Async rate limit sleeps
This commit is contained in:
parent
065d887a9f
commit
855e4c305e
3 changed files with 158 additions and 35 deletions
|
|
@ -5,10 +5,10 @@ WORKDIR /app
|
||||||
COPY docker-entrypoint.sh ./entrypoint.sh
|
COPY docker-entrypoint.sh ./entrypoint.sh
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
COPY schema.sql .
|
COPY schema.sql .
|
||||||
COPY mstbot.py .
|
ADD src src
|
||||||
|
|
||||||
RUN apt update && apt install sqlite3
|
RUN apt update && apt upgrade -y && apt install -y sqlite3
|
||||||
RUN pip install --upgrade pip && pip install -r requirements.txt
|
RUN pip install --upgrade pip && pip install -r requirements.txt
|
||||||
|
|
||||||
ENTRYPOINT [ "/app/entrypoint.sh" ]
|
ENTRYPOINT [ "/app/entrypoint.sh" ]
|
||||||
CMD [ "python3", "-u", "mstbot.py" ]
|
CMD [ "python3", "-u", "src/mstbot.py" ]
|
||||||
|
|
|
||||||
126
src/erl.py
Normal file
126
src/erl.py
Normal file
|
|
@ -0,0 +1,126 @@
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Awaitable, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointRateLimiter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_interval: float,
|
||||||
|
func: Callable[..., Awaitable[Any]],
|
||||||
|
):
|
||||||
|
self.min_interval = min_interval
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
self._workers: Dict[str, "_EndpointWorker"] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def get_state(self, endpoint: str):
|
||||||
|
async with self._lock:
|
||||||
|
worker = self._workers.get(endpoint)
|
||||||
|
if worker is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await worker.get_state()
|
||||||
|
|
||||||
|
async def get_time_remaining(self, endpoint: str):
|
||||||
|
async with self._lock:
|
||||||
|
worker = self._workers.get(endpoint)
|
||||||
|
if worker is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await worker.get_sleep_time()
|
||||||
|
|
||||||
|
async def submit(self, endpoint: str, *args, **kwargs):
|
||||||
|
async with self._lock:
|
||||||
|
worker = self._workers.get(endpoint)
|
||||||
|
if worker is None:
|
||||||
|
worker = _EndpointWorker(
|
||||||
|
endpoint,
|
||||||
|
self.min_interval,
|
||||||
|
self.func,
|
||||||
|
)
|
||||||
|
self._workers[endpoint] = worker
|
||||||
|
asyncio.create_task(worker.run())
|
||||||
|
|
||||||
|
await worker.submit(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class _EndpointWorker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
min_interval: float,
|
||||||
|
func: Callable[..., Awaitable[Any]],
|
||||||
|
):
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.min_interval = min_interval
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
self._desired_state = None
|
||||||
|
self._applied_state = None
|
||||||
|
self._last_run = 0.0
|
||||||
|
|
||||||
|
self._event = asyncio.Event()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
# ---- State handling -------------------------------------------------
|
||||||
|
|
||||||
|
def state_from_args(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Override if args do not directly represent endpoint state.
|
||||||
|
Must return a comparable, immutable value.
|
||||||
|
"""
|
||||||
|
return args, frozenset(kwargs.items())
|
||||||
|
|
||||||
|
async def submit(self, *args, **kwargs):
|
||||||
|
desired = self.state_from_args(*args, **kwargs)
|
||||||
|
async with self._lock:
|
||||||
|
self._desired_state = desired
|
||||||
|
self._event.set()
|
||||||
|
|
||||||
|
async def get_state(self):
|
||||||
|
async with self._lock:
|
||||||
|
state = self._desired_state
|
||||||
|
if state is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
args, kw_items = state
|
||||||
|
return args, dict(kw_items)
|
||||||
|
|
||||||
|
async def get_sleep_time(self):
|
||||||
|
return self.min_interval - (time.monotonic() - self._last_run)
|
||||||
|
|
||||||
|
# ---- Worker loop ----------------------------------------------------
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
while self._running:
|
||||||
|
await self._event.wait()
|
||||||
|
self._event.clear()
|
||||||
|
|
||||||
|
# Rate limit
|
||||||
|
sleep_for = await self.get_sleep_time()
|
||||||
|
if sleep_for > 0:
|
||||||
|
await asyncio.sleep(sleep_for)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
if self._desired_state == self._applied_state:
|
||||||
|
# Net-zero change → skip call
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("Applying state...")
|
||||||
|
state_to_apply = self._desired_state
|
||||||
|
|
||||||
|
try:
|
||||||
|
args, kw_items = state_to_apply
|
||||||
|
kwargs = dict(kw_items)
|
||||||
|
await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
self._applied_state = state_to_apply
|
||||||
|
self._last_run = time.monotonic()
|
||||||
|
except Exception:
|
||||||
|
print("An exception occurred in EndpointRateLimiter!")
|
||||||
|
# applied_state intentionally unchanged on failure
|
||||||
|
pass
|
||||||
|
|
@ -23,6 +23,13 @@ import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from mcstatus import JavaServer
|
from mcstatus import JavaServer
|
||||||
|
|
||||||
|
from erl import EndpointRateLimiter
|
||||||
|
|
||||||
|
async def change_channel_name(channel, name: str):
|
||||||
|
await channel.edit(name=name)
|
||||||
|
|
||||||
|
limiter = EndpointRateLimiter(301, change_channel_name)
|
||||||
|
|
||||||
# Bot Initialization
|
# Bot Initialization
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.messages = True
|
intents.messages = True
|
||||||
|
|
@ -478,6 +485,16 @@ def ascii_clean(s):
|
||||||
if unicodedata.category(c)[0] != "C"
|
if unicodedata.category(c)[0] != "C"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_last_channel_name(channel):
|
||||||
|
last_name = channel.name
|
||||||
|
latest = await limiter.get_state(channel.id)
|
||||||
|
if latest:
|
||||||
|
args, _ = latest
|
||||||
|
last_name = args[1]
|
||||||
|
|
||||||
|
return last_name
|
||||||
|
|
||||||
|
|
||||||
async def safe_send(msg: str, ctx: Union[discord.ext.commands.context.Context, None] = None,
|
async def safe_send(msg: str, ctx: Union[discord.ext.commands.context.Context, None] = None,
|
||||||
chan: Union[discord.TextChannel, None] = None, format: str = ''):
|
chan: Union[discord.TextChannel, None] = None, format: str = ''):
|
||||||
try:
|
try:
|
||||||
|
|
@ -530,8 +547,6 @@ async def status_task(sid: int):
|
||||||
|
|
||||||
deltaSeconds = (currTime - lastTime).total_seconds()
|
deltaSeconds = (currTime - lastTime).total_seconds()
|
||||||
|
|
||||||
print(currServ.name, "Query:", ip, str(players) + "/" + str(max), datetime.now(timezone.utc).strftime("%H:%M:%S"),
|
|
||||||
str(deltaSeconds))
|
|
||||||
|
|
||||||
setMCNames(sid, conn, names)
|
setMCNames(sid, conn, names)
|
||||||
setMCQueryTime(sid, conn, currTime)
|
setMCQueryTime(sid, conn, currTime)
|
||||||
|
|
@ -578,24 +593,17 @@ async def status_task(sid: int):
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(iChannels) > 0:
|
if len(iChannels) > 0:
|
||||||
lastIPName = iChannels[0].name
|
lastIPName = await get_last_channel_name(iChannels[0])
|
||||||
|
|
||||||
ipStr = "IP: " + ip
|
ipStr = "IP: " + ip
|
||||||
if lastIPName != ipStr:
|
if lastIPName != ipStr:
|
||||||
try:
|
print(currServ.name, "Update: IP changed!")
|
||||||
print(currServ.name, "Update: Ip changed!")
|
await limiter.submit(iChannels[0].id, iChannels[0], ipStr)
|
||||||
await iChannels[0].edit(name=ipStr)
|
|
||||||
wait = 301
|
|
||||||
except discord.errors.Forbidden:
|
|
||||||
print(currServ.name,
|
|
||||||
"Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.")
|
|
||||||
await do_bot_cleanup(sid)
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
await currServ.create_voice_channel("IP: " + ip, overwrites=overwrites)
|
await currServ.create_voice_channel("IP: " + ip, overwrites=overwrites)
|
||||||
|
|
||||||
if len(pChannels) > 0:
|
if len(pChannels) > 0:
|
||||||
lastPName = pChannels[0].name
|
lastPName = await get_last_channel_name(pChannels[0])
|
||||||
|
|
||||||
if players == -1:
|
if players == -1:
|
||||||
pStr = lastPName
|
pStr = lastPName
|
||||||
|
|
@ -603,15 +611,8 @@ async def status_task(sid: int):
|
||||||
pStr = "Players: " + str(players) + "/" + str(max)
|
pStr = "Players: " + str(players) + "/" + str(max)
|
||||||
|
|
||||||
if lastPName != pStr:
|
if lastPName != pStr:
|
||||||
try:
|
print(currServ.name, "Update: Players changed!")
|
||||||
print(currServ.name, "Update: Players changed!")
|
await limiter.submit(pChannels[0].id, pChannels[0], pStr)
|
||||||
await pChannels[0].edit(name=pStr)
|
|
||||||
wait = 301
|
|
||||||
except discord.errors.Forbidden:
|
|
||||||
print(currServ.name,
|
|
||||||
"Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.")
|
|
||||||
await do_bot_cleanup(sid)
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
await currServ.create_voice_channel(f"Players: {str(players)}/{str(max)}",
|
await currServ.create_voice_channel(f"Players: {str(players)}/{str(max)}",
|
||||||
overwrites=overwrites)
|
overwrites=overwrites)
|
||||||
|
|
@ -620,18 +621,11 @@ async def status_task(sid: int):
|
||||||
if do_show_hours and not first_iter and lastSeconds != -1:
|
if do_show_hours and not first_iter and lastSeconds != -1:
|
||||||
tStr = "Player Hrs: " + str(round(getPersonSeconds(sid, conn)/3600))
|
tStr = "Player Hrs: " + str(round(getPersonSeconds(sid, conn)/3600))
|
||||||
if len(tChannels) > 0:
|
if len(tChannels) > 0:
|
||||||
lastTName = tChannels[0].name
|
lastTName = await get_last_channel_name(tChannels[0])
|
||||||
|
|
||||||
if lastTName != tStr:
|
if lastTName != tStr:
|
||||||
try:
|
print(currServ.name, "Update: Time changed!")
|
||||||
print(currServ.name, "Update: Time changed!")
|
await limiter.submit(tChannels[0].id, tChannels[0], tStr)
|
||||||
await tChannels[0].edit(name=tStr)
|
|
||||||
wait = 301
|
|
||||||
except discord.errors.Forbidden:
|
|
||||||
print(currServ.name,
|
|
||||||
"Error: I don't have permission to edit channels. Try deleting the channels I create. Then, run the `setup` command again.")
|
|
||||||
await do_bot_cleanup(sid)
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
await currServ.create_voice_channel(tStr,
|
await currServ.create_voice_channel(tStr,
|
||||||
overwrites=overwrites)
|
overwrites=overwrites)
|
||||||
|
|
@ -639,6 +633,9 @@ async def status_task(sid: int):
|
||||||
if len(tChannels) > 0:
|
if len(tChannels) > 0:
|
||||||
for channel in tChannels:
|
for channel in tChannels:
|
||||||
await channel.delete()
|
await channel.delete()
|
||||||
|
|
||||||
|
print(currServ.name, "Query:", ip, str(players) + "/" + str(max), datetime.now(timezone.utc).strftime("%H:%M:%S"),
|
||||||
|
str(deltaSeconds), "IP:", str(await limiter.get_time_remaining(iChannels[0].id)), "P:", str(await limiter.get_time_remaining(pChannels[0].id)), "T:", str(await limiter.get_time_remaining(tChannels[0].id)))
|
||||||
|
|
||||||
await asyncio.sleep(wait)
|
await asyncio.sleep(wait)
|
||||||
first_iter = False
|
first_iter = False
|
||||||
Loading…
Add table
Add a link
Reference in a new issue