Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "weco"
authors = [{ name = "Weco AI Team", email = "contact@weco.ai" }]
description = "Documentation for `weco`, a CLI for using Weco AI's code optimizer."
readme = "README.md"
version = "0.3.30"
version = "0.3.31"
license = { file = "LICENSE" }
requires-python = ">=3.9"
dependencies = [
Expand Down
220 changes: 220 additions & 0 deletions tests/test_auto_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from unittest.mock import MagicMock

import pytest

from weco.optimizer import AutoResumePolicy, OptimizationResult, _is_transient, _run_loop_with_auto_resume


class FakeUI:
def __init__(self):
self.reconnecting: list[tuple[int, int, float]] = []
self.reconnected: int = 0
self.errors: list[str] = []
self.warnings: list[str] = []

def on_reconnecting(self, attempt: int, max_attempts: int, backoff_s: float) -> None:
self.reconnecting.append((attempt, max_attempts, backoff_s))

def on_reconnected(self) -> None:
self.reconnected += 1

def on_error(self, message: str) -> None:
self.errors.append(message)

def on_warning(self, message: str) -> None:
self.warnings.append(message)


def _make_result(reason: str, *, status: str = "error", success: bool = False, final_step: int = 0) -> OptimizationResult:
return OptimizationResult(success=success, final_step=final_step, status=status, reason=reason)


@pytest.fixture
def stub_sleep(monkeypatch):
sleeps: list[float] = []
monkeypatch.setattr("weco.optimizer.time.sleep", lambda s: sleeps.append(s))
return sleeps


@pytest.fixture
def stub_resume(monkeypatch):
calls: list[str] = []

def _install(outcomes: list[bool]):
iterator = iter(outcomes)

def _fake(run_id, auth_headers, api_keys):
calls.append(run_id)
return next(iterator)

monkeypatch.setattr("weco.optimizer._silent_resume", _fake)
return calls

return _install


def _drive(results: list[OptimizationResult], *, policy: AutoResumePolicy | None = None, initial_start_step: int = 0):
factory = MagicMock(side_effect=results)
ui = FakeUI()
returned = _run_loop_with_auto_resume(
factory,
ui=ui,
run_id="run-1",
auth_headers={},
api_keys=None,
policy=policy or AutoResumePolicy(),
initial_start_step=initial_start_step,
)
return returned, factory, ui


@pytest.mark.parametrize(
"reason,transient",
[
("transient_network_error", True),
("http_502", True),
("http_503", True),
("http_504", True),
("http_401", False),
("http_402", False),
("http_500", False),
("user_terminated_sigint", False),
("completed_successfully", False),
("user_requested_stop", False),
("timeout_waiting_for_tasks", False),
("unknown", False),
],
)
def test_classifies_transient_reasons(reason, transient):
assert _is_transient(_make_result(reason)) is transient


def test_returns_verbatim_when_not_transient(stub_sleep, stub_resume):
stub_resume([])
completed = _make_result("completed_successfully", status="completed", success=True, final_step=7)

returned, factory, ui = _drive([completed])

assert returned is completed
assert factory.call_count == 1
assert stub_sleep == []
assert ui.reconnecting == []
assert ui.reconnected == 0
assert ui.errors == []


def test_resumes_once_then_continues_from_final_step(stub_sleep, stub_resume):
resume_calls = stub_resume([True])
transient = _make_result("transient_network_error", final_step=4)
completed = _make_result("completed_successfully", status="completed", success=True, final_step=9)

returned, factory, ui = _drive([transient, completed])

assert returned is completed
assert factory.call_count == 2
assert factory.call_args_list[0].args == (0,)
assert factory.call_args_list[1].args == (4,)
assert resume_calls == ["run-1"]
assert len(stub_sleep) == 1
assert len(ui.reconnecting) == 1
assert ui.reconnected == 1
assert ui.errors == []


def test_exhausts_after_max_attempts_and_returns_original_result(stub_sleep, stub_resume):
resume_calls = stub_resume([True, True, True])
transient = _make_result("transient_network_error", final_step=2)
policy = AutoResumePolicy(max_attempts=3)

returned, factory, ui = _drive([transient, transient, transient, transient], policy=policy)

assert returned.reason == "transient_network_error"
assert factory.call_count == 4
assert len(resume_calls) == 3
assert len(ui.reconnecting) == 3
assert ui.reconnected == 3
assert len(ui.errors) == 1
assert "exhausted after 3" in ui.errors[0]


def test_disabled_policy_skips_resume_on_transient(stub_sleep, stub_resume):
resume_calls = stub_resume([])
transient = _make_result("transient_network_error", final_step=2)

returned, factory, ui = _drive([transient], policy=AutoResumePolicy(enabled=False))

assert returned is transient
assert factory.call_count == 1
assert resume_calls == []
assert stub_sleep == []
assert ui.reconnecting == []
assert ui.reconnected == 0
assert ui.errors == []


def test_silent_resume_failure_retries_without_reinvoking_loop(stub_sleep, stub_resume):
resume_calls = stub_resume([False, True])
transient = _make_result("transient_network_error", final_step=3)
completed = _make_result("completed_successfully", status="completed", success=True)

returned, factory, ui = _drive([transient, completed], policy=AutoResumePolicy(max_attempts=3))

assert returned is completed
assert factory.call_count == 2
assert len(resume_calls) == 2
assert len(stub_sleep) == 2
assert len(ui.reconnecting) == 2
assert ui.reconnected == 1
assert ui.errors == []


def test_silent_resume_exhaustion_without_reinvoking_loop(stub_sleep, stub_resume):
resume_calls = stub_resume([False, False])
transient = _make_result("transient_network_error", final_step=3)

returned, factory, ui = _drive([transient], policy=AutoResumePolicy(max_attempts=2))

assert returned is transient
assert factory.call_count == 1
assert len(resume_calls) == 2
assert len(ui.reconnecting) == 2
assert ui.reconnected == 0
assert len(ui.errors) == 1
assert "exhausted after 2" in ui.errors[0]


def test_backoff_is_exponential_and_capped(stub_sleep, stub_resume):
stub_resume([True, True, True, True, True])
transient = _make_result("transient_network_error")
completed = _make_result("completed_successfully", status="completed", success=True)
policy = AutoResumePolicy(max_attempts=5, backoff_initial_s=1.0, backoff_factor=2.0, backoff_max_s=5.0)

_drive([transient, transient, transient, transient, completed], policy=policy)

assert stub_sleep == [1.0, 2.0, 4.0, 5.0]


def test_keyboard_interrupt_result_propagates_untouched(stub_sleep, stub_resume):
resume_calls = stub_resume([])
interrupted = _make_result("user_terminated_sigint", status="terminated")

returned, factory, ui = _drive([interrupted])

assert returned is interrupted
assert factory.call_count == 1
assert resume_calls == []
assert stub_sleep == []
assert ui.reconnecting == []
assert ui.reconnected == 0
assert ui.errors == []


def test_reconnecting_event_carries_attempt_and_backoff(stub_sleep, stub_resume):
stub_resume([True])
transient = _make_result("transient_network_error", final_step=2)
completed = _make_result("completed_successfully", status="completed", success=True)
policy = AutoResumePolicy(max_attempts=5, backoff_initial_s=3.0, backoff_factor=2.0, backoff_max_s=30.0)

_, _, ui = _drive([transient, completed], policy=policy)

assert ui.reconnecting == [(1, 5, 3.0)]
36 changes: 34 additions & 2 deletions weco/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@ def configure_run_parser(run_parser: argparse.ArgumentParser) -> None:
help="Output mode: 'rich' for interactive terminal UI (default), 'plain' for machine-readable text output suitable for LLM agents.",
)
run_parser.add_argument("--submit-timeout", type=int, default=None, help=argparse.SUPPRESS)
run_parser.add_argument(
"--no-auto-resume",
action="store_true",
help="Disable automatic reconnection/resume on transient network errors (default: enabled).",
)
run_parser.add_argument(
"--auto-resume-max-attempts",
type=int,
default=5,
help="Max auto-resume attempts before giving up and printing the manual resume command (default: 5).",
)

# --- Eval backend integration ---
run_parser.add_argument(
Expand Down Expand Up @@ -370,6 +381,17 @@ def configure_resume_parser(resume_parser: argparse.ArgumentParser) -> None:
help="Output mode: 'rich' for interactive terminal UI (default), 'plain' for machine-readable text output suitable for LLM agents.",
)
resume_parser.add_argument("--submit-timeout", type=int, default=None, help=argparse.SUPPRESS)
resume_parser.add_argument(
"--no-auto-resume",
action="store_true",
help="Disable automatic reconnection/resume on transient network errors (default: enabled).",
)
resume_parser.add_argument(
"--auto-resume-max-attempts",
type=int,
default=5,
help="Max auto-resume attempts before giving up and printing the manual resume command (default: 5).",
)


def _dispatch_run_subcommand(sub: str, args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -431,7 +453,7 @@ def _collect_source_paths() -> list[str] | None:

def execute_run_command(args: argparse.Namespace) -> None:
"""Execute the 'weco run' command with all its logic."""
from .optimizer import optimize
from .optimizer import AutoResumePolicy, optimize

ctx = get_event_context()

Expand Down Expand Up @@ -505,6 +527,10 @@ def execute_run_command(args: argparse.Namespace) -> None:
ctx,
)

auto_resume_policy = AutoResumePolicy(
enabled=not getattr(args, "no_auto_resume", False), max_attempts=getattr(args, "auto_resume_max_attempts", 5)
)

success = optimize(
source=source_arg,
eval_command=args.eval_command,
Expand All @@ -521,6 +547,7 @@ def execute_run_command(args: argparse.Namespace) -> None:
require_review=args.require_review,
output_mode=args.output,
submit_timeout=getattr(args, "submit_timeout", None),
auto_resume_policy=auto_resume_policy,
)

exit_code = 0 if success else 1
Expand All @@ -529,20 +556,25 @@ def execute_run_command(args: argparse.Namespace) -> None:

def execute_resume_command(args: argparse.Namespace) -> None:
"""Execute the 'weco resume' command with all its logic."""
from .optimizer import resume_optimization
from .optimizer import AutoResumePolicy, resume_optimization

try:
api_keys = parse_api_keys(args.api_key)
except ValueError as e:
console.print(f"[bold red]Error parsing API keys: {e}[/]")
sys.exit(1)

auto_resume_policy = AutoResumePolicy(
enabled=not getattr(args, "no_auto_resume", False), max_attempts=getattr(args, "auto_resume_max_attempts", 5)
)

success = resume_optimization(
run_id=args.run_id,
api_keys=api_keys,
apply_change=args.apply_change,
output_mode=args.output,
submit_timeout=getattr(args, "submit_timeout", None),
auto_resume_policy=auto_resume_policy,
)

sys.exit(0 if success else 1)
Expand Down
Loading
Loading