"""
PostgreSQL execution repository.
Production-grade persistence using native psycopg3 with connection pooling.
Uses singleton ConnectionManager for efficient connection pool sharing.
"""
from __future__ import annotations
import json
import logging
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
from stabilize.models.workflow import Workflow
from stabilize.persistence.postgres.converters import (
execution_to_dict,
paused_to_dict,
row_to_execution,
row_to_stage,
row_to_task,
)
from stabilize.persistence.postgres.helpers import insert_stage, upsert_tasks_bulk
from stabilize.persistence.postgres.operations import (
cancel_execution,
pause_execution,
resume_execution,
)
from stabilize.persistence.postgres.operations import (
cleanup_old_processed_messages as _cleanup_old_processed_messages,
)
from stabilize.persistence.postgres.operations import is_message_processed as _is_message_processed
from stabilize.persistence.postgres.operations import (
mark_message_processed as _mark_message_processed,
)
from stabilize.persistence.postgres.queries import get_downstream_stages as _get_downstream_stages
from stabilize.persistence.postgres.queries import (
get_merged_ancestor_outputs as _get_merged_ancestor_outputs,
)
from stabilize.persistence.postgres.queries import get_synthetic_stages as _get_synthetic_stages
from stabilize.persistence.postgres.queries import get_upstream_stages as _get_upstream_stages
from stabilize.persistence.postgres.queries import (
load_tasks_for_stages,
)
from stabilize.persistence.postgres.queries import (
retrieve_by_application as _retrieve_by_application,
)
from stabilize.persistence.postgres.queries import (
retrieve_by_pipeline_config_id as _retrieve_by_pipeline_config_id,
)
from stabilize.persistence.store import (
StoreTransaction,
WorkflowCriteria,
WorkflowNotFoundError,
WorkflowStore,
)
logger = logging.getLogger(__name__)
[docs]
class PostgresWorkflowStore(WorkflowStore):
"""
PostgreSQL implementation of WorkflowStore.
Uses native psycopg3 with connection pooling for database operations.
Supports concurrent access and provides efficient queries for pipeline
execution tracking.
"""
def __init__(self, connection_string: str) -> None:
"""Initialize the repository."""
from stabilize.persistence.connection import get_connection_manager
self.connection_string = connection_string
self._manager = get_connection_manager()
self._pool = self._manager.get_postgres_pool(connection_string)
[docs]
def close(self) -> None:
"""Close the connection pool via connection manager."""
self._manager.close_postgres_pool(self.connection_string)
[docs]
def store(self, execution: Workflow) -> None:
"""Store a complete execution."""
with self._pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO pipeline_executions (
id, type, application, name, status, context, start_time, end_time,
start_time_expiry, trigger, is_canceled, canceled_by,
cancellation_reason, paused, pipeline_config_id,
is_limit_concurrent, max_concurrent_executions,
keep_waiting_pipelines, origin
) VALUES (
%(id)s, %(type)s, %(application)s, %(name)s, %(status)s,
%(context)s::jsonb, %(start_time)s, %(end_time)s, %(start_time_expiry)s,
%(trigger)s::jsonb, %(is_canceled)s, %(canceled_by)s,
%(cancellation_reason)s, %(paused)s::jsonb, %(pipeline_config_id)s,
%(is_limit_concurrent)s, %(max_concurrent_executions)s,
%(keep_waiting_pipelines)s, %(origin)s
)
""",
execution_to_dict(execution),
)
for stage in execution.stages:
insert_stage(cur, stage, execution.id)
conn.commit()
[docs]
def retrieve(self, execution_id: str) -> Workflow:
"""Retrieve an execution by ID."""
with self._pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT * FROM pipeline_executions WHERE id = %(id)s",
{"id": execution_id},
)
row = cur.fetchone()
if not row:
raise WorkflowNotFoundError(execution_id)
execution = row_to_execution(row)
cur.execute(
"SELECT * FROM stage_executions WHERE execution_id = %(execution_id)s",
{"execution_id": execution_id},
)
stages_by_id: dict[str, Any] = {}
stages: list[Any] = []
for stage_row in cur.fetchall():
stage = row_to_stage(stage_row)
stage.execution = execution
stages_by_id[stage.id] = stage
stages.append(stage)
if stages:
load_tasks_for_stages(cur, stages)
execution.stages = stages
return execution
[docs]
def retrieve_execution_summary(self, execution_id: str) -> Workflow:
"""Retrieve execution metadata without stages."""
with self._pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT * FROM pipeline_executions WHERE id = %(id)s",
{"id": execution_id},
)
row = cur.fetchone()
if not row:
raise WorkflowNotFoundError(execution_id)
return row_to_execution(row)
[docs]
def update_status(self, execution: Workflow) -> None:
"""Update execution status."""
with self._pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE pipeline_executions SET
status = %(status)s,
start_time = %(start_time)s,
end_time = %(end_time)s,
is_canceled = %(is_canceled)s,
canceled_by = %(canceled_by)s,
cancellation_reason = %(cancellation_reason)s,
paused = %(paused)s::jsonb
WHERE id = %(id)s
""",
{
"id": execution.id,
"status": execution.status.name,
"start_time": execution.start_time,
"end_time": execution.end_time,
"is_canceled": execution.is_canceled,
"canceled_by": execution.canceled_by,
"cancellation_reason": execution.cancellation_reason,
"paused": (json.dumps(paused_to_dict(execution.paused)) if execution.paused else None),
},
)
conn.commit()
[docs]
def delete(self, execution_id: str) -> None:
"""Delete an execution."""
with self._pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM pipeline_executions WHERE id = %(id)s",
{"id": execution_id},
)
conn.commit()
[docs]
def store_stage(
self,
stage: Any,
expected_phase: str | None = None,
connection: Any | None = None,
) -> None:
"""Store or update a stage.
Args:
stage: The stage to store
expected_phase: If provided, adds status check to WHERE clause for
phase-aware optimistic locking.
connection: Optional existing connection to use
"""
if connection:
with connection.cursor() as cur:
self._store_stage_impl(cur, stage, expected_phase)
else:
with self._pool.connection() as conn:
with conn.cursor() as cur:
self._store_stage_impl(cur, stage, expected_phase)
conn.commit()
def _store_stage_impl(
self,
cur: Any,
stage: Any,
expected_phase: str | None = None,
) -> None:
"""Implementation of store_stage using a cursor with optimistic locking."""
from stabilize.errors import ConcurrencyError
cur.execute(
"SELECT id FROM stage_executions WHERE id = %(id)s",
{"id": stage.id},
)
exists = cur.fetchone() is not None
if exists:
# Build update query with optimistic locking
# Optionally include phase check
if expected_phase is not None:
cur.execute(
"""
UPDATE stage_executions SET
status = %(status)s,
context = %(context)s::jsonb,
outputs = %(outputs)s::jsonb,
start_time = %(start_time)s,
end_time = %(end_time)s,
version = version + 1
WHERE id = %(id)s AND version = %(version)s AND status = %(expected_phase)s
RETURNING version
""",
{
"id": stage.id,
"status": stage.status.name,
"context": json.dumps(stage.context, default=str),
"outputs": json.dumps(stage.outputs),
"start_time": stage.start_time,
"end_time": stage.end_time,
"version": stage.version,
"expected_phase": expected_phase,
},
)
else:
cur.execute(
"""
UPDATE stage_executions SET
status = %(status)s,
context = %(context)s::jsonb,
outputs = %(outputs)s::jsonb,
start_time = %(start_time)s,
end_time = %(end_time)s,
version = version + 1
WHERE id = %(id)s AND version = %(version)s
RETURNING version
""",
{
"id": stage.id,
"status": stage.status.name,
"context": json.dumps(stage.context, default=str),
"outputs": json.dumps(stage.outputs),
"start_time": stage.start_time,
"end_time": stage.end_time,
"version": stage.version,
},
)
result = cur.fetchone()
if result:
new_version = result[0] if isinstance(result, tuple) else result.get("version", stage.version + 1)
stage.version = new_version
else:
if expected_phase is not None:
raise ConcurrencyError(
f"Optimistic lock failed for stage {stage.id} "
f"(version {stage.version}, expected_phase {expected_phase}). "
f"Another process has modified this stage."
)
raise ConcurrencyError(
f"Optimistic lock failed for stage {stage.id} (version {stage.version}). "
f"Another process has modified this stage."
)
if stage.tasks:
upsert_tasks_bulk(cur, stage.tasks, stage.id)
else:
insert_stage(cur, stage, stage.execution.id)
[docs]
def add_stage(self, stage: Any) -> None:
"""Add a new stage."""
self.store_stage(stage)
[docs]
def remove_stage(self, execution: Workflow, stage_id: str) -> None:
"""Remove a stage."""
with self._pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM stage_executions WHERE id = %(id)s",
{"id": stage_id},
)
conn.commit()
[docs]
def retrieve_stage(self, stage_id: str) -> Any:
"""Retrieve a single stage by ID."""
with self._pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT * FROM stage_executions WHERE id = %(id)s",
{"id": stage_id},
)
stage_row = cur.fetchone()
if not stage_row:
raise ValueError(f"Stage {stage_id} not found")
stage = row_to_stage(stage_row)
cur.execute(
"SELECT * FROM pipeline_executions WHERE id = %(id)s",
{"id": stage_row["execution_id"]},
)
exec_row = cur.fetchone()
if exec_row:
execution = row_to_execution(exec_row)
stage.set_execution_strong(execution)
all_stages = [stage]
requisites = list(stage.requisite_stage_ref_ids)
if requisites:
cur.execute(
"""
SELECT * FROM stage_executions
WHERE execution_id = %(execution_id)s
AND ref_id = ANY(%(requisites)s)
""",
{"execution_id": execution.id, "requisites": requisites},
)
for us_row in cur.fetchall():
us = row_to_stage(us_row)
us.set_execution_strong(execution)
all_stages.append(us)
synthetic_stages = self.get_synthetic_stages(execution.id, stage.id)
for ss in synthetic_stages:
ss.set_execution_strong(execution)
all_stages.append(ss)
execution.stages = all_stages
# ORDER BY id ensures consistent task sequencing (ULID encodes creation time)
cur.execute(
"SELECT * FROM task_executions WHERE stage_id = %(stage_id)s ORDER BY id ASC",
{"stage_id": stage.id},
)
for task_row in cur.fetchall():
task = row_to_task(task_row)
task.stage = stage
stage.tasks.append(task)
return stage
[docs]
def get_upstream_stages(self, execution_id: str, stage_ref_id: str) -> list[Any]:
"""Get upstream stages with tasks loaded."""
return _get_upstream_stages(self._pool, execution_id, stage_ref_id)
[docs]
def get_downstream_stages(self, execution_id: str, stage_ref_id: str) -> list[Any]:
"""Get downstream stages with tasks loaded."""
return _get_downstream_stages(self._pool, execution_id, stage_ref_id)
[docs]
def get_synthetic_stages(self, execution_id: str, parent_stage_id: str) -> list[Any]:
"""Get synthetic stages with tasks loaded."""
return _get_synthetic_stages(self._pool, execution_id, parent_stage_id)
[docs]
def get_merged_ancestor_outputs(self, execution_id: str, stage_ref_id: str) -> dict[str, Any]:
"""Get merged outputs from all ancestor stages."""
return _get_merged_ancestor_outputs(self._pool, execution_id, stage_ref_id)
[docs]
def retrieve_by_pipeline_config_id(
self,
pipeline_config_id: str,
criteria: WorkflowCriteria | None = None,
) -> Iterator[Workflow]:
"""Retrieve executions by pipeline config ID."""
return _retrieve_by_pipeline_config_id(self._pool, pipeline_config_id, criteria, self.retrieve)
[docs]
def retrieve_by_application(
self,
application: str,
criteria: WorkflowCriteria | None = None,
) -> Iterator[Workflow]:
"""Retrieve executions by application."""
return _retrieve_by_application(self._pool, application, criteria, self.retrieve)
[docs]
def pause(self, execution_id: str, paused_by: str) -> None:
"""Pause an execution."""
pause_execution(self._pool, execution_id, paused_by)
[docs]
def resume(self, execution_id: str) -> None:
"""Resume a paused execution."""
resume_execution(self._pool, execution_id)
[docs]
def cancel(self, execution_id: str, canceled_by: str, reason: str) -> None:
"""Cancel an execution."""
cancel_execution(self._pool, execution_id, canceled_by, reason)
[docs]
def is_message_processed(self, message_id: str) -> bool:
"""Check if a message has already been processed."""
return _is_message_processed(self._pool, message_id)
[docs]
def mark_message_processed(
self,
message_id: str,
handler_type: str | None = None,
execution_id: str | None = None,
) -> None:
"""Mark a message as successfully processed."""
_mark_message_processed(self._pool, message_id, handler_type, execution_id)
[docs]
def cleanup_old_processed_messages(self, max_age_hours: float = 24.0) -> int:
"""Clean up old processed message records."""
return _cleanup_old_processed_messages(self._pool, max_age_hours)
[docs]
def is_healthy(self) -> bool:
"""Check if the database connection is healthy."""
try:
with self._pool.connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
return True
except Exception:
return False
[docs]
@contextmanager
def transaction(self, queue: Any | None = None) -> Iterator[StoreTransaction]:
"""Create an atomic transaction for store + queue operations."""
from stabilize.persistence.postgres.transaction import PostgresTransaction
with self._pool.connection() as conn:
txn = PostgresTransaction(conn, self, queue)
try:
yield txn
conn.commit()
except Exception:
conn.rollback()
txn.rollback_versions()
raise