Source code for stabilize.tasks.shell

"""
Enterprise-ready ShellTask for executing shell commands.

This module provides a production-ready ShellTask with:
- Working directory support
- Environment variable injection
- Shell selection (bash, sh, custom)
- Stdin input support
- Output size limits (prevent OOM)
- Expected exit codes (for tools that return non-zero on success)
- Secret masking in logs
- Binary output mode
- {key} placeholder substitution with upstream outputs
- Cross-platform process tree cleanup (uses psutil when available)
- Linux-specific PR_SET_PDEATHSIG for auto-cleanup when parent dies
"""

from __future__ import annotations

import base64
import ctypes
import logging
import os
import shlex
import signal
import subprocess
import sys
from typing import TYPE_CHECKING, Any

from stabilize.errors import TransientError
from stabilize.tasks.interface import Task
from stabilize.tasks.result import TaskResult

if TYPE_CHECKING:
    from stabilize.models.stage import StageExecution

logger = logging.getLogger(__name__)

# Keys that should not be substituted as placeholders
RESERVED_KEYS = frozenset(
    {
        "command",
        "timeout",
        "cwd",
        "env",
        "shell",
        "stdin",
        "max_output_size",
        "expected_codes",
        "secrets",
        "binary",
        "continue_on_failure",
        "restart_on_failure",
    }
)


[docs] class ShellTask(Task): """ Enterprise-ready shell command execution. Executes shell commands with full control over execution environment, input/output handling, and error management. Context Parameters: command (str): The shell command to execute (required) timeout (int): Command timeout in seconds (default: 60) cwd (str): Working directory for command execution env (dict): Additional environment variables to set shell (bool|str): True for default shell, or path to shell executable stdin (str): Input to send to command's stdin max_output_size (int): Max bytes for stdout/stderr (default: 10MB) expected_codes (list[int]): Exit codes to treat as success (default: [0]) secrets (list[str]): Context keys whose values should be masked in logs binary (bool): If True, capture output as bytes (default: False) continue_on_failure (bool): If True, return failed_continue instead of terminal restart_on_failure (bool): If True, raise TransientError on failure to trigger automatic retry with backoff (for long-running services) Outputs: stdout (str|bytes): Command standard output (stripped if text mode) stderr (str|bytes): Command standard error (stripped if text mode) returncode (int): Command exit code truncated (bool): True if output was truncated due to size limit stdout_b64 (str): Base64-encoded stdout (only if binary=True) Placeholder Substitution: Any {key} in the command is replaced with stage.context[key]. This includes outputs from upstream stages. Examples: # Basic command context={"command": "ls -la"} # With working directory context={"command": "npm install", "cwd": "/app/frontend"} # With environment variables context={"command": "./deploy.sh", "env": {"AWS_REGION": "us-east-1"}} # With custom shell context={"command": "source venv/bin/activate && pytest", "shell": "/bin/bash"} # With stdin input context={"command": "cat", "stdin": "Hello World"} # Allow grep's exit code 1 (no match) context={"command": "grep pattern file.txt", "expected_codes": [0, 1]} # Mask secrets in logs context={ "command": "curl -H 'Authorization: Bearer {token}' https://api.example.com", "token": "secret123", "secrets": ["token"] } # Binary output context={"command": "cat image.png", "binary": True} # Using upstream output context={"command": "echo 'Previous output: {stdout}'"} """
[docs] def execute(self, stage: StageExecution) -> TaskResult: """Execute the shell command with all configured options.""" command = stage.context.get("command") if not command: return TaskResult.terminal(error="No 'command' specified in context") # Extract options with defaults timeout: int = stage.context.get("timeout", 60) cwd: str | None = stage.context.get("cwd") env: dict[str, str] = stage.context.get("env", {}) shell: bool | str = stage.context.get("shell", True) stdin_input: str | None = stage.context.get("stdin") max_output_size: int = stage.context.get("max_output_size", 10 * 1024 * 1024) expected_codes: list[int] = stage.context.get("expected_codes", [0]) secrets: list[str] = stage.context.get("secrets", []) binary: bool = stage.context.get("binary", False) continue_on_failure: bool = stage.context.get("continue_on_failure", False) restart_on_failure: bool = stage.context.get("restart_on_failure", False) # Substitute {key} placeholders with context values (includes upstream outputs) for key, value in stage.context.items(): if key not in RESERVED_KEYS: placeholder = "{" + key + "}" if placeholder in command: if isinstance(value, str): # Quote value to prevent injection quoted_value = shlex.quote(value) command = command.replace(placeholder, quoted_value) elif value is not None: # Non-strings are converted but also quoted if they might contain dangerous chars? # Generally numbers are safe, but objects str() repr might not be. # Safer to quote everything. quoted_value = shlex.quote(str(value)) command = command.replace(placeholder, quoted_value) # Build full environment (inherit + custom) full_env = os.environ.copy() for k, v in env.items(): full_env[k] = str(v) if not isinstance(v, str) else v # Exclude empty strings: str.replace("", "***") corrupts entire output. # Sort longest-first: prevents partial masking of overlapping secrets. secret_values: list[str] = [] for secret_key in secrets: secret_value = stage.context.get(secret_key) if secret_value is not None and str(secret_value) != "": secret_values.append(str(secret_value)) secret_values.sort(key=len, reverse=True) log_command = command for sv in secret_values: log_command = log_command.replace(sv, "***") logger.debug("ShellTask executing: %s", log_command) try: # Build subprocess arguments for Popen popen_kwargs: dict[str, Any] = { "stdout": subprocess.PIPE, "stderr": subprocess.PIPE, "preexec_fn": self._preexec_fn, # Process isolation + Linux auto-cleanup } if isinstance(shell, str): real_path = os.path.realpath(shell) if not os.path.isfile(real_path): return TaskResult.terminal(error=f"Shell not found: {shell}") if not os.access(real_path, os.X_OK): return TaskResult.terminal(error=f"Shell not executable: {shell}") if not os.path.isabs(real_path): return TaskResult.terminal(error=f"Shell must be an absolute path: {shell}") popen_kwargs["shell"] = True popen_kwargs["executable"] = real_path else: popen_kwargs["shell"] = shell # Handle stdin - must use PIPE for input if stdin_input is not None: popen_kwargs["stdin"] = subprocess.PIPE # Optional arguments if cwd: popen_kwargs["cwd"] = cwd if full_env: popen_kwargs["env"] = full_env # Execute command using Popen for process group management proc: subprocess.Popen[bytes] = subprocess.Popen(command, **popen_kwargs) try: # Communicate with timeout stdin_data: bytes | None = None if stdin_input is not None: stdin_data = stdin_input.encode() if not binary else stdin_input.encode() stdout_bytes, stderr_bytes = proc.communicate(input=stdin_data, timeout=timeout) assert isinstance(stdout_bytes, bytes) assert isinstance(stderr_bytes, bytes) except subprocess.TimeoutExpired: # Kill entire process tree on timeout self._kill_process_tree(proc) # Collect any partial output try: stdout_bytes, stderr_bytes = proc.communicate(timeout=2) except subprocess.TimeoutExpired: stdout_bytes = b"" stderr_bytes = b"" proc.kill() proc.wait() timeout_outputs: dict[str, Any] = {"returncode": -1, "truncated": False} if stdout_bytes: if len(stdout_bytes) > max_output_size: stdout_bytes = stdout_bytes[:max_output_size] if binary: timeout_outputs["stdout"] = stdout_bytes else: decoded = stdout_bytes.decode("utf-8", errors="replace").strip() for sv in secret_values: decoded = decoded.replace(sv, "***") timeout_outputs["stdout"] = decoded if stderr_bytes: if len(stderr_bytes) > max_output_size: stderr_bytes = stderr_bytes[:max_output_size] if binary: timeout_outputs["stderr"] = stderr_bytes else: decoded = stderr_bytes.decode("utf-8", errors="replace").strip() for sv in secret_values: decoded = decoded.replace(sv, "***") timeout_outputs["stderr"] = decoded error_msg = f"Command timed out after {timeout}s" if restart_on_failure: raise TransientError(error_msg, context_update={"_last_outputs": timeout_outputs}) if continue_on_failure: return TaskResult.failed_continue(error=error_msg, outputs=timeout_outputs) return TaskResult.terminal(error=error_msg, context=timeout_outputs) # Process output with size limits truncated = False if len(stdout_bytes) > max_output_size: stdout_bytes = stdout_bytes[:max_output_size] truncated = True if len(stderr_bytes) > max_output_size: stderr_bytes = stderr_bytes[:max_output_size] truncated = True # Build outputs outputs: dict[str, Any] = { "returncode": proc.returncode, "truncated": truncated, } if binary: outputs["stdout"] = stdout_bytes outputs["stderr"] = stderr_bytes outputs["stdout_b64"] = base64.b64encode(stdout_bytes).decode("ascii") else: outputs["stdout"] = stdout_bytes.decode("utf-8", errors="replace").strip() if stdout_bytes else "" outputs["stderr"] = stderr_bytes.decode("utf-8", errors="replace").strip() if stderr_bytes else "" # Mask secrets in outputs before returning if secret_values: for key in ("stdout", "stderr"): val = outputs.get(key) if isinstance(val, str): for sv in secret_values: val = val.replace(sv, "***") outputs[key] = val # Check exit code if proc.returncode in expected_codes: return TaskResult.success(outputs=outputs) else: error_msg = f"Command exited with code {proc.returncode} (expected: {expected_codes})" if outputs.get("stderr"): error_msg += f": {outputs['stderr'][:200]}" if restart_on_failure: raise TransientError(error_msg, context_update={"_last_outputs": outputs}) if continue_on_failure: return TaskResult.failed_continue(error=error_msg, outputs=outputs) return TaskResult.terminal(error=error_msg, context=outputs) except TransientError: # Let TransientError propagate for retry handling raise except FileNotFoundError as e: return TaskResult.terminal(error=f"Command or shell not found: {e}") except PermissionError as e: return TaskResult.terminal(error=f"Permission denied: {e}") except OSError as e: return TaskResult.terminal(error=f"OS error: {e}") except Exception as e: return TaskResult.terminal(error=f"Unexpected error: {e}")
def _preexec_fn(self) -> None: """Pre-exec function for subprocess to set up process isolation. Creates a new session (process group) and on Linux, sets PR_SET_PDEATHSIG to automatically kill the process when the parent dies. """ os.setsid() # Create new session (same as start_new_session=True) # Linux-only: auto-kill child when parent dies if sys.platform == "linux": try: pr_set_pdeathsig = 1 libc = ctypes.CDLL("libc.so.6", use_errno=True) libc.prctl(pr_set_pdeathsig, signal.SIGKILL) except (OSError, AttributeError): pass # Not available, continue without it def _kill_process_tree(self, proc: subprocess.Popen[Any]) -> None: """Kill process and all its descendants using psutil (cross-platform). Uses psutil to traverse the entire process tree and kill all descendants, which is more robust than process group killing for processes that escape the process group. Falls back to process group killing if psutil is not available. Args: proc: The Popen process object """ try: import psutil try: parent = psutil.Process(proc.pid) children = parent.children(recursive=True) # Send SIGTERM to all (children first, leaf to root) for child in reversed(children): try: child.send_signal(signal.SIGTERM) except psutil.NoSuchProcess: pass try: parent.send_signal(signal.SIGTERM) except psutil.NoSuchProcess: pass # Wait for processes to exit, then SIGKILL survivors gone, alive = psutil.wait_procs(children + [parent], timeout=5) for p in alive: try: p.kill() except psutil.NoSuchProcess: pass except psutil.NoSuchProcess: pass # Process already dead except ImportError: # psutil not installed, fall back to process group killing self._kill_process_group_fallback(proc) def _kill_process_group_fallback(self, proc: subprocess.Popen[Any]) -> None: """Kill the entire process group for cleanup on timeout (fallback). Uses SIGTERM first, then SIGKILL if processes don't exit gracefully. This ensures child processes spawned by the shell command are also killed. This is the fallback when psutil is not available. Args: proc: The Popen process object """ try: # Get the process group ID (same as PID since we used setsid) pgid = os.getpgid(proc.pid) # Send SIGTERM to the process group os.killpg(pgid, signal.SIGTERM) # Give processes time to exit gracefully try: proc.wait(timeout=5) except subprocess.TimeoutExpired: # Force kill with SIGKILL if still running try: os.killpg(pgid, signal.SIGKILL) proc.wait(timeout=2) except (ProcessLookupError, OSError): pass # Process already dead except (ProcessLookupError, OSError): # Process or process group already dead pass