"""
RunTaskHandler - executes tasks.
This is the handler that actually runs task implementations.
It handles execution, retries, timeouts, and result processing.
Uses bulkman for bulkhead pattern (per-task-type isolation) and
resilient_circuit for circuit breaker protection.
"""
from __future__ import annotations
import logging
import os
import threading
import time
from datetime import timedelta
from typing import TYPE_CHECKING
from resilient_circuit import ExponentialDelay
from stabilize.errors import (
TaskTimeoutError,
TransientVerificationError,
VerificationError,
is_transient,
)
from stabilize.handlers.base import StabilizeHandler
from stabilize.handlers.run_task.error import (
complete_with_error,
handle_cancellation,
handle_exception,
)
from stabilize.handlers.run_task.execution import execute_with_timeout, resolve_task
from stabilize.handlers.run_task.result import get_backoff_period, process_result
from stabilize.handlers.run_task.verification import verify_task_outputs
from stabilize.metrics import Timer
from stabilize.models.status import WorkflowStatus
from stabilize.persistence.transaction import TransactionHelper
from stabilize.queue.messages import CompleteTask, PauseTask, RunTask
from stabilize.resilience.bulkheads import TaskBulkheadManager
from stabilize.resilience.circuits import WorkflowCircuitFactory
from stabilize.resilience.config import HandlerConfig, ResilienceConfig
from stabilize.resilience.process_executor import ProcessIsolatedTaskExecutor
from stabilize.resilience.timeouts import TimeoutManager
from stabilize.tasks.interface import RetryableTask
from stabilize.tasks.registry import TaskNotFoundError, TaskRegistry
if TYPE_CHECKING:
from stabilize.events.recorder import EventRecorder
from stabilize.models.stage import StageExecution
from stabilize.models.task import TaskExecution
from stabilize.persistence.store import WorkflowStore
from stabilize.queue import Queue
from stabilize.tasks.result import TaskResult
logger = logging.getLogger(__name__)
[docs]
class RunTaskHandler(StabilizeHandler[RunTask]):
"""
Handler for RunTask messages.
This is where tasks are actually executed. The handler:
1. Resolves the task implementation
2. Checks for cancellation/pause
3. Checks for timeout
4. Executes the task
5. Processes the result
"""
# Class-level lock to track tasks currently being executed (prevents duplicate execution)
# Maps task_id -> start_time (monotonic) for staleness detection
_executing_tasks: dict[str, float] = {}
_executing_lock = threading.Lock()
def __init__(
self,
queue: Queue,
repository: WorkflowStore,
task_registry: TaskRegistry,
retry_delay: timedelta | None = None,
bulkhead_manager: TaskBulkheadManager | None = None,
circuit_factory: WorkflowCircuitFactory | None = None,
handler_config: HandlerConfig | None = None,
event_recorder: EventRecorder | None = None,
) -> None:
super().__init__(queue, repository, retry_delay, handler_config, event_recorder=event_recorder)
self.task_registry = task_registry
self.txn_helper = TransactionHelper(repository, queue)
self.timeout_manager = TimeoutManager(self.handler_config.default_task_timeout_seconds)
# Create backoff calculator from config
self._task_backoff = ExponentialDelay(
min_delay=timedelta(milliseconds=self.handler_config.task_backoff_min_delay_ms),
max_delay=timedelta(milliseconds=self.handler_config.task_backoff_max_delay_ms),
factor=int(self.handler_config.concurrency_backoff_factor),
jitter=self.handler_config.concurrency_jitter,
)
# Check isolation mode
self.isolation_mode = os.environ.get("STABILIZE_ISOLATION_MODE", "thread").lower()
self.process_executor = ProcessIsolatedTaskExecutor() if self.isolation_mode == "process" else None
# Initialize resilience components with defaults if not provided
if bulkhead_manager is None or circuit_factory is None:
config = ResilienceConfig.from_env()
self.bulkhead_manager = bulkhead_manager or TaskBulkheadManager(config)
self.circuit_factory = circuit_factory or WorkflowCircuitFactory(config)
else:
self.bulkhead_manager = bulkhead_manager
self.circuit_factory = circuit_factory
@property
def message_type(self) -> type[RunTask]:
return RunTask
[docs]
def handle(self, message: RunTask) -> None:
"""Handle the RunTask message."""
def on_task(stage: StageExecution, task_model: TaskExecution) -> None:
execution = stage.execution
# IDEMPOTENCY CHECK: Only run tasks that are in RUNNING state
if task_model.status != WorkflowStatus.RUNNING:
if task_model.status.is_complete:
logger.debug(
"Ignoring RunTask for %s - already completed with status %s",
task_model.name,
task_model.status,
)
else:
logger.warning(
"Ignoring RunTask for %s - unexpected status %s (expected RUNNING)",
task_model.name,
task_model.status,
)
# Mark message as processed to prevent redelivery
if message.message_id:
with self.repository.transaction(self.queue) as txn:
txn.mark_message_processed(message.message_id)
return
# Resolve task implementation
try:
task = resolve_task(message.task_type, task_model, self.task_registry)
except TaskNotFoundError as e:
logger.error("Task type not found: %s", message.task_type)
complete_with_error(
stage,
task_model,
message,
str(e),
self.repository,
self.txn_helper,
self.retry_on_concurrency_error,
)
return
# Check execution state
if execution.is_canceled:
handle_cancellation(
stage,
task_model,
task,
message,
self.repository,
self.txn_helper,
self.retry_on_concurrency_error,
)
return
if execution.status.is_complete:
# Atomic: mark message processed + push CompleteTask
self.txn_helper.execute_atomic(
source_message=message,
messages_to_push=[
(
CompleteTask(
execution_type=message.execution_type,
execution_id=message.execution_id,
stage_id=message.stage_id,
task_id=message.task_id,
status=WorkflowStatus.CANCELED,
),
None,
)
],
handler_name="RunTask",
)
return
if execution.status == WorkflowStatus.PAUSED:
# Atomic: mark message processed + push PauseTask
self.txn_helper.execute_atomic(
source_message=message,
messages_to_push=[
(
PauseTask(
execution_type=message.execution_type,
execution_id=message.execution_id,
stage_id=message.stage_id,
task_id=message.task_id,
),
None,
)
],
handler_name="RunTask",
)
return
# Check for manual skip
if stage.context.get("manualSkip"):
# Atomic: mark message processed + push CompleteTask
self.txn_helper.execute_atomic(
source_message=message,
messages_to_push=[
(
CompleteTask(
execution_type=message.execution_type,
execution_id=message.execution_id,
stage_id=message.stage_id,
task_id=message.task_id,
status=WorkflowStatus.SKIPPED,
),
None,
)
],
handler_name="RunTask",
)
return
# Stage-level elapsed time check
stage_timeout_ms = stage.context.get("stageTimeoutMs")
if stage_timeout_ms is not None and stage.start_time:
stage_elapsed_ms = self.current_time_millis() - stage.start_time
if stage_elapsed_ms > stage_timeout_ms:
logger.info(
"Stage %s exceeded stage timeout (%dms > %dms)",
stage.name,
stage_elapsed_ms,
stage_timeout_ms,
)
complete_with_error(
stage,
task_model,
message,
f"Stage exceeded timeout ({stage_elapsed_ms}ms > {stage_timeout_ms}ms)",
self.repository,
self.txn_helper,
self.retry_on_concurrency_error,
)
return
# Total-lifecycle timeout for RetryableTask (execute_with_timeout only caps single calls)
if isinstance(task, RetryableTask) and task_model.start_time:
elapsed_ms = self.current_time_millis() - task_model.start_time
allowed_ms = task.get_dynamic_timeout(stage).total_seconds() * 1000
if elapsed_ms > allowed_ms:
logger.info(
"RetryableTask %s exceeded total timeout (%dms > %dms)",
task_model.name,
elapsed_ms,
int(allowed_ms),
)
timeout_result = task.on_timeout(stage) if hasattr(task, "on_timeout") else None
if timeout_result is None:
complete_with_error(
stage,
task_model,
message,
f"Task exceeded total timeout ({elapsed_ms}ms > {int(allowed_ms)}ms)",
self.repository,
self.txn_helper,
self.retry_on_concurrency_error,
)
else:
self._process_result_safely(message.stage_id, message.task_id, timeout_result, message)
return
# CONCURRENT EXECUTION CHECK: Prevent duplicate execution of same task
stale_threshold_s = 3600
with RunTaskHandler._executing_lock:
existing_start = RunTaskHandler._executing_tasks.get(task_model.id)
if existing_start is not None:
elapsed = time.monotonic() - existing_start
if elapsed < stale_threshold_s:
logger.debug(
"Ignoring duplicate RunTask for %s - already executing (%.1fs)",
task_model.name,
elapsed,
)
return
else:
logger.warning(
"Stale execution lock for task %s (%.1fs > %ds), allowing re-execution",
task_model.name,
elapsed,
stale_threshold_s,
)
RunTaskHandler._executing_tasks[task_model.id] = time.monotonic()
# Execute the task with timeout enforcement
try:
timeout = self.timeout_manager.get_task_timeout(stage, task)
with Timer(
"task_execution_seconds",
task_type=message.task_type,
task_name=task_model.name,
):
result = execute_with_timeout(
task,
stage,
timeout,
message,
self.bulkhead_manager,
self.circuit_factory,
self.process_executor,
)
# Verify outputs before processing/persisting
verify_task_outputs(stage, result, self.task_registry)
self._process_result_safely(message.stage_id, message.task_id, result, message)
except TaskTimeoutError as e:
logger.info("Task %s timed out: %s", task_model.name, e)
timeout_result = task.on_timeout(stage) if hasattr(task, "on_timeout") else None
if timeout_result is None:
complete_with_error(
stage,
task_model,
message,
str(e),
self.repository,
self.txn_helper,
self.retry_on_concurrency_error,
)
else:
self._process_result_safely(message.stage_id, message.task_id, timeout_result, message)
except TransientVerificationError as e:
logger.info(
"Verification pending for task %s, will retry: %s",
task_model.name,
e,
)
handle_exception(
stage,
task_model,
task,
message,
e,
self.repository,
self.txn_helper,
self._get_backoff_period,
self.retry_on_concurrency_error,
)
except VerificationError as e:
logger.error("Verification failed for task %s: %s", task_model.name, e)
complete_with_error(
stage,
task_model,
message,
str(e),
self.repository,
self.txn_helper,
self.retry_on_concurrency_error,
)
except Exception as e:
if is_transient(e):
logger.debug(
"Transient error executing task %s: %s",
task_model.name,
e,
)
else:
logger.error(
"Error executing task %s: %s",
task_model.name,
e,
exc_info=True,
)
handle_exception(
stage,
task_model,
task,
message,
e,
self.repository,
self.txn_helper,
self._get_backoff_period,
self.retry_on_concurrency_error,
)
finally:
# Release the execution lock so other messages can be processed
with RunTaskHandler._executing_lock:
RunTaskHandler._executing_tasks.pop(task_model.id, None)
self.with_task(message, on_task)
def _process_result_safely(
self,
stage_id: str,
task_id: str,
result: TaskResult,
message: RunTask,
) -> None:
"""Process result with retry on concurrency error."""
def do_process() -> None:
# Always reload stage to get latest version
stage = self.repository.retrieve_stage(stage_id)
if not stage:
logger.error("Stage %s not found processing result", stage_id)
return
# Find task model
task_model = next((t for t in stage.tasks if t.id == task_id), None)
if not task_model:
logger.error("Task %s not found in stage %s", task_id, stage_id)
return
process_result(
stage,
task_model,
result,
message,
self.txn_helper,
self._get_backoff_period,
)
self.retry_on_concurrency_error(do_process, f"processing result for task {task_id}")
def _get_backoff_period(
self,
stage: StageExecution,
task_model: TaskExecution,
message: RunTask,
attempt: int = 1,
) -> timedelta:
"""Calculate backoff period for retry."""
return get_backoff_period(
stage,
task_model,
message,
attempt,
self.task_registry,
self._task_backoff,
self.current_time_millis,
)