Source code for stabilize.events.subscriptions

"""
Durable subscriptions for event processing.

Durable subscriptions persist their position and can resume
after restarts, ensuring at-least-once delivery of events.
"""

from __future__ import annotations

import logging
import threading
import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from stabilize.events.base import Event, EventType

if TYPE_CHECKING:
    from stabilize.events.store.interface import EventStore

logger = logging.getLogger(__name__)


[docs] @dataclass class DurableSubscription: """ A durable subscription that survives restarts. Position is persisted to the database, allowing the subscription to resume from where it left off. """ id: str handler: Callable[[Event], None] event_types: list[EventType] | None = None entity_filter: dict[str, Any] | None = None last_sequence: int = 0 webhook_url: str | None = None enabled: bool = True error_count: int = 0 max_errors: int = 10
[docs] def matches(self, event: Event) -> bool: """Check if this subscription should receive the event.""" if not self.enabled: return False if self.event_types is not None: if event.event_type not in self.event_types: return False if self.entity_filter is not None: # Filter by entity type if "entity_type" in self.entity_filter: if event.entity_type.value != self.entity_filter["entity_type"]: return False # Filter by workflow if "workflow_id" in self.entity_filter: if event.workflow_id != self.entity_filter["workflow_id"]: return False return True
[docs] class SubscriptionManager: """ Manages durable subscriptions that survive restarts. Features: - Persisted position (survives restarts) - Catch-up from last position - At-least-once delivery - Automatic error handling with backoff - Optional webhook integration """ def __init__( self, event_store: EventStore, poll_interval: float = 1.0, batch_size: int = 100, ) -> None: """ Initialize the subscription manager. Args: event_store: The event store to poll. poll_interval: Seconds between polls. batch_size: Maximum events per poll. """ self._event_store = event_store self._poll_interval = poll_interval self._batch_size = batch_size self._subscriptions: dict[str, DurableSubscription] = {} self._lock = threading.Lock() self._running = False self._poll_thread: threading.Thread | None = None
[docs] def create_subscription( self, subscription_id: str, handler: Callable[[Event], None], event_types: list[EventType] | None = None, entity_filter: dict[str, Any] | None = None, webhook_url: str | None = None, start_from: str = "latest", ) -> None: """ Create a durable subscription. Args: subscription_id: Unique identifier for this subscription. handler: Function to call with each event. event_types: Only receive these event types (None = all). entity_filter: Filter by entity properties. webhook_url: Optional webhook URL to call. start_from: "latest" (current position), "beginning" (sequence 0), or a sequence number string. """ # Determine starting position if start_from == "latest": last_sequence = self._event_store.get_current_sequence() elif start_from == "beginning": last_sequence = 0 else: last_sequence = int(start_from) subscription = DurableSubscription( id=subscription_id, handler=handler, event_types=event_types, entity_filter=entity_filter, last_sequence=last_sequence, webhook_url=webhook_url, ) with self._lock: self._subscriptions[subscription_id] = subscription # Persist subscription if hasattr(self._event_store, "save_subscription"): getattr(self._event_store, "save_subscription")( subscription_id=subscription_id, event_types=event_types, entity_filter=entity_filter, last_sequence=last_sequence, webhook_url=webhook_url, ) logger.info( "Created subscription %s starting from sequence %d", subscription_id, last_sequence, )
[docs] def delete_subscription(self, subscription_id: str) -> bool: """ Delete a subscription. Args: subscription_id: The subscription to delete. Returns: True if subscription was found and deleted. """ with self._lock: if subscription_id in self._subscriptions: del self._subscriptions[subscription_id] if hasattr(self._event_store, "delete_subscription"): getattr(self._event_store, "delete_subscription")(subscription_id) logger.info("Deleted subscription %s", subscription_id) return True
[docs] def load_subscription( self, subscription_id: str, handler: Callable[[Event], None], ) -> bool: """ Load a persisted subscription. Args: subscription_id: The subscription to load. handler: Handler function for events. Returns: True if subscription was found and loaded. """ if not hasattr(self._event_store, "get_subscription"): return False data = getattr(self._event_store, "get_subscription")(subscription_id) if data is None: return False subscription = DurableSubscription( id=subscription_id, handler=handler, event_types=data.get("event_types"), entity_filter=data.get("entity_filter"), last_sequence=data.get("last_sequence", 0), webhook_url=data.get("webhook_url"), ) with self._lock: self._subscriptions[subscription_id] = subscription logger.info( "Loaded subscription %s from sequence %d", subscription_id, subscription.last_sequence, ) return True
[docs] def start(self) -> None: """Start polling for events.""" if self._running: return self._running = True self._poll_thread = threading.Thread(target=self._poll_loop, daemon=True) self._poll_thread.start() logger.info("Subscription manager started")
[docs] def stop(self) -> None: """Stop polling for events.""" self._running = False if self._poll_thread: self._poll_thread.join(timeout=5.0) self._poll_thread = None logger.info("Subscription manager stopped")
def _poll_loop(self) -> None: """Main polling loop.""" while self._running: try: self._process_pending_events() except Exception as e: logger.exception("Error in subscription poll loop: %s", e) time.sleep(self._poll_interval) def _process_pending_events(self) -> None: """Process pending events for all subscriptions.""" with self._lock: subscriptions = list(self._subscriptions.values()) for subscription in subscriptions: if not subscription.enabled: continue try: self._process_subscription(subscription) except Exception as e: logger.exception( "Error processing subscription %s: %s", subscription.id, e, ) subscription.error_count += 1 # Disable subscription if too many errors if subscription.error_count >= subscription.max_errors: subscription.enabled = False logger.error( "Subscription %s disabled after %d errors", subscription.id, subscription.error_count, ) def _process_subscription(self, subscription: DurableSubscription) -> None: """Process events for a single subscription.""" events = self._event_store.get_events_since( subscription.last_sequence, limit=self._batch_size, ) if not events: return processed_sequence = subscription.last_sequence for event in events: if subscription.matches(event): try: subscription.handler(event) except Exception as e: logger.error( "Handler error for subscription %s on event %s: %s", subscription.id, event.event_id, e, ) # Don't update sequence - will retry on next poll raise processed_sequence = event.sequence # Update position subscription.last_sequence = processed_sequence subscription.error_count = 0 # Reset on success # Persist position if hasattr(self._event_store, "update_subscription_sequence"): getattr(self._event_store, "update_subscription_sequence")( subscription.id, processed_sequence, )
[docs] def process_once(self) -> int: """ Process pending events once (for testing). Returns: Number of events processed. """ total = 0 with self._lock: subscriptions = list(self._subscriptions.values()) for subscription in subscriptions: if not subscription.enabled: continue events = self._event_store.get_events_since( subscription.last_sequence, limit=self._batch_size, ) for event in events: if subscription.matches(event): subscription.handler(event) total += 1 subscription.last_sequence = event.sequence return total
[docs] def get_subscription_status(self, subscription_id: str) -> dict[str, Any] | None: """Get status of a subscription.""" with self._lock: subscription = self._subscriptions.get(subscription_id) if subscription is None: return None current_sequence = self._event_store.get_current_sequence() return { "id": subscription.id, "enabled": subscription.enabled, "last_sequence": subscription.last_sequence, "current_sequence": current_sequence, "lag": current_sequence - subscription.last_sequence, "error_count": subscription.error_count, "event_types": ([et.value for et in subscription.event_types] if subscription.event_types else None), }
[docs] def get_all_subscription_status(self) -> list[dict[str, Any]]: """Get status of all subscriptions.""" with self._lock: subscription_ids = list(self._subscriptions.keys()) return [status for sub_id in subscription_ids if (status := self.get_subscription_status(sub_id)) is not None]