diff --git a/Dockerfile b/Dockerfile index d393682..70b1e34 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,10 +5,10 @@ WORKDIR /app COPY docker-entrypoint.sh ./entrypoint.sh COPY requirements.txt . COPY schema.sql . -COPY mstbot.py . +ADD src src -RUN apt update && apt install sqlite3 +RUN apt update && apt upgrade -y && apt install -y sqlite3 RUN pip install --upgrade pip && pip install -r requirements.txt ENTRYPOINT [ "/app/entrypoint.sh" ] -CMD [ "python3", "-u", "mstbot.py" ] +CMD [ "python3", "-u", "src/mstbot.py" ] diff --git a/src/erl.py b/src/erl.py new file mode 100644 index 0000000..8154338 --- /dev/null +++ b/src/erl.py @@ -0,0 +1,126 @@ +import asyncio +import time +from typing import Any, Callable, Awaitable, Dict + + +class EndpointRateLimiter: + def __init__( + self, + min_interval: float, + func: Callable[..., Awaitable[Any]], + ): + self.min_interval = min_interval + self.func = func + + self._workers: Dict[str, "_EndpointWorker"] = {} + self._lock = asyncio.Lock() + + async def get_state(self, endpoint: str): + async with self._lock: + worker = self._workers.get(endpoint) + if worker is None: + return None + + return await worker.get_state() + + async def get_time_remaining(self, endpoint: str): + async with self._lock: + worker = self._workers.get(endpoint) + if worker is None: + return None + + return await worker.get_sleep_time() + + async def submit(self, endpoint: str, *args, **kwargs): + async with self._lock: + worker = self._workers.get(endpoint) + if worker is None: + worker = _EndpointWorker( + endpoint, + self.min_interval, + self.func, + ) + self._workers[endpoint] = worker + asyncio.create_task(worker.run()) + + await worker.submit(*args, **kwargs) + + +class _EndpointWorker: + def __init__( + self, + endpoint: str, + min_interval: float, + func: Callable[..., Awaitable[Any]], + ): + self.endpoint = endpoint + self.min_interval = min_interval + self.func = func + + self._desired_state = None + self._applied_state = None + self._last_run = 0.0 + + self._event = asyncio.Event() + self._lock = asyncio.Lock() + self._running = True + + # ---- State handling ------------------------------------------------- + + def state_from_args(self, *args, **kwargs): + """ + Override if args do not directly represent endpoint state. + Must return a comparable, immutable value. + """ + return args, frozenset(kwargs.items()) + + async def submit(self, *args, **kwargs): + desired = self.state_from_args(*args, **kwargs) + async with self._lock: + self._desired_state = desired + self._event.set() + + async def get_state(self): + async with self._lock: + state = self._desired_state + if state is None: + return None + + args, kw_items = state + return args, dict(kw_items) + + async def get_sleep_time(self): + return self.min_interval - (time.monotonic() - self._last_run) + + # ---- Worker loop ---------------------------------------------------- + + async def run(self): + while self._running: + await self._event.wait() + self._event.clear() + + # Rate limit + sleep_for = await self.get_sleep_time() + if sleep_for > 0: + await asyncio.sleep(sleep_for) + + async with self._lock: + if self._desired_state == self._applied_state: + # Net-zero change → skip call + continue + + print("Applying state...") + state_to_apply = self._desired_state + + try: + args, kw_items = state_to_apply + kwargs = dict(kw_items) + await self.func(*args, **kwargs) + + async with self._lock: + self._applied_state = state_to_apply + self._last_run = time.monotonic() + except Exception: + print("An exception occurred in EndpointRateLimiter!") + # applied_state intentionally unchanged on failure + pass diff --git a/mstbot.py b/src/mstbot.py similarity index 91% rename from mstbot.py rename to src/mstbot.py index 892c8b6..c3cd1ee 100644 --- a/mstbot.py +++ b/src/mstbot.py @@ -23,6 +23,13 @@ import discord from discord.ext import commands from mcstatus import JavaServer +from erl import EndpointRateLimiter + +async def change_channel_name(channel, name: str): + await channel.edit(name=name) + +limiter = EndpointRateLimiter(301, change_channel_name) + # Bot Initialization intents = discord.Intents.default() intents.messages = True @@ -478,6 +485,16 @@ def ascii_clean(s): if unicodedata.category(c)[0] != "C" ) +async def get_last_channel_name(channel): + last_name = channel.name + latest = await limiter.get_state(channel.id) + if latest: + args, _ = latest + last_name = args[1] + + return last_name + + async def safe_send(msg: str, ctx: Union[discord.ext.commands.context.Context, None] = None, chan: Union[discord.TextChannel, None] = None, format: str = ''): try: @@ -530,8 +547,6 @@ async def status_task(sid: int): deltaSeconds = (currTime - lastTime).total_seconds() - print(currServ.name, "Query:", ip, str(players) + "/" + str(max), datetime.now(timezone.utc).strftime("%H:%M:%S"), - str(deltaSeconds)) setMCNames(sid, conn, names) setMCQueryTime(sid, conn, currTime) @@ -578,24 +593,17 @@ async def status_task(sid: int): } if len(iChannels) > 0: - lastIPName = iChannels[0].name + lastIPName = await get_last_channel_name(iChannels[0]) ipStr = "IP: " + ip if lastIPName != ipStr: - try: - print(currServ.name, "Update: Ip changed!") - await iChannels[0].edit(name=ipStr) - wait = 301 - except discord.errors.Forbidden: - print(currServ.name, - "Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.") - await do_bot_cleanup(sid) - return + print(currServ.name, "Update: IP changed!") + await limiter.submit(iChannels[0].id, iChannels[0], ipStr) else: await currServ.create_voice_channel("IP: " + ip, overwrites=overwrites) if len(pChannels) > 0: - lastPName = pChannels[0].name + lastPName = await get_last_channel_name(pChannels[0]) if players == -1: pStr = lastPName @@ -603,15 +611,8 @@ async def status_task(sid: int): pStr = "Players: " + str(players) + "/" + str(max) if lastPName != pStr: - try: - print(currServ.name, "Update: Players changed!") - await pChannels[0].edit(name=pStr) - wait = 301 - except discord.errors.Forbidden: - print(currServ.name, - "Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.") - await do_bot_cleanup(sid) - return + print(currServ.name, "Update: Players changed!") + await limiter.submit(pChannels[0].id, pChannels[0], pStr) else: await currServ.create_voice_channel(f"Players: {str(players)}/{str(max)}", overwrites=overwrites) @@ -620,18 +621,11 @@ async def status_task(sid: int): if do_show_hours and not first_iter and lastSeconds != -1: tStr = "Player Hrs: " + str(round(getPersonSeconds(sid, conn)/3600)) if len(tChannels) > 0: - lastTName = tChannels[0].name + lastTName = await get_last_channel_name(tChannels[0]) if lastTName != tStr: - try: - print(currServ.name, "Update: Time changed!") - await tChannels[0].edit(name=tStr) - wait = 301 - except discord.errors.Forbidden: - print(currServ.name, - "Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.") - await do_bot_cleanup(sid) - return + print(currServ.name, "Update: Time changed!") + await limiter.submit(tChannels[0].id, tChannels[0], tStr) else: await currServ.create_voice_channel(tStr, overwrites=overwrites) @@ -639,6 +633,9 @@ async def status_task(sid: int): if len(tChannels) > 0: for channel in tChannels: await channel.delete() + + print(currServ.name, "Query:", ip, str(players) + "/" + str(max), datetime.now(timezone.utc).strftime("%H:%M:%S"), + str(deltaSeconds), "IP:", str(await limiter.get_time_remaining(iChannels[0].id)), "P:", str(await limiter.get_time_remaining(pChannels[0].id)), "T:", str(await limiter.get_time_remaining(tChannels[0].id))) await asyncio.sleep(wait) first_iter = False