Source code for stabilize.handlers.base

"""
Base message handler classes.

This module provides the base classes for all message handlers in the
pipeline execution engine.
"""

from __future__ import annotations

import logging
import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Generic, TypeVar

from resilient_circuit import ExponentialDelay, RetryWithBackoffPolicy
from resilient_circuit.exceptions import RetryLimitReached

from stabilize.errors import ConcurrencyError
from stabilize.models.stage import StageExecution
from stabilize.models.status import WorkflowStatus, validate_transition
from stabilize.models.task import TaskExecution
from stabilize.models.workflow import Workflow
from stabilize.persistence.store import WorkflowNotFoundError
from stabilize.queue.messages import (
    CompleteWorkflow,
    ContinueParentStage,
    InvalidStageId,
    InvalidTaskId,
    InvalidWorkflowId,
    Message,
    StageLevel,
    StartStage,
    TaskLevel,
    WorkflowLevel,
)
from stabilize.resilience.config import HandlerConfig, get_handler_config

if TYPE_CHECKING:
    from stabilize.events.recorder import EventRecorder
    from stabilize.persistence.store import WorkflowStore
    from stabilize.queue import Queue

logger = logging.getLogger(__name__)

M = TypeVar("M", bound=Message)


[docs] class MessageHandler(ABC, Generic[M]): """ Base class for message handlers. Each handler processes a specific type of message. """ @property @abstractmethod def message_type(self) -> type[M]: """Return the type of message this handler processes.""" pass
[docs] @abstractmethod def handle(self, message: M) -> None: """Handle a message.""" pass
[docs] class StabilizeHandler(MessageHandler[M], ABC): """ Base handler with common utilities. Provides helper methods for retrieving executions, stages, and tasks, as well as the startNext() implementation. Attributes: queue: The message queue for pushing messages repository: The workflow store for persistence retry_delay: Delay before re-queuing messages (from config or override) handler_config: Configuration for retry behavior and other settings """ def __init__( self, queue: Queue, repository: WorkflowStore, retry_delay: timedelta | None = None, handler_config: HandlerConfig | None = None, event_recorder: EventRecorder | None = None, ) -> None: self.queue = queue self.repository = repository self.handler_config = handler_config or get_handler_config() # Use explicit retry_delay if provided, otherwise use config self.retry_delay = retry_delay or timedelta(seconds=self.handler_config.handler_retry_delay_seconds) # Event recording is optional - use provided recorder or global instance self._event_recorder = event_recorder # ========== Event Recording ========== @property def event_recorder(self) -> EventRecorder | None: """Get the event recorder, preferring instance over global.""" if self._event_recorder is not None: return self._event_recorder # Fall back to global recorder if available from stabilize.events.recorder import get_event_recorder return get_event_recorder()
[docs] def set_event_context(self, workflow_id: str) -> None: """Set event context for correlation tracking. Call this at the start of handler processing to establish correlation context for any events recorded during handling. Args: workflow_id: The workflow ID for correlation. """ from stabilize.events.recorder import set_event_context set_event_context(correlation_id=workflow_id)
# ========== Execution Retrieval ==========
[docs] def with_execution( self, message: WorkflowLevel, block: Callable[[Workflow], None], ) -> None: """ Execute a block with the execution for a message. Args: message: Message containing execution ID block: Function to call with the execution """ try: execution = self.repository.retrieve(message.execution_id) except WorkflowNotFoundError: logger.error("Execution not found: %s", message.execution_id) # Use atomic transaction with deduplication to prevent infinite retries with self.repository.transaction(self.queue) as txn: if hasattr(message, "message_id") and message.message_id: txn.mark_message_processed( message_id=message.message_id, handler_type="with_execution", execution_id=message.execution_id, ) txn.push_message( InvalidWorkflowId( execution_type=message.execution_type, execution_id=message.execution_id, ) ) return except Exception as e: logger.error("Failed to retrieve execution %s: %s", message.execution_id, e) raise block(execution)
[docs] def with_stage( self, message: StageLevel, block: Callable[[StageExecution], None], ) -> None: """ Execute a block with the stage for a message. Args: message: Message containing stage ID block: Function to call with the stage """ try: stage = self.repository.retrieve_stage(message.stage_id) except ValueError: logger.error("Stage not found: %s", message.stage_id) # Use atomic transaction with deduplication to prevent infinite retries with self.repository.transaction(self.queue) as txn: if hasattr(message, "message_id") and message.message_id: txn.mark_message_processed( message_id=message.message_id, handler_type="with_stage", execution_id=message.execution_id, ) txn.push_message( InvalidStageId( execution_type=message.execution_type, execution_id=message.execution_id, stage_id=message.stage_id, ) ) return except Exception as e: logger.error("Failed to retrieve stage %s: %s", message.stage_id, e) raise block(stage)
[docs] def with_task( self, message: TaskLevel, block: Callable[[StageExecution, TaskExecution], None], ) -> None: """ Execute a block with the stage and task for a message. Args: message: Message containing task ID block: Function to call with (stage, task) """ def on_stage(stage: StageExecution) -> None: task = self._find_task(stage, message.task_id) if task is None: logger.error("Task not found: %s", message.task_id) # Use atomic transaction with deduplication to prevent infinite retries with self.repository.transaction(self.queue) as txn: if hasattr(message, "message_id") and message.message_id: txn.mark_message_processed( message_id=message.message_id, handler_type="with_task", execution_id=message.execution_id, ) txn.push_message( InvalidTaskId( execution_type=message.execution_type, execution_id=message.execution_id, stage_id=message.stage_id, task_id=message.task_id, ) ) else: block(stage, task) self.with_stage(message, on_stage)
def _find_task( self, stage: StageExecution, task_id: str, ) -> TaskExecution | None: """Find a task by ID in a stage.""" for task in stage.tasks: if task.id == task_id: return task return None # ========== Stage Navigation ==========
[docs] def start_next(self, stage: StageExecution) -> None: """ Start the next stage(s) after a stage completes. This is the critical method for DAG traversal: 1. Find downstream stages (those that depend on this stage) 2. Push StartStage for each downstream stage 3. If this is a synthetic stage, notify parent 4. If no downstream and not synthetic, complete execution """ execution = stage.execution downstream_stages = self.repository.get_downstream_stages(execution.id, stage.ref_id) phase = stage.synthetic_stage_owner if downstream_stages: # Start all downstream stages for downstream in downstream_stages: self.queue.push( StartStage( execution_type=execution.type.value, execution_id=execution.id, stage_id=downstream.id, ) ) elif phase is not None: # Synthetic stage - notify parent # Use atomic transaction to ensure message is persisted with stage state parent_id = stage.parent_stage_id if parent_id: with self.repository.transaction(self.queue) as txn: txn.push_message( ContinueParentStage( execution_type=execution.type.value, execution_id=execution.id, stage_id=parent_id, phase=phase, ) ) else: logger.error( "Synthetic stage %s has phase=%s but no parent_stage_id. " "Data inconsistency - completing workflow to prevent hang. Execution: %s", stage.id, phase, execution.id, ) # Complete workflow to prevent hang from data inconsistency self.queue.push( CompleteWorkflow( execution_type=execution.type.value, execution_id=execution.id, ) ) else: # Top-level stage with no downstream - complete execution self.queue.push( CompleteWorkflow( execution_type=execution.type.value, execution_id=execution.id, ) )
# ========== Utility Methods ==========
[docs] def current_time_millis(self) -> int: """Get current time in milliseconds.""" return int(time.time() * 1000)
[docs] def retry_on_concurrency_error( self, func: Callable[[], None], context: str = "operation", ) -> None: """Execute a function with retry on ConcurrencyError. Uses resilient-circuit's RetryWithBackoffPolicy for consistent retry behavior. Configuration comes from handler_config settings. Set concurrency_max_retries to 0 to disable retries entirely. Args: func: The function to execute context: Description for logging (e.g., "completing task", "starting stage") Raises: ConcurrencyError: If max retries exceeded """ max_retries = self.handler_config.concurrency_max_retries # If retries disabled, just run once if max_retries == 0: func() return # Create backoff configuration from handler_config concurrency_backoff = ExponentialDelay( min_delay=timedelta(milliseconds=self.handler_config.concurrency_min_delay_ms), max_delay=timedelta(milliseconds=self.handler_config.concurrency_max_delay_ms), factor=int(self.handler_config.concurrency_backoff_factor), jitter=self.handler_config.concurrency_jitter, ) # Create retry policy with ConcurrencyError handling retry_policy = RetryWithBackoffPolicy( max_retries=max_retries, backoff=concurrency_backoff, should_handle=lambda e: isinstance(e, ConcurrencyError), ) @retry_policy def with_retry() -> None: func() try: with_retry() except RetryLimitReached as e: logger.error( "Failed %s after %d attempts due to contention", context, max_retries, ) # Re-raise the original ConcurrencyError if e.__cause__ and isinstance(e.__cause__, ConcurrencyError): raise e.__cause__ from e raise ConcurrencyError(f"Max retries exceeded for {context}") from e
# ========== State Transition Helpers ==========
[docs] def set_stage_status( self, stage: StageExecution, new_status: WorkflowStatus, ) -> None: """Set stage status with validation. Validates the transition is allowed before setting the status. Args: stage: The stage to update new_status: The new status to set Raises: InvalidStateTransitionError: If the transition is not allowed """ validate_transition( stage.status, new_status, entity_type="stage", entity_id=stage.id, ) stage.status = new_status
[docs] def set_task_status( self, task: TaskExecution, new_status: WorkflowStatus, ) -> None: """Set task status with validation. Validates the transition is allowed before setting the status. Args: task: The task to update new_status: The new status to set Raises: InvalidStateTransitionError: If the transition is not allowed """ validate_transition( task.status, new_status, entity_type="task", entity_id=task.id, ) task.status = new_status
[docs] def set_workflow_status( self, workflow: Workflow, new_status: WorkflowStatus, ) -> None: """Set workflow status with validation. Validates the transition is allowed before setting the status. Args: workflow: The workflow to update new_status: The new status to set Raises: InvalidStateTransitionError: If the transition is not allowed """ validate_transition( workflow.status, new_status, entity_type="workflow", entity_id=workflow.id, ) workflow.status = new_status