Async rate limit sleeps
This commit is contained in:
parent
065d887a9f
commit
855e4c305e
3 changed files with 158 additions and 35 deletions
|
|
@ -5,10 +5,10 @@ WORKDIR /app
|
|||
COPY docker-entrypoint.sh ./entrypoint.sh
|
||||
COPY requirements.txt .
|
||||
COPY schema.sql .
|
||||
COPY mstbot.py .
|
||||
ADD src src
|
||||
|
||||
RUN apt update && apt install sqlite3
|
||||
RUN apt update && apt upgrade -y && apt install -y sqlite3
|
||||
RUN pip install --upgrade pip && pip install -r requirements.txt
|
||||
|
||||
ENTRYPOINT [ "/app/entrypoint.sh" ]
|
||||
CMD [ "python3", "-u", "mstbot.py" ]
|
||||
CMD [ "python3", "-u", "src/mstbot.py" ]
|
||||
|
|
|
|||
126
src/erl.py
Normal file
126
src/erl.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import Any, Callable, Awaitable, Dict
|
||||
|
||||
|
||||
class EndpointRateLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
min_interval: float,
|
||||
func: Callable[..., Awaitable[Any]],
|
||||
):
|
||||
self.min_interval = min_interval
|
||||
self.func = func
|
||||
|
||||
self._workers: Dict[str, "_EndpointWorker"] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_state(self, endpoint: str):
|
||||
async with self._lock:
|
||||
worker = self._workers.get(endpoint)
|
||||
if worker is None:
|
||||
return None
|
||||
|
||||
return await worker.get_state()
|
||||
|
||||
async def get_time_remaining(self, endpoint: str):
|
||||
async with self._lock:
|
||||
worker = self._workers.get(endpoint)
|
||||
if worker is None:
|
||||
return None
|
||||
|
||||
return await worker.get_sleep_time()
|
||||
|
||||
async def submit(self, endpoint: str, *args, **kwargs):
|
||||
async with self._lock:
|
||||
worker = self._workers.get(endpoint)
|
||||
if worker is None:
|
||||
worker = _EndpointWorker(
|
||||
endpoint,
|
||||
self.min_interval,
|
||||
self.func,
|
||||
)
|
||||
self._workers[endpoint] = worker
|
||||
asyncio.create_task(worker.run())
|
||||
|
||||
await worker.submit(*args, **kwargs)
|
||||
|
||||
|
||||
class _EndpointWorker:
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
min_interval: float,
|
||||
func: Callable[..., Awaitable[Any]],
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.min_interval = min_interval
|
||||
self.func = func
|
||||
|
||||
self._desired_state = None
|
||||
self._applied_state = None
|
||||
self._last_run = 0.0
|
||||
|
||||
self._event = asyncio.Event()
|
||||
self._lock = asyncio.Lock()
|
||||
self._running = True
|
||||
|
||||
# ---- State handling -------------------------------------------------
|
||||
|
||||
def state_from_args(self, *args, **kwargs):
|
||||
"""
|
||||
Override if args do not directly represent endpoint state.
|
||||
Must return a comparable, immutable value.
|
||||
"""
|
||||
return args, frozenset(kwargs.items())
|
||||
|
||||
async def submit(self, *args, **kwargs):
|
||||
desired = self.state_from_args(*args, **kwargs)
|
||||
async with self._lock:
|
||||
self._desired_state = desired
|
||||
self._event.set()
|
||||
|
||||
async def get_state(self):
|
||||
async with self._lock:
|
||||
state = self._desired_state
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
args, kw_items = state
|
||||
return args, dict(kw_items)
|
||||
|
||||
async def get_sleep_time(self):
|
||||
return self.min_interval - (time.monotonic() - self._last_run)
|
||||
|
||||
# ---- Worker loop ----------------------------------------------------
|
||||
|
||||
async def run(self):
|
||||
while self._running:
|
||||
await self._event.wait()
|
||||
self._event.clear()
|
||||
|
||||
# Rate limit
|
||||
sleep_for = await self.get_sleep_time()
|
||||
if sleep_for > 0:
|
||||
await asyncio.sleep(sleep_for)
|
||||
|
||||
async with self._lock:
|
||||
if self._desired_state == self._applied_state:
|
||||
# Net-zero change → skip call
|
||||
continue
|
||||
|
||||
print("Applying state...")
|
||||
state_to_apply = self._desired_state
|
||||
|
||||
try:
|
||||
args, kw_items = state_to_apply
|
||||
kwargs = dict(kw_items)
|
||||
await self.func(*args, **kwargs)
|
||||
|
||||
async with self._lock:
|
||||
self._applied_state = state_to_apply
|
||||
self._last_run = time.monotonic()
|
||||
except Exception:
|
||||
print("An exception occurred in EndpointRateLimiter!")
|
||||
# applied_state intentionally unchanged on failure
|
||||
pass
|
||||
|
|
@ -23,6 +23,13 @@ import discord
|
|||
from discord.ext import commands
|
||||
from mcstatus import JavaServer
|
||||
|
||||
from erl import EndpointRateLimiter
|
||||
|
||||
async def change_channel_name(channel, name: str):
|
||||
await channel.edit(name=name)
|
||||
|
||||
limiter = EndpointRateLimiter(301, change_channel_name)
|
||||
|
||||
# Bot Initialization
|
||||
intents = discord.Intents.default()
|
||||
intents.messages = True
|
||||
|
|
@ -478,6 +485,16 @@ def ascii_clean(s):
|
|||
if unicodedata.category(c)[0] != "C"
|
||||
)
|
||||
|
||||
async def get_last_channel_name(channel):
|
||||
last_name = channel.name
|
||||
latest = await limiter.get_state(channel.id)
|
||||
if latest:
|
||||
args, _ = latest
|
||||
last_name = args[1]
|
||||
|
||||
return last_name
|
||||
|
||||
|
||||
async def safe_send(msg: str, ctx: Union[discord.ext.commands.context.Context, None] = None,
|
||||
chan: Union[discord.TextChannel, None] = None, format: str = ''):
|
||||
try:
|
||||
|
|
@ -530,8 +547,6 @@ async def status_task(sid: int):
|
|||
|
||||
deltaSeconds = (currTime - lastTime).total_seconds()
|
||||
|
||||
print(currServ.name, "Query:", ip, str(players) + "/" + str(max), datetime.now(timezone.utc).strftime("%H:%M:%S"),
|
||||
str(deltaSeconds))
|
||||
|
||||
setMCNames(sid, conn, names)
|
||||
setMCQueryTime(sid, conn, currTime)
|
||||
|
|
@ -578,24 +593,17 @@ async def status_task(sid: int):
|
|||
}
|
||||
|
||||
if len(iChannels) > 0:
|
||||
lastIPName = iChannels[0].name
|
||||
lastIPName = await get_last_channel_name(iChannels[0])
|
||||
|
||||
ipStr = "IP: " + ip
|
||||
if lastIPName != ipStr:
|
||||
try:
|
||||
print(currServ.name, "Update: Ip changed!")
|
||||
await iChannels[0].edit(name=ipStr)
|
||||
wait = 301
|
||||
except discord.errors.Forbidden:
|
||||
print(currServ.name,
|
||||
"Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.")
|
||||
await do_bot_cleanup(sid)
|
||||
return
|
||||
print(currServ.name, "Update: IP changed!")
|
||||
await limiter.submit(iChannels[0].id, iChannels[0], ipStr)
|
||||
else:
|
||||
await currServ.create_voice_channel("IP: " + ip, overwrites=overwrites)
|
||||
|
||||
if len(pChannels) > 0:
|
||||
lastPName = pChannels[0].name
|
||||
lastPName = await get_last_channel_name(pChannels[0])
|
||||
|
||||
if players == -1:
|
||||
pStr = lastPName
|
||||
|
|
@ -603,15 +611,8 @@ async def status_task(sid: int):
|
|||
pStr = "Players: " + str(players) + "/" + str(max)
|
||||
|
||||
if lastPName != pStr:
|
||||
try:
|
||||
print(currServ.name, "Update: Players changed!")
|
||||
await pChannels[0].edit(name=pStr)
|
||||
wait = 301
|
||||
except discord.errors.Forbidden:
|
||||
print(currServ.name,
|
||||
"Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.")
|
||||
await do_bot_cleanup(sid)
|
||||
return
|
||||
print(currServ.name, "Update: Players changed!")
|
||||
await limiter.submit(pChannels[0].id, pChannels[0], pStr)
|
||||
else:
|
||||
await currServ.create_voice_channel(f"Players: {str(players)}/{str(max)}",
|
||||
overwrites=overwrites)
|
||||
|
|
@ -620,18 +621,11 @@ async def status_task(sid: int):
|
|||
if do_show_hours and not first_iter and lastSeconds != -1:
|
||||
tStr = "Player Hrs: " + str(round(getPersonSeconds(sid, conn)/3600))
|
||||
if len(tChannels) > 0:
|
||||
lastTName = tChannels[0].name
|
||||
lastTName = await get_last_channel_name(tChannels[0])
|
||||
|
||||
if lastTName != tStr:
|
||||
try:
|
||||
print(currServ.name, "Update: Time changed!")
|
||||
await tChannels[0].edit(name=tStr)
|
||||
wait = 301
|
||||
except discord.errors.Forbidden:
|
||||
print(currServ.name,
|
||||
"Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.")
|
||||
await do_bot_cleanup(sid)
|
||||
return
|
||||
print(currServ.name, "Update: Time changed!")
|
||||
await limiter.submit(tChannels[0].id, tChannels[0], tStr)
|
||||
else:
|
||||
await currServ.create_voice_channel(tStr,
|
||||
overwrites=overwrites)
|
||||
|
|
@ -639,6 +633,9 @@ async def status_task(sid: int):
|
|||
if len(tChannels) > 0:
|
||||
for channel in tChannels:
|
||||
await channel.delete()
|
||||
|
||||
print(currServ.name, "Query:", ip, str(players) + "/" + str(max), datetime.now(timezone.utc).strftime("%H:%M:%S"),
|
||||
str(deltaSeconds), "IP:", str(await limiter.get_time_remaining(iChannels[0].id)), "P:", str(await limiter.get_time_remaining(pChannels[0].id)), "T:", str(await limiter.get_time_remaining(tChannels[0].id)))
|
||||
|
||||
await asyncio.sleep(wait)
|
||||
first_iter = False
|
||||
Loading…
Add table
Add a link
Reference in a new issue