Source code for automation_file.core.circuit_breaker

"""Three-state circuit breaker.

States: CLOSED (normal), OPEN (short-circuit after ``failure_threshold``
consecutive failures), HALF_OPEN (trial one call after ``recovery_timeout``
seconds; one success closes, one failure re-opens). Failures are counted
only for exceptions in ``retriable`` — internal errors surface as-is.
"""

from __future__ import annotations

import threading
import time
from collections.abc import Callable
from functools import wraps
from typing import Any, TypeVar

from automation_file.exceptions import CircuitOpenException
from automation_file.logging_config import file_automation_logger

F = TypeVar("F", bound=Callable[..., Any])

_STATE_CLOSED = "closed"
_STATE_OPEN = "open"
_STATE_HALF_OPEN = "half_open"


[docs] class CircuitBreaker: """Open-close-half-open breaker. ``failure_threshold`` — consecutive failures that trip the breaker. ``recovery_timeout`` — seconds spent in OPEN before transitioning to HALF_OPEN. ``retriable`` — exception types counted as failures; other exceptions pass through. """ def __init__( self, failure_threshold: int = 5, recovery_timeout: float = 30.0, retriable: tuple[type[BaseException], ...] = (Exception,), name: str = "circuit", ) -> None: if failure_threshold < 1: raise ValueError("failure_threshold must be >= 1") if recovery_timeout <= 0: raise ValueError("recovery_timeout must be > 0") self._failure_threshold = failure_threshold self._recovery_timeout = float(recovery_timeout) self._retriable = retriable self._name = name self._state = _STATE_CLOSED self._failures = 0 self._opened_at = 0.0 self._lock = threading.Lock() @property def state(self) -> str: with self._lock: self._maybe_transition_locked() return self._state def _maybe_transition_locked(self) -> None: if self._state == _STATE_OPEN and ( time.monotonic() - self._opened_at >= self._recovery_timeout ): self._state = _STATE_HALF_OPEN file_automation_logger.info("circuit %s: open -> half_open", self._name) def _on_success_locked(self) -> None: if self._state == _STATE_HALF_OPEN: file_automation_logger.info("circuit %s: half_open -> closed", self._name) self._state = _STATE_CLOSED self._failures = 0 def _on_failure_locked(self) -> None: self._failures += 1 if self._state == _STATE_HALF_OPEN or self._failures >= self._failure_threshold: self._state = _STATE_OPEN self._opened_at = time.monotonic() file_automation_logger.warning( "circuit %s: opened after %d failures", self._name, self._failures )
[docs] def call(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """Invoke ``func`` through the breaker.""" with self._lock: self._maybe_transition_locked() if self._state == _STATE_OPEN: raise CircuitOpenException(f"circuit {self._name!r} is open") try: result = func(*args, **kwargs) except self._retriable as error: with self._lock: self._on_failure_locked() raise error with self._lock: self._on_success_locked() return result
[docs] def wraps(self) -> Callable[[F], F]: """Return a decorator that routes every call through :meth:`call`.""" def decorator(func: F) -> F: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: return self.call(func, *args, **kwargs) return wrapper # type: ignore[return-value] return decorator
[docs] def reset(self) -> None: """Force the breaker back to CLOSED, clearing failure count.""" with self._lock: self._state = _STATE_CLOSED self._failures = 0 self._opened_at = 0.0