Skip to content

Commit 6a4dc2d

Browse files
authored
Merge pull request #142 from WecoAI/dev
Add auto-resume feature and bump version (0.3.31)
2 parents 491dd31 + 6b3ee46 commit 6a4dc2d

7 files changed

Lines changed: 429 additions & 25 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ name = "weco"
88
authors = [{ name = "Weco AI Team", email = "contact@weco.ai" }]
99
description = "Documentation for `weco`, a CLI for using Weco AI's code optimizer."
1010
readme = "README.md"
11-
version = "0.3.30"
11+
version = "0.3.31"
1212
license = { file = "LICENSE" }
1313
requires-python = ">=3.9"
1414
dependencies = [

tests/test_auto_resume.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
from weco.optimizer import AutoResumePolicy, OptimizationResult, _is_transient, _run_loop_with_auto_resume
6+
7+
8+
class FakeUI:
9+
def __init__(self):
10+
self.reconnecting: list[tuple[int, int, float]] = []
11+
self.reconnected: int = 0
12+
self.errors: list[str] = []
13+
self.warnings: list[str] = []
14+
15+
def on_reconnecting(self, attempt: int, max_attempts: int, backoff_s: float) -> None:
16+
self.reconnecting.append((attempt, max_attempts, backoff_s))
17+
18+
def on_reconnected(self) -> None:
19+
self.reconnected += 1
20+
21+
def on_error(self, message: str) -> None:
22+
self.errors.append(message)
23+
24+
def on_warning(self, message: str) -> None:
25+
self.warnings.append(message)
26+
27+
28+
def _make_result(reason: str, *, status: str = "error", success: bool = False, final_step: int = 0) -> OptimizationResult:
29+
return OptimizationResult(success=success, final_step=final_step, status=status, reason=reason)
30+
31+
32+
@pytest.fixture
33+
def stub_sleep(monkeypatch):
34+
sleeps: list[float] = []
35+
monkeypatch.setattr("weco.optimizer.time.sleep", lambda s: sleeps.append(s))
36+
return sleeps
37+
38+
39+
@pytest.fixture
40+
def stub_resume(monkeypatch):
41+
calls: list[str] = []
42+
43+
def _install(outcomes: list[bool]):
44+
iterator = iter(outcomes)
45+
46+
def _fake(run_id, auth_headers, api_keys):
47+
calls.append(run_id)
48+
return next(iterator)
49+
50+
monkeypatch.setattr("weco.optimizer._silent_resume", _fake)
51+
return calls
52+
53+
return _install
54+
55+
56+
def _drive(results: list[OptimizationResult], *, policy: AutoResumePolicy | None = None, initial_start_step: int = 0):
57+
factory = MagicMock(side_effect=results)
58+
ui = FakeUI()
59+
returned = _run_loop_with_auto_resume(
60+
factory,
61+
ui=ui,
62+
run_id="run-1",
63+
auth_headers={},
64+
api_keys=None,
65+
policy=policy or AutoResumePolicy(),
66+
initial_start_step=initial_start_step,
67+
)
68+
return returned, factory, ui
69+
70+
71+
@pytest.mark.parametrize(
72+
"reason,transient",
73+
[
74+
("transient_network_error", True),
75+
("http_502", True),
76+
("http_503", True),
77+
("http_504", True),
78+
("http_401", False),
79+
("http_402", False),
80+
("http_500", False),
81+
("user_terminated_sigint", False),
82+
("completed_successfully", False),
83+
("user_requested_stop", False),
84+
("timeout_waiting_for_tasks", False),
85+
("unknown", False),
86+
],
87+
)
88+
def test_classifies_transient_reasons(reason, transient):
89+
assert _is_transient(_make_result(reason)) is transient
90+
91+
92+
def test_returns_verbatim_when_not_transient(stub_sleep, stub_resume):
93+
stub_resume([])
94+
completed = _make_result("completed_successfully", status="completed", success=True, final_step=7)
95+
96+
returned, factory, ui = _drive([completed])
97+
98+
assert returned is completed
99+
assert factory.call_count == 1
100+
assert stub_sleep == []
101+
assert ui.reconnecting == []
102+
assert ui.reconnected == 0
103+
assert ui.errors == []
104+
105+
106+
def test_resumes_once_then_continues_from_final_step(stub_sleep, stub_resume):
107+
resume_calls = stub_resume([True])
108+
transient = _make_result("transient_network_error", final_step=4)
109+
completed = _make_result("completed_successfully", status="completed", success=True, final_step=9)
110+
111+
returned, factory, ui = _drive([transient, completed])
112+
113+
assert returned is completed
114+
assert factory.call_count == 2
115+
assert factory.call_args_list[0].args == (0,)
116+
assert factory.call_args_list[1].args == (4,)
117+
assert resume_calls == ["run-1"]
118+
assert len(stub_sleep) == 1
119+
assert len(ui.reconnecting) == 1
120+
assert ui.reconnected == 1
121+
assert ui.errors == []
122+
123+
124+
def test_exhausts_after_max_attempts_and_returns_original_result(stub_sleep, stub_resume):
125+
resume_calls = stub_resume([True, True, True])
126+
transient = _make_result("transient_network_error", final_step=2)
127+
policy = AutoResumePolicy(max_attempts=3)
128+
129+
returned, factory, ui = _drive([transient, transient, transient, transient], policy=policy)
130+
131+
assert returned.reason == "transient_network_error"
132+
assert factory.call_count == 4
133+
assert len(resume_calls) == 3
134+
assert len(ui.reconnecting) == 3
135+
assert ui.reconnected == 3
136+
assert len(ui.errors) == 1
137+
assert "exhausted after 3" in ui.errors[0]
138+
139+
140+
def test_disabled_policy_skips_resume_on_transient(stub_sleep, stub_resume):
141+
resume_calls = stub_resume([])
142+
transient = _make_result("transient_network_error", final_step=2)
143+
144+
returned, factory, ui = _drive([transient], policy=AutoResumePolicy(enabled=False))
145+
146+
assert returned is transient
147+
assert factory.call_count == 1
148+
assert resume_calls == []
149+
assert stub_sleep == []
150+
assert ui.reconnecting == []
151+
assert ui.reconnected == 0
152+
assert ui.errors == []
153+
154+
155+
def test_silent_resume_failure_retries_without_reinvoking_loop(stub_sleep, stub_resume):
156+
resume_calls = stub_resume([False, True])
157+
transient = _make_result("transient_network_error", final_step=3)
158+
completed = _make_result("completed_successfully", status="completed", success=True)
159+
160+
returned, factory, ui = _drive([transient, completed], policy=AutoResumePolicy(max_attempts=3))
161+
162+
assert returned is completed
163+
assert factory.call_count == 2
164+
assert len(resume_calls) == 2
165+
assert len(stub_sleep) == 2
166+
assert len(ui.reconnecting) == 2
167+
assert ui.reconnected == 1
168+
assert ui.errors == []
169+
170+
171+
def test_silent_resume_exhaustion_without_reinvoking_loop(stub_sleep, stub_resume):
172+
resume_calls = stub_resume([False, False])
173+
transient = _make_result("transient_network_error", final_step=3)
174+
175+
returned, factory, ui = _drive([transient], policy=AutoResumePolicy(max_attempts=2))
176+
177+
assert returned is transient
178+
assert factory.call_count == 1
179+
assert len(resume_calls) == 2
180+
assert len(ui.reconnecting) == 2
181+
assert ui.reconnected == 0
182+
assert len(ui.errors) == 1
183+
assert "exhausted after 2" in ui.errors[0]
184+
185+
186+
def test_backoff_is_exponential_and_capped(stub_sleep, stub_resume):
187+
stub_resume([True, True, True, True, True])
188+
transient = _make_result("transient_network_error")
189+
completed = _make_result("completed_successfully", status="completed", success=True)
190+
policy = AutoResumePolicy(max_attempts=5, backoff_initial_s=1.0, backoff_factor=2.0, backoff_max_s=5.0)
191+
192+
_drive([transient, transient, transient, transient, completed], policy=policy)
193+
194+
assert stub_sleep == [1.0, 2.0, 4.0, 5.0]
195+
196+
197+
def test_keyboard_interrupt_result_propagates_untouched(stub_sleep, stub_resume):
198+
resume_calls = stub_resume([])
199+
interrupted = _make_result("user_terminated_sigint", status="terminated")
200+
201+
returned, factory, ui = _drive([interrupted])
202+
203+
assert returned is interrupted
204+
assert factory.call_count == 1
205+
assert resume_calls == []
206+
assert stub_sleep == []
207+
assert ui.reconnecting == []
208+
assert ui.reconnected == 0
209+
assert ui.errors == []
210+
211+
212+
def test_reconnecting_event_carries_attempt_and_backoff(stub_sleep, stub_resume):
213+
stub_resume([True])
214+
transient = _make_result("transient_network_error", final_step=2)
215+
completed = _make_result("completed_successfully", status="completed", success=True)
216+
policy = AutoResumePolicy(max_attempts=5, backoff_initial_s=3.0, backoff_factor=2.0, backoff_max_s=30.0)
217+
218+
_, _, ui = _drive([transient, completed], policy=policy)
219+
220+
assert ui.reconnecting == [(1, 5, 3.0)]

weco/cli.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,17 @@ def configure_run_parser(run_parser: argparse.ArgumentParser) -> None:
165165
help="Output mode: 'rich' for interactive terminal UI (default), 'plain' for machine-readable text output suitable for LLM agents.",
166166
)
167167
run_parser.add_argument("--submit-timeout", type=int, default=None, help=argparse.SUPPRESS)
168+
run_parser.add_argument(
169+
"--no-auto-resume",
170+
action="store_true",
171+
help="Disable automatic reconnection/resume on transient network errors (default: enabled).",
172+
)
173+
run_parser.add_argument(
174+
"--auto-resume-max-attempts",
175+
type=int,
176+
default=5,
177+
help="Max auto-resume attempts before giving up and printing the manual resume command (default: 5).",
178+
)
168179

169180
# --- Eval backend integration ---
170181
run_parser.add_argument(
@@ -370,6 +381,17 @@ def configure_resume_parser(resume_parser: argparse.ArgumentParser) -> None:
370381
help="Output mode: 'rich' for interactive terminal UI (default), 'plain' for machine-readable text output suitable for LLM agents.",
371382
)
372383
resume_parser.add_argument("--submit-timeout", type=int, default=None, help=argparse.SUPPRESS)
384+
resume_parser.add_argument(
385+
"--no-auto-resume",
386+
action="store_true",
387+
help="Disable automatic reconnection/resume on transient network errors (default: enabled).",
388+
)
389+
resume_parser.add_argument(
390+
"--auto-resume-max-attempts",
391+
type=int,
392+
default=5,
393+
help="Max auto-resume attempts before giving up and printing the manual resume command (default: 5).",
394+
)
373395

374396

375397
def _dispatch_run_subcommand(sub: str, args: argparse.Namespace) -> None:
@@ -431,7 +453,7 @@ def _collect_source_paths() -> list[str] | None:
431453

432454
def execute_run_command(args: argparse.Namespace) -> None:
433455
"""Execute the 'weco run' command with all its logic."""
434-
from .optimizer import optimize
456+
from .optimizer import AutoResumePolicy, optimize
435457

436458
ctx = get_event_context()
437459

@@ -505,6 +527,10 @@ def execute_run_command(args: argparse.Namespace) -> None:
505527
ctx,
506528
)
507529

530+
auto_resume_policy = AutoResumePolicy(
531+
enabled=not getattr(args, "no_auto_resume", False), max_attempts=getattr(args, "auto_resume_max_attempts", 5)
532+
)
533+
508534
success = optimize(
509535
source=source_arg,
510536
eval_command=args.eval_command,
@@ -521,6 +547,7 @@ def execute_run_command(args: argparse.Namespace) -> None:
521547
require_review=args.require_review,
522548
output_mode=args.output,
523549
submit_timeout=getattr(args, "submit_timeout", None),
550+
auto_resume_policy=auto_resume_policy,
524551
)
525552

526553
exit_code = 0 if success else 1
@@ -529,20 +556,25 @@ def execute_run_command(args: argparse.Namespace) -> None:
529556

530557
def execute_resume_command(args: argparse.Namespace) -> None:
531558
"""Execute the 'weco resume' command with all its logic."""
532-
from .optimizer import resume_optimization
559+
from .optimizer import AutoResumePolicy, resume_optimization
533560

534561
try:
535562
api_keys = parse_api_keys(args.api_key)
536563
except ValueError as e:
537564
console.print(f"[bold red]Error parsing API keys: {e}[/]")
538565
sys.exit(1)
539566

567+
auto_resume_policy = AutoResumePolicy(
568+
enabled=not getattr(args, "no_auto_resume", False), max_attempts=getattr(args, "auto_resume_max_attempts", 5)
569+
)
570+
540571
success = resume_optimization(
541572
run_id=args.run_id,
542573
api_keys=api_keys,
543574
apply_change=args.apply_change,
544575
output_mode=args.output,
545576
submit_timeout=getattr(args, "submit_timeout", None),
577+
auto_resume_policy=auto_resume_policy,
546578
)
547579

548580
sys.exit(0 if success else 1)

0 commit comments

Comments
 (0)