Source code for stabilize.handlers.signal_stage

"""
SignalStageHandler - handles external signals for suspended stages.

Implements:
- WCP-23: Transient Trigger - signal is lost if stage not SUSPENDED
- WCP-24: Persistent Trigger - signal is buffered if stage not ready
"""

from __future__ import annotations

import logging
from datetime import timedelta
from typing import TYPE_CHECKING

from stabilize.handlers.base import StabilizeHandler
from stabilize.models.status import WorkflowStatus
from stabilize.queue.messages import RunTask, SignalStage, StartStage
from stabilize.resilience.config import HandlerConfig

if TYPE_CHECKING:
    from stabilize.events.recorder import EventRecorder
    from stabilize.models.stage import StageExecution
    from stabilize.persistence.store import WorkflowStore
    from stabilize.queue import Queue

logger = logging.getLogger(__name__)


[docs] class SignalStageHandler(StabilizeHandler[SignalStage]): """ Handler for SignalStage messages. Execution flow: 1. Check if stage is SUSPENDED - If SUSPENDED: transition to RUNNING, merge signal data, push StartStage - If not SUSPENDED and transient: discard signal (WCP-23) - If not SUSPENDED and persistent: buffer signal for later (WCP-24) """ def __init__( self, queue: Queue, repository: WorkflowStore, retry_delay: timedelta | None = None, handler_config: HandlerConfig | None = None, event_recorder: EventRecorder | None = None, ) -> None: super().__init__(queue, repository, retry_delay, handler_config, event_recorder=event_recorder) @property def message_type(self) -> type[SignalStage]: return SignalStage
[docs] def handle(self, message: SignalStage) -> None: """Handle the SignalStage message.""" self.retry_on_concurrency_error( lambda: self._handle_with_retry(message), f"signaling stage {message.stage_id}", )
def _handle_with_retry(self, message: SignalStage) -> None: """Inner handle logic to be retried.""" def on_stage(stage: StageExecution) -> None: if stage.status == WorkflowStatus.SUSPENDED: # Stage is waiting for a signal - deliver it logger.info( "Delivering signal '%s' to suspended stage %s", message.signal_name, stage.name, ) # Merge signal data into stage context stage.context["_signal_name"] = message.signal_name stage.context["_signal_data"] = message.signal_data # Transition back to RUNNING self.set_stage_status(stage, WorkflowStatus.RUNNING) # Find the suspended task and set it back to RUNNING suspended_task = None for task in stage.tasks: if task.status == WorkflowStatus.SUSPENDED: suspended_task = task task.status = WorkflowStatus.RUNNING break with self.repository.transaction(self.queue) as txn: txn.store_stage(stage) if message.message_id: txn.mark_message_processed( message_id=message.message_id, handler_type="SignalStage", execution_id=message.execution_id, ) if suspended_task: # Push RunTask to re-execute the suspended task txn.push_message( RunTask( execution_type=message.execution_type, execution_id=message.execution_id, stage_id=message.stage_id, task_id=suspended_task.id, ) ) else: # No suspended task found - re-start the stage txn.push_message( StartStage( execution_type=message.execution_type, execution_id=message.execution_id, stage_id=message.stage_id, ) ) return # Stage is not SUSPENDED if message.persistent: # WCP-24: Buffer the signal for later consumption logger.info( "Buffering persistent signal '%s' for stage %s (current status: %s)", message.signal_name, stage.name, stage.status, ) # Store signal in stage context buffer buffered = stage.context.get("_buffered_signals", []) buffered.append( { "signal_name": message.signal_name, "signal_data": message.signal_data, } ) stage.context["_buffered_signals"] = buffered with self.repository.transaction(self.queue) as txn: txn.store_stage(stage) if message.message_id: txn.mark_message_processed( message_id=message.message_id, handler_type="SignalStage", execution_id=message.execution_id, ) else: # WCP-23: Transient signal - discard logger.debug( "Discarding transient signal '%s' for stage %s (not SUSPENDED, status: %s)", message.signal_name, stage.name, stage.status, ) if message.message_id: with self.repository.transaction(self.queue) as txn: txn.mark_message_processed( message_id=message.message_id, handler_type="SignalStage", execution_id=message.execution_id, ) self.with_stage(message, on_stage)