Source code for automation_file.core.rate_limit

"""Token-bucket rate limiter.

:class:`RateLimiter` refills at ``rate`` tokens/second up to a burst capacity.
Callers acquire N tokens before issuing a protected call; when empty, the
limiter either blocks (up to ``timeout``) or raises
:class:`RateLimitExceededException`.
"""

from __future__ import annotations

import threading
import time
from collections.abc import Callable
from functools import wraps
from typing import Any, TypeVar

from automation_file.exceptions import RateLimitExceededException

F = TypeVar("F", bound=Callable[..., Any])


[docs] class RateLimiter: """Thread-safe token bucket.""" def __init__(self, rate: float, burst: float | None = None) -> None: if rate <= 0: raise ValueError("rate must be > 0") cap = float(burst) if burst is not None else float(rate) if cap <= 0: raise ValueError("burst must be > 0") self._rate = float(rate) self._capacity = cap self._tokens = cap self._updated = time.monotonic() self._cv = threading.Condition() @property def capacity(self) -> float: return self._capacity def _refill_locked(self) -> None: now = time.monotonic() elapsed = now - self._updated if elapsed > 0: self._tokens = min(self._capacity, self._tokens + elapsed * self._rate) self._updated = now
[docs] def try_acquire(self, tokens: float = 1.0) -> bool: """Take ``tokens`` without blocking. Return True on success.""" if tokens <= 0: raise ValueError("tokens must be > 0") with self._cv: self._refill_locked() if self._tokens >= tokens: self._tokens -= tokens return True return False
[docs] def acquire(self, tokens: float = 1.0, timeout: float | None = None) -> None: """Block until ``tokens`` are available. Raises :class:`RateLimitExceededException` if ``timeout`` elapses first. ``timeout=None`` waits indefinitely; ``timeout=0`` fails immediately. """ if tokens <= 0: raise ValueError("tokens must be > 0") if tokens > self._capacity: raise ValueError(f"tokens {tokens} exceeds capacity {self._capacity}") deadline = None if timeout is None else time.monotonic() + timeout with self._cv: while True: self._refill_locked() if self._tokens >= tokens: self._tokens -= tokens return needed = tokens - self._tokens wait_for = needed / self._rate if deadline is not None: remaining = deadline - time.monotonic() if remaining <= 0: raise RateLimitExceededException( f"rate limit: could not acquire {tokens} tokens within timeout" ) wait_for = min(wait_for, remaining) self._cv.wait(timeout=wait_for)
[docs] def wraps(self, tokens: float = 1.0, timeout: float | None = None) -> Callable[[F], F]: """Return a decorator that acquires ``tokens`` before each call.""" def decorator(func: F) -> F: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: self.acquire(tokens=tokens, timeout=timeout) return func(*args, **kwargs) return wrapper # type: ignore[return-value] return decorator