Updates for discordpy 2

This commit is contained in:
Benjamin Morgan 2025-12-25 02:01:35 -07:00
parent 083d29bd8c
commit f8cfad3fba
2 changed files with 75 additions and 35 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
*.env
*.db

108
mstbot.py
View file

@ -1,11 +1,21 @@
# TODO: Containerize, include sql script to initialize empty database if one does not exist in volume
# TODO: CICD deploy over live instance via compose
# TODO: prevent sql injection
# TODO: annotate return types everywhere
# TODO: split cogs into separate files, add utils file, remove globals
# TODO: save invite link to README and show query port enable instructions
# TODO: https://discord.com/oauth2/authorize?client_id=911009295947165747&permissions=3088&response_type=code&redirect_uri=https%3A%2F%2Fwww.benrmorgan.com%2Fportfolio%2Fminecraft-server-tools-bot&integration_type=0&scope=messages.read+bot
import asyncio
import os
import re
import socket
import unicodedata
from datetime import datetime, timezone
from typing import Union, List
import mysql.connector
import sqlite3
import discord
from discord.ext import commands
from dotenv import load_dotenv
@ -15,7 +25,10 @@ from mcstatus import JavaServer
load_dotenv('.env')
# Bot Initialization
bot = commands.Bot(command_prefix='$')
intents = discord.Intents.default()
intents.messages = True
intents.message_content = True
bot = commands.Bot(command_prefix='$', intents=intents)
servers = []
@bot.event
@ -61,7 +74,7 @@ class Admin(commands.Cog):
@commands.command(brief="Sets up new Minecraft server querier",
description="Sets up new Minecraft server querier. Removes previous querier, if applicable. Use a valid domain name or IP address. The default port (25565) is used unless otherwise specified. Specify a channel ID to announce player joins there.")
@commands.has_permissions(administrator=True)
async def setup(self, ctx: discord.ext.commands.context.Context, ip: str, port: int = 25565, annChanID: Union[int, None] = None):
async def setup(self, ctx: discord.ext.commands.context.Context, ip: str, port: int = 25565, annChanID: Union[int, str, None] = None):
"""Sets up a new Minecraft server querier."""
val_ip, val_port = await validate_ip_port(ctx, ip, port)
if not val_ip:
@ -71,7 +84,7 @@ class Admin(commands.Cog):
try:
query = mc.query()
names = query.players.names
names = query.players.list
except asyncio.exceptions.TimeoutError:
await log(ctx, "Setup query error, query port enabled?")
except ConnectionRefusedError:
@ -79,11 +92,11 @@ class Admin(commands.Cog):
else:
mydb, cursor = connect()
cursor.execute(f"INSERT INTO servers (id, ip, port) VALUES({str(ctx.guild.id)}, \"{val_ip}\", {str(port)}) ON DUPLICATE KEY UPDATE ip=\"{val_ip}\", port={str(port)}")
cursor.execute(f"INSERT INTO servers (id, ip, port) VALUES({str(ctx.guild.id)}, \"{val_ip}\", {str(port)}) ON CONFLICT(id) DO UPDATE SET ip=\"{val_ip}\", port={str(port)}")
setMCNames(ctx.guild.id, cursor, names)
cursor.execute(f"INSERT INTO times (id) VALUES({ctx.guild.id}) ON DUPLICATE KEY UPDATE id=id;")
cursor.execute(f"INSERT INTO times (id) VALUES({ctx.guild.id}) ON CONFLICT(id) DO UPDATE SET id=id;")
mydb.commit()
mydb.close()
@ -108,7 +121,7 @@ class Admin(commands.Cog):
@commands.command(brief="Turns on player join announcements",
description="Turns on player join announcements in specified channel.")
@commands.has_permissions(administrator=True)
async def announce(self, ctx: discord.ext.commands.context.Context, chanid: int):
async def announce(self, ctx: discord.ext.commands.context.Context, chanid: Union[int, str]):
await setAnn(ctx, True, chanid)
@commands.command(brief="Turns off player join announcements", description="Turns off player join announcements.")
@ -174,7 +187,7 @@ class Other(commands.Cog):
minutes = seconds // 60
seconds %= 60
await safe_send("Status:\n " + ip + "\n " + motd + "\n\n Players: " + str(players) + "/" + str(max) + "\n Total Player Time: " + "%d:%02d:%02d" % (hour, minutes, seconds), ctx=ctx, format="```")
await safe_send("Status:\n " + ip + "\n " + motd.to_plain() + "\n\n Players: " + str(players) + "/" + str(max) + "\n Total Player Time: " + "%d:%02d:%02d" % (hour, minutes, seconds), ctx=ctx, format="```")
mydb.close()
@ -211,7 +224,17 @@ class Other(commands.Cog):
await safe_send("I last queried " + ip + " at " + str(last) + " UTC", ctx=ctx)
async def setAnn(ctx: discord.ext.commands.context.Context, ann: bool, cid: Union[int, None] = None):
async def setAnn(ctx: discord.ext.commands.context.Context, ann: bool, cid: Union[int, str, None] = None):
if cid is not None:
if type(cid) == str:
res = re.match(r"^<#(\d+)>$", cid)
if res is not None:
cid = int(res.group(1))
if type(cid) != int:
await log(ctx, f"Announcement channel ID was provided but was malformed.")
return
if cid is not None and find_channels(serv=ctx.guild, chanid=cid) is None:
await log(ctx, "Channel", str(cid), "does not exist.")
return
@ -245,7 +268,7 @@ async def setHours(ctx: discord.ext.commands.context.Context, hours: bool):
await log(ctx, (
"Not displaying total player hours for " + ip + ".",
"Displaying total player hours for " + ip + ".")[
hours])
hours] + " Please wait a moment for channel updates to take effect.")
async def log(ctx: commands.Context, *msg: str):
@ -270,16 +293,16 @@ async def validate_ip_port(ctx, ip, port):
await log(ctx, str(port), "is not a valid port number, please try again.")
return None, None
domain = re.search("^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,6}$", ip)
domain = re.search(r"^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,6}$", ip)
addr = re.search(
"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$",
r"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$",
ip)
if domain is None or domain.group(0) != ip:
if addr is None or addr.group(0) != ip:
await log(ctx, ip, "is not a valid domain or IP address, please try again.")
return None, None
if re.search("(^127\.)|(^10\.)|(^172\.1[6-9]\.)|(^172\.2[0-9]\.)|(^172\.3[0-1]\.)|(^192\.168\.)",
if re.search(r"(^127\.)|(^10\.)|(^172\.1[6-9]\.)|(^172\.2[0-9]\.)|(^172\.3[0-1]\.)|(^192\.168\.)",
ip) is not None:
await log(ctx, ip, "is a private IP. I won't be able to query it.")
return None, None
@ -323,25 +346,20 @@ async def getStatus(serv: JavaServer):
status = await asyncio.wait_for(serv.async_query(), timeout=1.0)
p = status.players.online
m = status.players.max
n = status.players.names
n = status.players.list
d = status.motd
return p, m, n, d
def connect():
mydb = mysql.connector.connect(
host="localhost",
user=os.getenv("USER"),
password=os.getenv("PASS"),
database="status"
)
mydb = sqlite3.connect('mstbot.db')
cursor = mydb.cursor()
return mydb, cursor
def getMCIP(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection):
def getMCIP(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT ip, port FROM servers WHERE id=" + str(sid))
row = cursor.fetchone()
ip = row[0]
@ -350,32 +368,35 @@ def getMCIP(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection):
return ip, port
def getMCNames(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection):
def getMCNames(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT name FROM names WHERE id=" + str(sid))
rows = cursor.fetchall()
return [row[0] for row in rows]
def getMCQueryTime(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection):
def getMCQueryTime(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT last_query FROM servers WHERE id=" + str(sid))
tstr = cursor.fetchone()[0]
if tstr is None:
return tstr
return datetime.strptime(tstr, '%Y-%m-%d %H:%M:%S')
return tstr
def getPersonSeconds(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection):
def getPersonSeconds(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT time FROM times WHERE id=" + str(sid))
tstr = cursor.fetchone()[0]
return int(tstr)
def getShowHours(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection):
def getShowHours(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT hours FROM servers WHERE id=" + str(sid))
tstr = cursor.fetchone()[0]
return bool(tstr)
def getMCJoinAnnounce(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection):
def getMCJoinAnnounce(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT announce_joins, announce_joins_id FROM servers WHERE id=" + str(sid))
row = cursor.fetchone()
ann = row[0]
@ -384,7 +405,7 @@ def getMCJoinAnnounce(sid: int, cursor: mysql.connector.connection_cext.CMySQLCo
return ann, cid
def setMCNames(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection, names: List[str]):
def setMCNames(sid: int, cursor: sqlite3.Cursor, names: List[str]):
cursor.execute("DELETE FROM names WHERE id=" + str(sid))
if len(names) > 0:
@ -397,10 +418,10 @@ def setMCNames(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnectio
cursor.execute("INSERT INTO names (id, name) VALUES " + qStr)
def setMCQueryTime(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection, dt: datetime):
def setMCQueryTime(sid: int, cursor: sqlite3.Cursor, dt: datetime):
cursor.execute("UPDATE servers SET last_query=\"" + dt.strftime('%Y-%m-%d %H:%M:%S') + "\" WHERE id=" + str(sid))
def incrementPersonSeconds(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection, seconds: int):
def incrementPersonSeconds(sid: int, cursor: sqlite3.Cursor, seconds: int):
cursor.execute(f"UPDATE times SET time=time+{seconds} WHERE id={str(sid)};")
async def do_bot_cleanup(sid: int, ctx: Union[discord.ext.commands.context.Context, None] = None):
@ -441,6 +462,11 @@ async def do_bot_cleanup(sid: int, ctx: Union[discord.ext.commands.context.Conte
else:
await log(ctx, "Cleaned up! Removed", ip, "querier, deleted", ctx.guild.name + "'s data from my server,", ("but failed to remove my status channels. ", "and removed my status channels.")[deleted])
def ascii_clean(s):
return "".join(
c for c in unicodedata.normalize("NFKD", s)
if unicodedata.category(c)[0] != "C"
)
async def safe_send(msg: str, ctx: Union[discord.ext.commands.context.Context, None] = None,
chan: Union[discord.TextChannel, None] = None, format: str = ''):
@ -577,7 +603,8 @@ async def status_task(sid: int):
await currServ.create_voice_channel(f"Players: {str(players)}/{str(max)}",
overwrites=overwrites)
if getShowHours(sid, cursor) and not first_iter and lastSeconds != -1:
do_show_hours = getShowHours(sid, cursor)
if do_show_hours and not first_iter and lastSeconds != -1:
tStr = "Player Hrs: " + str(round((lastSeconds + len(names) * (currTime - lastTime).total_seconds())/3600))
if len(tChannels) > 0:
lastTName = tChannels[0].name
@ -595,6 +622,10 @@ async def status_task(sid: int):
else:
await currServ.create_voice_channel(tStr,
overwrites=overwrites)
elif not do_show_hours:
if len(tChannels) > 0:
for channel in tChannels:
await channel.delete()
mydb.commit()
mydb.close()
@ -602,6 +633,13 @@ async def status_task(sid: int):
await asyncio.sleep(wait)
first_iter = False
bot.add_cog(Admin())
bot.add_cog(Other(bot))
bot.run(os.getenv('TOKEN'))
@bot.event
async def setup_hook():
await bot.add_cog(Admin())
await bot.add_cog(Other(bot))
token = os.getenv('TOKEN')
if token is not None:
bot.run(token)
else:
print("Token not found, ensure .env file exists and is well-formed.")