Updates for discordpy 2

This commit is contained in:
Benjamin Morgan 2025-12-25 02:01:35 -07:00
parent 083d29bd8c
commit f8cfad3fba
2 changed files with 75 additions and 35 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
*.env
*.db

108
mstbot.py
View file

@ -1,11 +1,21 @@
# TODO: Containerize, include sql script to initialize empty database if one does not exist in volume
# TODO: CICD deploy over live instance via compose
# TODO: prevent sql injection
# TODO: annotate return types everywhere
# TODO: split cogs into separate files, add utils file, remove globals
# TODO: save invite link to README and show query port enable instructions
# TODO: https://discord.com/oauth2/authorize?client_id=911009295947165747&permissions=3088&response_type=code&redirect_uri=https%3A%2F%2Fwww.benrmorgan.com%2Fportfolio%2Fminecraft-server-tools-bot&integration_type=0&scope=messages.read+bot
import asyncio import asyncio
import os import os
import re import re
import socket import socket
import unicodedata
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Union, List from typing import Union, List
import mysql.connector import sqlite3
import discord import discord
from discord.ext import commands from discord.ext import commands
from dotenv import load_dotenv from dotenv import load_dotenv
@ -15,7 +25,10 @@ from mcstatus import JavaServer
load_dotenv('.env') load_dotenv('.env')
# Bot Initialization # Bot Initialization
bot = commands.Bot(command_prefix='$') intents = discord.Intents.default()
intents.messages = True
intents.message_content = True
bot = commands.Bot(command_prefix='$', intents=intents)
servers = [] servers = []
@bot.event @bot.event
@ -61,7 +74,7 @@ class Admin(commands.Cog):
@commands.command(brief="Sets up new Minecraft server querier", @commands.command(brief="Sets up new Minecraft server querier",
description="Sets up new Minecraft server querier. Removes previous querier, if applicable. Use a valid domain name or IP address. The default port (25565) is used unless otherwise specified. Specify a channel ID to announce player joins there.") description="Sets up new Minecraft server querier. Removes previous querier, if applicable. Use a valid domain name or IP address. The default port (25565) is used unless otherwise specified. Specify a channel ID to announce player joins there.")
@commands.has_permissions(administrator=True) @commands.has_permissions(administrator=True)
async def setup(self, ctx: discord.ext.commands.context.Context, ip: str, port: int = 25565, annChanID: Union[int, None] = None): async def setup(self, ctx: discord.ext.commands.context.Context, ip: str, port: int = 25565, annChanID: Union[int, str, None] = None):
"""Sets up a new Minecraft server querier.""" """Sets up a new Minecraft server querier."""
val_ip, val_port = await validate_ip_port(ctx, ip, port) val_ip, val_port = await validate_ip_port(ctx, ip, port)
if not val_ip: if not val_ip:
@ -71,7 +84,7 @@ class Admin(commands.Cog):
try: try:
query = mc.query() query = mc.query()
names = query.players.names names = query.players.list
except asyncio.exceptions.TimeoutError: except asyncio.exceptions.TimeoutError:
await log(ctx, "Setup query error, query port enabled?") await log(ctx, "Setup query error, query port enabled?")
except ConnectionRefusedError: except ConnectionRefusedError:
@ -79,11 +92,11 @@ class Admin(commands.Cog):
else: else:
mydb, cursor = connect() mydb, cursor = connect()
cursor.execute(f"INSERT INTO servers (id, ip, port) VALUES({str(ctx.guild.id)}, \"{val_ip}\", {str(port)}) ON DUPLICATE KEY UPDATE ip=\"{val_ip}\", port={str(port)}") cursor.execute(f"INSERT INTO servers (id, ip, port) VALUES({str(ctx.guild.id)}, \"{val_ip}\", {str(port)}) ON CONFLICT(id) DO UPDATE SET ip=\"{val_ip}\", port={str(port)}")
setMCNames(ctx.guild.id, cursor, names) setMCNames(ctx.guild.id, cursor, names)
cursor.execute(f"INSERT INTO times (id) VALUES({ctx.guild.id}) ON DUPLICATE KEY UPDATE id=id;") cursor.execute(f"INSERT INTO times (id) VALUES({ctx.guild.id}) ON CONFLICT(id) DO UPDATE SET id=id;")
mydb.commit() mydb.commit()
mydb.close() mydb.close()
@ -108,7 +121,7 @@ class Admin(commands.Cog):
@commands.command(brief="Turns on player join announcements", @commands.command(brief="Turns on player join announcements",
description="Turns on player join announcements in specified channel.") description="Turns on player join announcements in specified channel.")
@commands.has_permissions(administrator=True) @commands.has_permissions(administrator=True)
async def announce(self, ctx: discord.ext.commands.context.Context, chanid: int): async def announce(self, ctx: discord.ext.commands.context.Context, chanid: Union[int, str]):
await setAnn(ctx, True, chanid) await setAnn(ctx, True, chanid)
@commands.command(brief="Turns off player join announcements", description="Turns off player join announcements.") @commands.command(brief="Turns off player join announcements", description="Turns off player join announcements.")
@ -174,7 +187,7 @@ class Other(commands.Cog):
minutes = seconds // 60 minutes = seconds // 60
seconds %= 60 seconds %= 60
await safe_send("Status:\n " + ip + "\n " + motd + "\n\n Players: " + str(players) + "/" + str(max) + "\n Total Player Time: " + "%d:%02d:%02d" % (hour, minutes, seconds), ctx=ctx, format="```") await safe_send("Status:\n " + ip + "\n " + motd.to_plain() + "\n\n Players: " + str(players) + "/" + str(max) + "\n Total Player Time: " + "%d:%02d:%02d" % (hour, minutes, seconds), ctx=ctx, format="```")
mydb.close() mydb.close()
@ -211,7 +224,17 @@ class Other(commands.Cog):
await safe_send("I last queried " + ip + " at " + str(last) + " UTC", ctx=ctx) await safe_send("I last queried " + ip + " at " + str(last) + " UTC", ctx=ctx)
async def setAnn(ctx: discord.ext.commands.context.Context, ann: bool, cid: Union[int, None] = None): async def setAnn(ctx: discord.ext.commands.context.Context, ann: bool, cid: Union[int, str, None] = None):
if cid is not None:
if type(cid) == str:
res = re.match(r"^<#(\d+)>$", cid)
if res is not None:
cid = int(res.group(1))
if type(cid) != int:
await log(ctx, f"Announcement channel ID was provided but was malformed.")
return
if cid is not None and find_channels(serv=ctx.guild, chanid=cid) is None: if cid is not None and find_channels(serv=ctx.guild, chanid=cid) is None:
await log(ctx, "Channel", str(cid), "does not exist.") await log(ctx, "Channel", str(cid), "does not exist.")
return return
@ -245,7 +268,7 @@ async def setHours(ctx: discord.ext.commands.context.Context, hours: bool):
await log(ctx, ( await log(ctx, (
"Not displaying total player hours for " + ip + ".", "Not displaying total player hours for " + ip + ".",
"Displaying total player hours for " + ip + ".")[ "Displaying total player hours for " + ip + ".")[
hours]) hours] + " Please wait a moment for channel updates to take effect.")
async def log(ctx: commands.Context, *msg: str): async def log(ctx: commands.Context, *msg: str):
@ -270,16 +293,16 @@ async def validate_ip_port(ctx, ip, port):
await log(ctx, str(port), "is not a valid port number, please try again.") await log(ctx, str(port), "is not a valid port number, please try again.")
return None, None return None, None
domain = re.search("^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,6}$", ip) domain = re.search(r"^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,6}$", ip)
addr = re.search( addr = re.search(
"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$", r"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$",
ip) ip)
if domain is None or domain.group(0) != ip: if domain is None or domain.group(0) != ip:
if addr is None or addr.group(0) != ip: if addr is None or addr.group(0) != ip:
await log(ctx, ip, "is not a valid domain or IP address, please try again.") await log(ctx, ip, "is not a valid domain or IP address, please try again.")
return None, None return None, None
if re.search("(^127\.)|(^10\.)|(^172\.1[6-9]\.)|(^172\.2[0-9]\.)|(^172\.3[0-1]\.)|(^192\.168\.)", if re.search(r"(^127\.)|(^10\.)|(^172\.1[6-9]\.)|(^172\.2[0-9]\.)|(^172\.3[0-1]\.)|(^192\.168\.)",
ip) is not None: ip) is not None:
await log(ctx, ip, "is a private IP. I won't be able to query it.") await log(ctx, ip, "is a private IP. I won't be able to query it.")
return None, None return None, None
@ -323,25 +346,20 @@ async def getStatus(serv: JavaServer):
status = await asyncio.wait_for(serv.async_query(), timeout=1.0) status = await asyncio.wait_for(serv.async_query(), timeout=1.0)
p = status.players.online p = status.players.online
m = status.players.max m = status.players.max
n = status.players.names n = status.players.list
d = status.motd d = status.motd
return p, m, n, d return p, m, n, d
def connect(): def connect():
mydb = mysql.connector.connect( mydb = sqlite3.connect('mstbot.db')
host="localhost",
user=os.getenv("USER"),
password=os.getenv("PASS"),
database="status"
)
cursor = mydb.cursor() cursor = mydb.cursor()
return mydb, cursor return mydb, cursor
def getMCIP(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection): def getMCIP(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT ip, port FROM servers WHERE id=" + str(sid)) cursor.execute("SELECT ip, port FROM servers WHERE id=" + str(sid))
row = cursor.fetchone() row = cursor.fetchone()
ip = row[0] ip = row[0]
@ -350,32 +368,35 @@ def getMCIP(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection):
return ip, port return ip, port
def getMCNames(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection): def getMCNames(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT name FROM names WHERE id=" + str(sid)) cursor.execute("SELECT name FROM names WHERE id=" + str(sid))
rows = cursor.fetchall() rows = cursor.fetchall()
return [row[0] for row in rows] return [row[0] for row in rows]
def getMCQueryTime(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection): def getMCQueryTime(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT last_query FROM servers WHERE id=" + str(sid)) cursor.execute("SELECT last_query FROM servers WHERE id=" + str(sid))
tstr = cursor.fetchone()[0] tstr = cursor.fetchone()[0]
if tstr is None:
return tstr
return datetime.strptime(tstr, '%Y-%m-%d %H:%M:%S')
return tstr def getPersonSeconds(sid: int, cursor: sqlite3.Cursor):
def getPersonSeconds(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection):
cursor.execute("SELECT time FROM times WHERE id=" + str(sid)) cursor.execute("SELECT time FROM times WHERE id=" + str(sid))
tstr = cursor.fetchone()[0] tstr = cursor.fetchone()[0]
return int(tstr) return int(tstr)
def getShowHours(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection): def getShowHours(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT hours FROM servers WHERE id=" + str(sid)) cursor.execute("SELECT hours FROM servers WHERE id=" + str(sid))
tstr = cursor.fetchone()[0] tstr = cursor.fetchone()[0]
return bool(tstr) return bool(tstr)
def getMCJoinAnnounce(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection): def getMCJoinAnnounce(sid: int, cursor: sqlite3.Cursor):
cursor.execute("SELECT announce_joins, announce_joins_id FROM servers WHERE id=" + str(sid)) cursor.execute("SELECT announce_joins, announce_joins_id FROM servers WHERE id=" + str(sid))
row = cursor.fetchone() row = cursor.fetchone()
ann = row[0] ann = row[0]
@ -384,7 +405,7 @@ def getMCJoinAnnounce(sid: int, cursor: mysql.connector.connection_cext.CMySQLCo
return ann, cid return ann, cid
def setMCNames(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection, names: List[str]): def setMCNames(sid: int, cursor: sqlite3.Cursor, names: List[str]):
cursor.execute("DELETE FROM names WHERE id=" + str(sid)) cursor.execute("DELETE FROM names WHERE id=" + str(sid))
if len(names) > 0: if len(names) > 0:
@ -397,10 +418,10 @@ def setMCNames(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnectio
cursor.execute("INSERT INTO names (id, name) VALUES " + qStr) cursor.execute("INSERT INTO names (id, name) VALUES " + qStr)
def setMCQueryTime(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection, dt: datetime): def setMCQueryTime(sid: int, cursor: sqlite3.Cursor, dt: datetime):
cursor.execute("UPDATE servers SET last_query=\"" + dt.strftime('%Y-%m-%d %H:%M:%S') + "\" WHERE id=" + str(sid)) cursor.execute("UPDATE servers SET last_query=\"" + dt.strftime('%Y-%m-%d %H:%M:%S') + "\" WHERE id=" + str(sid))
def incrementPersonSeconds(sid: int, cursor: mysql.connector.connection_cext.CMySQLConnection, seconds: int): def incrementPersonSeconds(sid: int, cursor: sqlite3.Cursor, seconds: int):
cursor.execute(f"UPDATE times SET time=time+{seconds} WHERE id={str(sid)};") cursor.execute(f"UPDATE times SET time=time+{seconds} WHERE id={str(sid)};")
async def do_bot_cleanup(sid: int, ctx: Union[discord.ext.commands.context.Context, None] = None): async def do_bot_cleanup(sid: int, ctx: Union[discord.ext.commands.context.Context, None] = None):
@ -441,6 +462,11 @@ async def do_bot_cleanup(sid: int, ctx: Union[discord.ext.commands.context.Conte
else: else:
await log(ctx, "Cleaned up! Removed", ip, "querier, deleted", ctx.guild.name + "'s data from my server,", ("but failed to remove my status channels. ", "and removed my status channels.")[deleted]) await log(ctx, "Cleaned up! Removed", ip, "querier, deleted", ctx.guild.name + "'s data from my server,", ("but failed to remove my status channels. ", "and removed my status channels.")[deleted])
def ascii_clean(s):
return "".join(
c for c in unicodedata.normalize("NFKD", s)
if unicodedata.category(c)[0] != "C"
)
async def safe_send(msg: str, ctx: Union[discord.ext.commands.context.Context, None] = None, async def safe_send(msg: str, ctx: Union[discord.ext.commands.context.Context, None] = None,
chan: Union[discord.TextChannel, None] = None, format: str = ''): chan: Union[discord.TextChannel, None] = None, format: str = ''):
@ -577,7 +603,8 @@ async def status_task(sid: int):
await currServ.create_voice_channel(f"Players: {str(players)}/{str(max)}", await currServ.create_voice_channel(f"Players: {str(players)}/{str(max)}",
overwrites=overwrites) overwrites=overwrites)
if getShowHours(sid, cursor) and not first_iter and lastSeconds != -1: do_show_hours = getShowHours(sid, cursor)
if do_show_hours and not first_iter and lastSeconds != -1:
tStr = "Player Hrs: " + str(round((lastSeconds + len(names) * (currTime - lastTime).total_seconds())/3600)) tStr = "Player Hrs: " + str(round((lastSeconds + len(names) * (currTime - lastTime).total_seconds())/3600))
if len(tChannels) > 0: if len(tChannels) > 0:
lastTName = tChannels[0].name lastTName = tChannels[0].name
@ -595,6 +622,10 @@ async def status_task(sid: int):
else: else:
await currServ.create_voice_channel(tStr, await currServ.create_voice_channel(tStr,
overwrites=overwrites) overwrites=overwrites)
elif not do_show_hours:
if len(tChannels) > 0:
for channel in tChannels:
await channel.delete()
mydb.commit() mydb.commit()
mydb.close() mydb.close()
@ -602,6 +633,13 @@ async def status_task(sid: int):
await asyncio.sleep(wait) await asyncio.sleep(wait)
first_iter = False first_iter = False
bot.add_cog(Admin()) @bot.event
bot.add_cog(Other(bot)) async def setup_hook():
bot.run(os.getenv('TOKEN')) await bot.add_cog(Admin())
await bot.add_cog(Other(bot))
token = os.getenv('TOKEN')
if token is not None:
bot.run(token)
else:
print("Token not found, ensure .env file exists and is well-formed.")