Async rate limit sleeps

This commit is contained in:
Benjamin Morgan 2026-01-10 15:58:50 -07:00
parent 065d887a9f
commit 855e4c305e
3 changed files with 158 additions and 35 deletions

View file

@ -5,10 +5,10 @@ WORKDIR /app
COPY docker-entrypoint.sh ./entrypoint.sh
COPY requirements.txt .
COPY schema.sql .
COPY mstbot.py .
ADD src src
RUN apt update && apt install sqlite3
RUN apt update && apt upgrade -y && apt install -y sqlite3
RUN pip install --upgrade pip && pip install -r requirements.txt
ENTRYPOINT [ "/app/entrypoint.sh" ]
CMD [ "python3", "-u", "mstbot.py" ]
CMD [ "python3", "-u", "src/mstbot.py" ]

126
src/erl.py Normal file
View file

@ -0,0 +1,126 @@
import asyncio
import time
from typing import Any, Callable, Awaitable, Dict
class EndpointRateLimiter:
def __init__(
self,
min_interval: float,
func: Callable[..., Awaitable[Any]],
):
self.min_interval = min_interval
self.func = func
self._workers: Dict[str, "_EndpointWorker"] = {}
self._lock = asyncio.Lock()
async def get_state(self, endpoint: str):
async with self._lock:
worker = self._workers.get(endpoint)
if worker is None:
return None
return await worker.get_state()
async def get_time_remaining(self, endpoint: str):
async with self._lock:
worker = self._workers.get(endpoint)
if worker is None:
return None
return await worker.get_sleep_time()
async def submit(self, endpoint: str, *args, **kwargs):
async with self._lock:
worker = self._workers.get(endpoint)
if worker is None:
worker = _EndpointWorker(
endpoint,
self.min_interval,
self.func,
)
self._workers[endpoint] = worker
asyncio.create_task(worker.run())
await worker.submit(*args, **kwargs)
class _EndpointWorker:
def __init__(
self,
endpoint: str,
min_interval: float,
func: Callable[..., Awaitable[Any]],
):
self.endpoint = endpoint
self.min_interval = min_interval
self.func = func
self._desired_state = None
self._applied_state = None
self._last_run = 0.0
self._event = asyncio.Event()
self._lock = asyncio.Lock()
self._running = True
# ---- State handling -------------------------------------------------
def state_from_args(self, *args, **kwargs):
"""
Override if args do not directly represent endpoint state.
Must return a comparable, immutable value.
"""
return args, frozenset(kwargs.items())
async def submit(self, *args, **kwargs):
desired = self.state_from_args(*args, **kwargs)
async with self._lock:
self._desired_state = desired
self._event.set()
async def get_state(self):
async with self._lock:
state = self._desired_state
if state is None:
return None
args, kw_items = state
return args, dict(kw_items)
async def get_sleep_time(self):
return self.min_interval - (time.monotonic() - self._last_run)
# ---- Worker loop ----------------------------------------------------
async def run(self):
while self._running:
await self._event.wait()
self._event.clear()
# Rate limit
sleep_for = await self.get_sleep_time()
if sleep_for > 0:
await asyncio.sleep(sleep_for)
async with self._lock:
if self._desired_state == self._applied_state:
# Net-zero change → skip call
continue
print("Applying state...")
state_to_apply = self._desired_state
try:
args, kw_items = state_to_apply
kwargs = dict(kw_items)
await self.func(*args, **kwargs)
async with self._lock:
self._applied_state = state_to_apply
self._last_run = time.monotonic()
except Exception:
print("An exception occurred in EndpointRateLimiter!")
# applied_state intentionally unchanged on failure
pass

View file

@ -23,6 +23,13 @@ import discord
from discord.ext import commands
from mcstatus import JavaServer
from erl import EndpointRateLimiter
async def change_channel_name(channel, name: str):
await channel.edit(name=name)
limiter = EndpointRateLimiter(301, change_channel_name)
# Bot Initialization
intents = discord.Intents.default()
intents.messages = True
@ -478,6 +485,16 @@ def ascii_clean(s):
if unicodedata.category(c)[0] != "C"
)
async def get_last_channel_name(channel):
last_name = channel.name
latest = await limiter.get_state(channel.id)
if latest:
args, _ = latest
last_name = args[1]
return last_name
async def safe_send(msg: str, ctx: Union[discord.ext.commands.context.Context, None] = None,
chan: Union[discord.TextChannel, None] = None, format: str = ''):
try:
@ -530,8 +547,6 @@ async def status_task(sid: int):
deltaSeconds = (currTime - lastTime).total_seconds()
print(currServ.name, "Query:", ip, str(players) + "/" + str(max), datetime.now(timezone.utc).strftime("%H:%M:%S"),
str(deltaSeconds))
setMCNames(sid, conn, names)
setMCQueryTime(sid, conn, currTime)
@ -578,24 +593,17 @@ async def status_task(sid: int):
}
if len(iChannels) > 0:
lastIPName = iChannels[0].name
lastIPName = await get_last_channel_name(iChannels[0])
ipStr = "IP: " + ip
if lastIPName != ipStr:
try:
print(currServ.name, "Update: Ip changed!")
await iChannels[0].edit(name=ipStr)
wait = 301
except discord.errors.Forbidden:
print(currServ.name,
"Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.")
await do_bot_cleanup(sid)
return
print(currServ.name, "Update: IP changed!")
await limiter.submit(iChannels[0].id, iChannels[0], ipStr)
else:
await currServ.create_voice_channel("IP: " + ip, overwrites=overwrites)
if len(pChannels) > 0:
lastPName = pChannels[0].name
lastPName = await get_last_channel_name(pChannels[0])
if players == -1:
pStr = lastPName
@ -603,15 +611,8 @@ async def status_task(sid: int):
pStr = "Players: " + str(players) + "/" + str(max)
if lastPName != pStr:
try:
print(currServ.name, "Update: Players changed!")
await pChannels[0].edit(name=pStr)
wait = 301
except discord.errors.Forbidden:
print(currServ.name,
"Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.")
await do_bot_cleanup(sid)
return
await limiter.submit(pChannels[0].id, pChannels[0], pStr)
else:
await currServ.create_voice_channel(f"Players: {str(players)}/{str(max)}",
overwrites=overwrites)
@ -620,18 +621,11 @@ async def status_task(sid: int):
if do_show_hours and not first_iter and lastSeconds != -1:
tStr = "Player Hrs: " + str(round(getPersonSeconds(sid, conn)/3600))
if len(tChannels) > 0:
lastTName = tChannels[0].name
lastTName = await get_last_channel_name(tChannels[0])
if lastTName != tStr:
try:
print(currServ.name, "Update: Time changed!")
await tChannels[0].edit(name=tStr)
wait = 301
except discord.errors.Forbidden:
print(currServ.name,
"Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.")
await do_bot_cleanup(sid)
return
await limiter.submit(tChannels[0].id, tChannels[0], tStr)
else:
await currServ.create_voice_channel(tStr,
overwrites=overwrites)
@ -640,6 +634,9 @@ async def status_task(sid: int):
for channel in tChannels:
await channel.delete()
print(currServ.name, "Query:", ip, str(players) + "/" + str(max), datetime.now(timezone.utc).strftime("%H:%M:%S"),
str(deltaSeconds), "IP:", str(await limiter.get_time_remaining(iChannels[0].id)), "P:", str(await limiter.get_time_remaining(pChannels[0].id)), "T:", str(await limiter.get_time_remaining(tChannels[0].id)))
await asyncio.sleep(wait)
first_iter = False