diff --git a/pyproject.toml b/pyproject.toml index f3446f3..7d3efc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ name = "weco" authors = [{ name = "Weco AI Team", email = "contact@weco.ai" }] description = "Documentation for `weco`, a CLI for using Weco AI's code optimizer." readme = "README.md" -version = "0.3.30" +version = "0.3.31" license = { file = "LICENSE" } requires-python = ">=3.9" dependencies = [ diff --git a/tests/test_auto_resume.py b/tests/test_auto_resume.py new file mode 100644 index 0000000..d83c025 --- /dev/null +++ b/tests/test_auto_resume.py @@ -0,0 +1,220 @@ +from unittest.mock import MagicMock + +import pytest + +from weco.optimizer import AutoResumePolicy, OptimizationResult, _is_transient, _run_loop_with_auto_resume + + +class FakeUI: + def __init__(self): + self.reconnecting: list[tuple[int, int, float]] = [] + self.reconnected: int = 0 + self.errors: list[str] = [] + self.warnings: list[str] = [] + + def on_reconnecting(self, attempt: int, max_attempts: int, backoff_s: float) -> None: + self.reconnecting.append((attempt, max_attempts, backoff_s)) + + def on_reconnected(self) -> None: + self.reconnected += 1 + + def on_error(self, message: str) -> None: + self.errors.append(message) + + def on_warning(self, message: str) -> None: + self.warnings.append(message) + + +def _make_result(reason: str, *, status: str = "error", success: bool = False, final_step: int = 0) -> OptimizationResult: + return OptimizationResult(success=success, final_step=final_step, status=status, reason=reason) + + +@pytest.fixture +def stub_sleep(monkeypatch): + sleeps: list[float] = [] + monkeypatch.setattr("weco.optimizer.time.sleep", lambda s: sleeps.append(s)) + return sleeps + + +@pytest.fixture +def stub_resume(monkeypatch): + calls: list[str] = [] + + def _install(outcomes: list[bool]): + iterator = iter(outcomes) + + def _fake(run_id, auth_headers, api_keys): + calls.append(run_id) + return next(iterator) + + monkeypatch.setattr("weco.optimizer._silent_resume", _fake) + return calls + + return _install + + +def _drive(results: list[OptimizationResult], *, policy: AutoResumePolicy | None = None, initial_start_step: int = 0): + factory = MagicMock(side_effect=results) + ui = FakeUI() + returned = _run_loop_with_auto_resume( + factory, + ui=ui, + run_id="run-1", + auth_headers={}, + api_keys=None, + policy=policy or AutoResumePolicy(), + initial_start_step=initial_start_step, + ) + return returned, factory, ui + + +@pytest.mark.parametrize( + "reason,transient", + [ + ("transient_network_error", True), + ("http_502", True), + ("http_503", True), + ("http_504", True), + ("http_401", False), + ("http_402", False), + ("http_500", False), + ("user_terminated_sigint", False), + ("completed_successfully", False), + ("user_requested_stop", False), + ("timeout_waiting_for_tasks", False), + ("unknown", False), + ], +) +def test_classifies_transient_reasons(reason, transient): + assert _is_transient(_make_result(reason)) is transient + + +def test_returns_verbatim_when_not_transient(stub_sleep, stub_resume): + stub_resume([]) + completed = _make_result("completed_successfully", status="completed", success=True, final_step=7) + + returned, factory, ui = _drive([completed]) + + assert returned is completed + assert factory.call_count == 1 + assert stub_sleep == [] + assert ui.reconnecting == [] + assert ui.reconnected == 0 + assert ui.errors == [] + + +def test_resumes_once_then_continues_from_final_step(stub_sleep, stub_resume): + resume_calls = stub_resume([True]) + transient = _make_result("transient_network_error", final_step=4) + completed = _make_result("completed_successfully", status="completed", success=True, final_step=9) + + returned, factory, ui = _drive([transient, completed]) + + assert returned is completed + assert factory.call_count == 2 + assert factory.call_args_list[0].args == (0,) + assert factory.call_args_list[1].args == (4,) + assert resume_calls == ["run-1"] + assert len(stub_sleep) == 1 + assert len(ui.reconnecting) == 1 + assert ui.reconnected == 1 + assert ui.errors == [] + + +def test_exhausts_after_max_attempts_and_returns_original_result(stub_sleep, stub_resume): + resume_calls = stub_resume([True, True, True]) + transient = _make_result("transient_network_error", final_step=2) + policy = AutoResumePolicy(max_attempts=3) + + returned, factory, ui = _drive([transient, transient, transient, transient], policy=policy) + + assert returned.reason == "transient_network_error" + assert factory.call_count == 4 + assert len(resume_calls) == 3 + assert len(ui.reconnecting) == 3 + assert ui.reconnected == 3 + assert len(ui.errors) == 1 + assert "exhausted after 3" in ui.errors[0] + + +def test_disabled_policy_skips_resume_on_transient(stub_sleep, stub_resume): + resume_calls = stub_resume([]) + transient = _make_result("transient_network_error", final_step=2) + + returned, factory, ui = _drive([transient], policy=AutoResumePolicy(enabled=False)) + + assert returned is transient + assert factory.call_count == 1 + assert resume_calls == [] + assert stub_sleep == [] + assert ui.reconnecting == [] + assert ui.reconnected == 0 + assert ui.errors == [] + + +def test_silent_resume_failure_retries_without_reinvoking_loop(stub_sleep, stub_resume): + resume_calls = stub_resume([False, True]) + transient = _make_result("transient_network_error", final_step=3) + completed = _make_result("completed_successfully", status="completed", success=True) + + returned, factory, ui = _drive([transient, completed], policy=AutoResumePolicy(max_attempts=3)) + + assert returned is completed + assert factory.call_count == 2 + assert len(resume_calls) == 2 + assert len(stub_sleep) == 2 + assert len(ui.reconnecting) == 2 + assert ui.reconnected == 1 + assert ui.errors == [] + + +def test_silent_resume_exhaustion_without_reinvoking_loop(stub_sleep, stub_resume): + resume_calls = stub_resume([False, False]) + transient = _make_result("transient_network_error", final_step=3) + + returned, factory, ui = _drive([transient], policy=AutoResumePolicy(max_attempts=2)) + + assert returned is transient + assert factory.call_count == 1 + assert len(resume_calls) == 2 + assert len(ui.reconnecting) == 2 + assert ui.reconnected == 0 + assert len(ui.errors) == 1 + assert "exhausted after 2" in ui.errors[0] + + +def test_backoff_is_exponential_and_capped(stub_sleep, stub_resume): + stub_resume([True, True, True, True, True]) + transient = _make_result("transient_network_error") + completed = _make_result("completed_successfully", status="completed", success=True) + policy = AutoResumePolicy(max_attempts=5, backoff_initial_s=1.0, backoff_factor=2.0, backoff_max_s=5.0) + + _drive([transient, transient, transient, transient, completed], policy=policy) + + assert stub_sleep == [1.0, 2.0, 4.0, 5.0] + + +def test_keyboard_interrupt_result_propagates_untouched(stub_sleep, stub_resume): + resume_calls = stub_resume([]) + interrupted = _make_result("user_terminated_sigint", status="terminated") + + returned, factory, ui = _drive([interrupted]) + + assert returned is interrupted + assert factory.call_count == 1 + assert resume_calls == [] + assert stub_sleep == [] + assert ui.reconnecting == [] + assert ui.reconnected == 0 + assert ui.errors == [] + + +def test_reconnecting_event_carries_attempt_and_backoff(stub_sleep, stub_resume): + stub_resume([True]) + transient = _make_result("transient_network_error", final_step=2) + completed = _make_result("completed_successfully", status="completed", success=True) + policy = AutoResumePolicy(max_attempts=5, backoff_initial_s=3.0, backoff_factor=2.0, backoff_max_s=30.0) + + _, _, ui = _drive([transient, completed], policy=policy) + + assert ui.reconnecting == [(1, 5, 3.0)] diff --git a/weco/cli.py b/weco/cli.py index 4046be6..722d808 100644 --- a/weco/cli.py +++ b/weco/cli.py @@ -165,6 +165,17 @@ def configure_run_parser(run_parser: argparse.ArgumentParser) -> None: help="Output mode: 'rich' for interactive terminal UI (default), 'plain' for machine-readable text output suitable for LLM agents.", ) run_parser.add_argument("--submit-timeout", type=int, default=None, help=argparse.SUPPRESS) + run_parser.add_argument( + "--no-auto-resume", + action="store_true", + help="Disable automatic reconnection/resume on transient network errors (default: enabled).", + ) + run_parser.add_argument( + "--auto-resume-max-attempts", + type=int, + default=5, + help="Max auto-resume attempts before giving up and printing the manual resume command (default: 5).", + ) # --- Eval backend integration --- run_parser.add_argument( @@ -370,6 +381,17 @@ def configure_resume_parser(resume_parser: argparse.ArgumentParser) -> None: help="Output mode: 'rich' for interactive terminal UI (default), 'plain' for machine-readable text output suitable for LLM agents.", ) resume_parser.add_argument("--submit-timeout", type=int, default=None, help=argparse.SUPPRESS) + resume_parser.add_argument( + "--no-auto-resume", + action="store_true", + help="Disable automatic reconnection/resume on transient network errors (default: enabled).", + ) + resume_parser.add_argument( + "--auto-resume-max-attempts", + type=int, + default=5, + help="Max auto-resume attempts before giving up and printing the manual resume command (default: 5).", + ) def _dispatch_run_subcommand(sub: str, args: argparse.Namespace) -> None: @@ -431,7 +453,7 @@ def _collect_source_paths() -> list[str] | None: def execute_run_command(args: argparse.Namespace) -> None: """Execute the 'weco run' command with all its logic.""" - from .optimizer import optimize + from .optimizer import AutoResumePolicy, optimize ctx = get_event_context() @@ -505,6 +527,10 @@ def execute_run_command(args: argparse.Namespace) -> None: ctx, ) + auto_resume_policy = AutoResumePolicy( + enabled=not getattr(args, "no_auto_resume", False), max_attempts=getattr(args, "auto_resume_max_attempts", 5) + ) + success = optimize( source=source_arg, eval_command=args.eval_command, @@ -521,6 +547,7 @@ def execute_run_command(args: argparse.Namespace) -> None: require_review=args.require_review, output_mode=args.output, submit_timeout=getattr(args, "submit_timeout", None), + auto_resume_policy=auto_resume_policy, ) exit_code = 0 if success else 1 @@ -529,7 +556,7 @@ def execute_run_command(args: argparse.Namespace) -> None: def execute_resume_command(args: argparse.Namespace) -> None: """Execute the 'weco resume' command with all its logic.""" - from .optimizer import resume_optimization + from .optimizer import AutoResumePolicy, resume_optimization try: api_keys = parse_api_keys(args.api_key) @@ -537,12 +564,17 @@ def execute_resume_command(args: argparse.Namespace) -> None: console.print(f"[bold red]Error parsing API keys: {e}[/]") sys.exit(1) + auto_resume_policy = AutoResumePolicy( + enabled=not getattr(args, "no_auto_resume", False), max_attempts=getattr(args, "auto_resume_max_attempts", 5) + ) + success = resume_optimization( run_id=args.run_id, api_keys=api_keys, apply_change=args.apply_change, output_mode=args.output, submit_timeout=getattr(args, "submit_timeout", None), + auto_resume_policy=auto_resume_policy, ) sys.exit(0 if success else 1) diff --git a/weco/optimizer.py b/weco/optimizer.py index a24586d..a63ab35 100644 --- a/weco/optimizer.py +++ b/weco/optimizer.py @@ -5,9 +5,9 @@ import time import traceback from dataclasses import dataclass -from typing import Optional +from typing import Callable, Optional -from requests.exceptions import HTTPError +from requests.exceptions import ConnectionError as RequestsConnectionError, HTTPError, ReadTimeout from rich.console import Console from rich.prompt import Confirm @@ -23,6 +23,7 @@ start_optimization_run, submit_execution_result, ) +from .core.api import WecoClient from .artifacts import RunArtifacts from .auth import handle_authentication from .events import get_event_context @@ -42,6 +43,85 @@ class OptimizationResult: details: Optional[str] = None +@dataclass +class AutoResumePolicy: + """Policy for auto-resuming a run after transient errors.""" + + enabled: bool = True + max_attempts: int = 5 + backoff_initial_s: float = 5.0 + backoff_max_s: float = 60.0 + backoff_factor: float = 2.0 + + +# Reasons produced by run_optimization_loop that indicate a retryable failure. +# 5xx bursts imply layer-2 recovery already tried and gave up; waiting and +# resuming the run is the right response. 4xx (auth, validation, insufficient +# credits) and user-driven terminations must propagate. +_TRANSIENT_REASONS = frozenset({"transient_network_error", "http_502", "http_503", "http_504"}) + + +def _is_transient(result: OptimizationResult) -> bool: + return result.reason in _TRANSIENT_REASONS + + +def _silent_resume(run_id: str, auth_headers: dict, api_keys: Optional[dict]) -> bool: + """Flip a run back to 'running' without emitting any console output.""" + try: + WecoClient(auth_headers).resume_run(run_id, api_keys=api_keys) + return True + except Exception: + return False + + +def _run_loop_with_auto_resume( + loop_factory: Callable[[int], OptimizationResult], + *, + ui: "OptimizationUI", + run_id: str, + auth_headers: dict, + api_keys: Optional[dict], + policy: AutoResumePolicy, + initial_start_step: int, +) -> OptimizationResult: + """Invoke the optimization loop; on transient failure, resume and re-enter. + + ``loop_factory(start_step)`` runs one attempt of ``run_optimization_loop``. + If the attempt exits with a transient reason and auto-resume is enabled, + this sleeps with exponential backoff, calls the backend resume endpoint, + and re-enters. Non-transient outcomes (completed, user interrupt, HTTP 4xx) + are returned verbatim. + """ + start_step = initial_start_step + attempts_used = 0 + + while True: + result = loop_factory(start_step) + + if not policy.enabled or not _is_transient(result): + return result + + resumed = False + while attempts_used < policy.max_attempts: + attempts_used += 1 + backoff = min(policy.backoff_initial_s * (policy.backoff_factor ** (attempts_used - 1)), policy.backoff_max_s) + ui.on_reconnecting(attempts_used, policy.max_attempts, backoff) + time.sleep(backoff) + + if _silent_resume(run_id, auth_headers, api_keys): + ui.on_reconnected() + start_step = result.final_step + resumed = True + break + + if not resumed: + ui.on_error( + f"Auto-resume exhausted after {policy.max_attempts} attempt(s). " + f"Use 'weco resume {run_id}' to continue manually." + ) + return result + + # --- Heartbeat Sender Class --- class HeartbeatSender(threading.Thread): def __init__(self, run_id: str, auth_headers: dict, stop_event: threading.Event, interval: int = 30): @@ -224,6 +304,13 @@ def run_optimization_loop( except KeyboardInterrupt: ui.on_interrupted() return OptimizationResult(success=False, final_step=step, status="terminated", reason="user_terminated_sigint") + except (RequestsConnectionError, ReadTimeout) as e: + # Tagged separately so the outer auto-resume wrapper can distinguish + # transport failures from unrecoverable errors. + ui.on_warning(f"Network error during optimization: {e}") + return OptimizationResult( + success=False, final_step=step, status="error", reason="transient_network_error", details=str(e) + ) except HTTPError as e: # Surface structured API error details (insufficient credits, auth failures, candidate # generation failures, etc.) through the UI rather than a generic exception string. @@ -318,6 +405,7 @@ def resume_optimization( apply_change: bool = False, output_mode: str = "rich", submit_timeout: Optional[int] = None, + auto_resume_policy: Optional[AutoResumePolicy] = None, ) -> bool: """ Resume an interrupted run using the queue-based optimization loop. @@ -451,19 +539,30 @@ def resume_optimization( if best_metric_value is not None and best_step is not None: ui.on_metric(best_step, best_metric_value) - result = run_optimization_loop( + def _loop(start_step: int) -> OptimizationResult: + return run_optimization_loop( + ui=ui, + run_id=run_id, + auth_headers=auth_headers, + source_code=source_code, + eval_command=eval_command, + eval_timeout=eval_timeout, + artifacts=artifacts, + save_logs=save_logs, + start_step=start_step, + poll_interval=poll_interval, + api_keys=api_keys, + submit_timeout=submit_timeout, + ) + + result = _run_loop_with_auto_resume( + _loop, ui=ui, run_id=run_id, auth_headers=auth_headers, - source_code=source_code, - eval_command=eval_command, - eval_timeout=eval_timeout, - artifacts=artifacts, - save_logs=save_logs, - start_step=current_step, - poll_interval=poll_interval, api_keys=api_keys, - submit_timeout=submit_timeout, + policy=auto_resume_policy or AutoResumePolicy(), + initial_start_step=current_step, ) # Stop heartbeat immediately after loop completes @@ -524,6 +623,7 @@ def optimize( require_review: bool = False, output_mode: str = "rich", submit_timeout: Optional[int] = None, + auto_resume_policy: Optional[AutoResumePolicy] = None, ) -> bool: """ Simplified queue-based optimization loop. @@ -638,19 +738,31 @@ def optimize( with ui_instance as ui: ui.on_init() - result = run_optimization_loop( + + def _loop(start_step: int) -> OptimizationResult: + return run_optimization_loop( + ui=ui, + run_id=run_id, + auth_headers=auth_headers, + source_code=source_code, + eval_command=eval_command, + eval_timeout=eval_timeout, + artifacts=artifacts, + save_logs=save_logs, + start_step=start_step, + poll_interval=poll_interval, + api_keys=api_keys, + submit_timeout=submit_timeout, + ) + + result = _run_loop_with_auto_resume( + _loop, ui=ui, run_id=run_id, auth_headers=auth_headers, - source_code=source_code, - eval_command=eval_command, - eval_timeout=eval_timeout, - artifacts=artifacts, - save_logs=save_logs, - start_step=0, - poll_interval=poll_interval, api_keys=api_keys, - submit_timeout=submit_timeout, + policy=auto_resume_policy or AutoResumePolicy(), + initial_start_step=0, ) # Stop heartbeat immediately after loop completes diff --git a/weco/ui/base.py b/weco/ui/base.py index c92579c..659a9b0 100644 --- a/weco/ui/base.py +++ b/weco/ui/base.py @@ -60,6 +60,14 @@ def on_error(self, message: str) -> None: """Called for errors.""" ... + def on_reconnecting(self, attempt: int, max_attempts: int, backoff_s: float) -> None: + """Called when the client is waiting to auto-resume after a transient error.""" + ... + + def on_reconnected(self) -> None: + """Called after a successful auto-resume, before the loop resumes polling.""" + ... + @dataclass class UIState: @@ -67,8 +75,11 @@ class UIState: step: int = 0 total_steps: int = 0 - status: str = "initializing" # polling, executing, submitting, complete, stopped, error + status: str = "initializing" # polling, executing, submitting, reconnecting, complete, stopped, error plan_preview: str = "" output_preview: str = "" metrics: List[tuple] = field(default_factory=list) # (step, value) error: Optional[str] = None + reconnect_attempt: int = 0 + reconnect_max_attempts: int = 0 + reconnect_backoff_s: float = 0.0 diff --git a/weco/ui/live.py b/weco/ui/live.py index 14eaaae..b18a026 100644 --- a/weco/ui/live.py +++ b/weco/ui/live.py @@ -27,12 +27,13 @@ class LiveOptimizationUI: SPARKLINE_CHARS = "▁▂▃▄▅▆▇█" SPINNER_FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] # Statuses that show the spinner animation - ACTIVE_STATUSES = {"initializing", "polling", "executing", "submitting"} + ACTIVE_STATUSES = {"initializing", "polling", "executing", "submitting", "reconnecting"} STATUS_INDICATORS = { "initializing": ("⏳", "dim"), "polling": ("🔄", "cyan"), "executing": ("⚡", "yellow"), "submitting": ("🧠", "blue"), + "reconnecting": ("📡", "yellow"), "complete": ("✅", "green"), "stopped": ("⏹", "yellow"), "interrupted": ("⚠", "yellow"), @@ -120,6 +121,12 @@ def _render(self) -> Group: status_text = Text() status_text.append(f"{emoji} ", style=style) status_text.append(self.state.status.replace("_", " ").title(), style=f"bold {style}") + if self.state.status == "reconnecting" and self.state.reconnect_max_attempts > 0: + status_text.append( + f" (attempt {self.state.reconnect_attempt}/{self.state.reconnect_max_attempts}" + f", retry in {self.state.reconnect_backoff_s:.0f}s)", + style=f"bold {style}", + ) if self.state.status in self.ACTIVE_STATUSES: # Time-based frame calculation: ~10 fps spinner animation frame = int(time.time() * 10) % len(self.SPINNER_FRAMES) @@ -261,3 +268,19 @@ def on_error(self, message: str) -> None: self.state.error = message self.state.status = "error" self._update() + + def on_reconnecting(self, attempt: int, max_attempts: int, backoff_s: float) -> None: + self.state.status = "reconnecting" + self.state.reconnect_attempt = attempt + self.state.reconnect_max_attempts = max_attempts + self.state.reconnect_backoff_s = backoff_s + self._update() + + def on_reconnected(self) -> None: + self.state.reconnect_attempt = 0 + self.state.reconnect_max_attempts = 0 + self.state.reconnect_backoff_s = 0.0 + # Status will be overwritten by the next on_polling call; clear here so + # any error/status rendering in between reads a clean slate. + self.state.status = "polling" + self._update() diff --git a/weco/ui/plain.py b/weco/ui/plain.py index f356ade..1709fc6 100644 --- a/weco/ui/plain.py +++ b/weco/ui/plain.py @@ -128,3 +128,9 @@ def on_warning(self, message: str) -> None: def on_error(self, message: str) -> None: self._print(f"[ERROR] {message}") + + def on_reconnecting(self, attempt: int, max_attempts: int, backoff_s: float) -> None: + self._print(f"[RECONNECTING] attempt {attempt}/{max_attempts}, retry in {backoff_s:.0f}s") + + def on_reconnected(self) -> None: + self._print("[RECONNECTED] Auto-resume succeeded; continuing optimization.")