Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "weco"
authors = [{ name = "Weco AI Team", email = "contact@weco.ai" }]
description = "Documentation for `weco`, a CLI for using Weco AI's code optimizer."
readme = "README.md"
version = "0.3.27"
version = "0.3.28"
license = { file = "LICENSE" }
requires-python = ">=3.9"
dependencies = [
Expand Down
24 changes: 17 additions & 7 deletions weco/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
WecoClient,
RunSummary,
ExecutionTasksResult,
format_api_error,
handle_api_error,
_truncate_output,
)
Expand Down Expand Up @@ -189,15 +190,24 @@ def submit_execution_result(
task_id: str,
execution_output: str,
auth_headers: dict = {},
timeout: Union[int, Tuple[int, int]] = (10, 3650),
timeout: Optional[Union[int, Tuple[int, int]]] = None,
api_keys: Optional[Dict[str, str]] = None,
) -> Optional[Dict[str, Any]]:
"""Submit execution result for a task."""
) -> Dict[str, Any]:
"""Submit execution result for a task.

Args:
timeout: Optional override for the HTTP ``(connect, read)`` timeout.
``None`` keeps the existing default of ``(10, 3650)`` so callers
that don't opt in see no behavior change.

Raises:
requests.exceptions.HTTPError: On non-2xx responses (e.g. 402 insufficient
credits, 503 candidate generation failed). Callers should format the
error via :func:`format_api_error` and surface it through the UI.
requests.exceptions.RequestException: On network errors.
"""
client = WecoClient(auth_headers)
try:
return client.suggest(run_id, execution_output=execution_output, task_id=task_id, api_keys=api_keys)
except Exception:
return None
return client.suggest(run_id, execution_output=execution_output, task_id=task_id, api_keys=api_keys, timeout=timeout)


# --- Share API Functions ---
Expand Down
9 changes: 8 additions & 1 deletion weco/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def configure_run_parser(run_parser: argparse.ArgumentParser) -> None:
default="rich",
help="Output mode: 'rich' for interactive terminal UI (default), 'plain' for machine-readable text output suitable for LLM agents.",
)
run_parser.add_argument("--submit-timeout", type=int, default=None, help=argparse.SUPPRESS)

# --- Eval backend integration ---
run_parser.add_argument(
Expand Down Expand Up @@ -344,6 +345,7 @@ def configure_resume_parser(resume_parser: argparse.ArgumentParser) -> None:
default="rich",
help="Output mode: 'rich' for interactive terminal UI (default), 'plain' for machine-readable text output suitable for LLM agents.",
)
resume_parser.add_argument("--submit-timeout", type=int, default=None, help=argparse.SUPPRESS)


def _dispatch_run_subcommand(sub: str, args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -480,6 +482,7 @@ def execute_run_command(args: argparse.Namespace) -> None:
apply_change=args.apply_change,
require_review=args.require_review,
output_mode=args.output,
submit_timeout=getattr(args, "submit_timeout", None),
)

exit_code = 0 if success else 1
Expand All @@ -497,7 +500,11 @@ def execute_resume_command(args: argparse.Namespace) -> None:
sys.exit(1)

success = resume_optimization(
run_id=args.run_id, api_keys=api_keys, apply_change=args.apply_change, output_mode=args.output
run_id=args.run_id,
api_keys=api_keys,
apply_change=args.apply_change,
output_mode=args.output,
submit_timeout=getattr(args, "submit_timeout", None),
)

sys.exit(0 if success else 1)
Expand Down
130 changes: 126 additions & 4 deletions weco/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,50 @@ def _truncate_output(output: str) -> str:
return f"{first}\n ... [{truncated_len} characters truncated] ... \n{last}"


def format_api_error(e: requests.exceptions.HTTPError) -> str:
"""Extract API error details as a plain multi-line string.

Mirrors :func:`handle_api_error` but returns text instead of printing,
so it can be passed to UI handlers (e.g. ``ui.on_error``) that don't
expose a Rich console — the Rich Live panel and the plain-text UI both
consume errors as plain strings via the ``on_error`` protocol.
"""
status = getattr(e.response, "status_code", None)
try:
payload = e.response.json()
detail = payload.get("detail", payload)
except (ValueError, AttributeError):
return getattr(e.response, "text", "") or f"HTTP {status} Error"

def _format(detail_obj: Any) -> list[str]:
if isinstance(detail_obj, str):
return [detail_obj]
if isinstance(detail_obj, dict):
lines: list[str] = []
message_keys = ("message", "error", "msg", "detail")
message = next((detail_obj.get(key) for key in message_keys if detail_obj.get(key)), None)
lines.append(message or f"HTTP {status} Error")
suggestion = detail_obj.get("suggestion")
if suggestion:
lines.append(str(suggestion))
extras = {
k: v
for k, v in detail_obj.items()
if k not in {"message", "error", "msg", "detail", "suggestion"} and v not in (None, "")
}
for key, value in extras.items():
lines.append(f"{key}: {value}")
return lines
if isinstance(detail_obj, list) and detail_obj:
lines = list(_format(detail_obj[0]))
for extra in detail_obj[1:]:
lines.append(str(extra))
return lines
return [str(detail_obj) if detail_obj else f"HTTP {status} Error"]

return "\n".join(_format(detail))


def handle_api_error(e: requests.exceptions.HTTPError, console) -> None:
"""Extract and display error messages from API responses in a structured format."""
status = getattr(e.response, "status_code", None)
Expand Down Expand Up @@ -272,11 +316,20 @@ def suggest(
step: int | None = None,
task_id: str | None = None,
api_keys: dict[str, str] | None = None,
timeout: tuple[int, int] | int | None = None,
) -> dict:
"""``POST /runs/{run_id}/suggest`` — submit execution output, get next candidate.

If *step* is provided, transport errors (ReadTimeout, 502, ConnectionError)
trigger an automatic recovery attempt via ``get_run_status``.
If *step* is provided (legacy flow), transport errors (ReadTimeout, 502,
ConnectionError) trigger recovery via ``get_run_status``. If *task_id* is
provided (queue flow), recovery instead checks ``/execution-tasks/`` and
the run status, so a dropped response doesn't hang the CLI for up to
``timeout[1]`` seconds waiting on a socket the backend has already replied on.

Args:
timeout: Optional ``(connect, read)`` tuple or int override for the
HTTP request. Defaults to ``(10, 3650)`` to preserve existing
behavior; pass a smaller value to exercise the recovery path.

Raises:
requests.exceptions.HTTPError: On non-recoverable HTTP errors.
Expand All @@ -289,8 +342,10 @@ def suggest(
if api_keys:
body["api_keys"] = api_keys

request_timeout = timeout if timeout is not None else (10, 3650)

try:
resp = self._post(f"/runs/{run_id}/suggest", json=body, timeout=(10, 3650))
resp = self._post(f"/runs/{run_id}/suggest", json=body, timeout=request_timeout)
resp.raise_for_status()
result = resp.json()
if result.get("plan") is None:
Expand All @@ -303,12 +358,21 @@ def suggest(
recovered = self._recover_suggest(run_id, step)
if recovered is not None:
return recovered
elif task_id is not None:
recovered = self._recover_queue_suggest(run_id)
if recovered is not None:
return recovered
raise type(exc)(exc) from exc
except requests.exceptions.HTTPError as exc:
if step is not None and getattr(exc.response, "status_code", None) == 502:
status_code = getattr(exc.response, "status_code", None)
if step is not None and status_code == 502:
recovered = self._recover_suggest(run_id, step)
if recovered is not None:
return recovered
elif task_id is not None and status_code in (502, 503, 504):
recovered = self._recover_queue_suggest(run_id)
if recovered is not None:
return recovered
raise

def heartbeat(self, run_id: str) -> bool:
Expand Down Expand Up @@ -482,6 +546,64 @@ def log_external_step(
# Internal
# ------------------------------------------------------------------

def _recover_queue_suggest(self, run_id: str) -> dict | None:
"""Try to reconstruct a ``/suggest`` response for queue-mode clients.

Called after a transport error (ReadTimeout / 5xx / ConnectionError) when
a ``task_id`` was supplied. The backend marks the submitted task as
completed early in ``/suggest`` and, if everything succeeds, atomically
creates the next node + revision + execution task before returning.

If we can observe either (a) a ready execution task queued for this run,
or (b) the run transitioning to ``completed``, the submit effectively
landed — we can synthesize a success response and let the main loop
proceed to its next poll/claim iteration. Also recover the previous
step's metric from the run history so the UI's ``on_metric`` still fires
for the step whose response we missed. Otherwise return ``None`` and let
the caller surface the transport error.
"""
try:
run_data = self.get_run_status(run_id, include_history=True)
except Exception:
run_data = None

run_status = (run_data or {}).get("status")
if run_status in ("terminated", "error"):
return None

# Latest node that has an execution output is the one we just evaluated
# — its metric is the ``previous_solution_metric_value`` the caller expects.
previous_metric = None
if run_data is not None:
evaluated_nodes = [
n
for n in (run_data.get("nodes") or [])
if n.get("execution_output") is not None and n.get("metric_value") is not None
]
if evaluated_nodes:
latest_evaluated = max(evaluated_nodes, key=lambda n: n.get("step", 0))
previous_metric = latest_evaluated.get("metric_value")

def _with_metric(payload: dict) -> dict:
if previous_metric is not None:
payload["previous_solution_metric_value"] = previous_metric
return payload

if run_status == "completed":
return _with_metric({"run_id": run_id, "is_done": True})

# Verify the next candidate task is queued — that's the signal that the
# submit landed end-to-end (not just that the previous node was updated).
try:
tasks_result = self.get_execution_tasks(run_id)
except Exception:
tasks_result = None

if tasks_result is not None and tasks_result.tasks:
return _with_metric({"run_id": run_id, "is_done": False})

return None

def _recover_suggest(self, run_id: str, step: int) -> dict | None:
"""Try to reconstruct a ``/suggest`` response after a transport error.

Expand Down
45 changes: 33 additions & 12 deletions weco/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from dataclasses import dataclass
from typing import Optional

from requests.exceptions import HTTPError
from rich.console import Console
from rich.prompt import Confirm

from . import __dashboard_url__
from .api import (
claim_execution_task,
format_api_error,
get_execution_tasks,
get_optimization_run_status,
report_termination,
Expand Down Expand Up @@ -82,6 +84,7 @@ def _run_optimization_loop(
poll_interval: float = 2.0,
max_poll_attempts: int = 300,
api_keys: Optional[dict] = None,
submit_timeout: Optional[int] = None,
) -> OptimizationResult:
"""
Shared queue-based execution loop for optimize and resume.
Expand All @@ -103,6 +106,9 @@ def _run_optimization_loop(
poll_interval: Seconds between polling attempts.
max_poll_attempts: Max polls before timeout (~10 min with 2s interval).
api_keys: Optional API keys for LLM providers.
submit_timeout: Optional read-timeout override (seconds) for the
``/suggest`` call made when submitting a step's result. ``None``
preserves the existing ~61-minute default.

Returns:
OptimizationResult with success status and termination info.
Expand Down Expand Up @@ -189,22 +195,20 @@ def _run_optimization_loop(

ui.on_output(term_out)

# Submit result
# Submit result. HTTP errors (insufficient credits, candidate generation
# failures, etc.) propagate and are handled centrally below so the real
# backend detail reaches the user and the run's termination record.
ui.on_submitting()
submit_timeout_tuple = (10, submit_timeout) if submit_timeout is not None else None
result = submit_execution_result(
run_id=run_id, task_id=task_id, execution_output=term_out, auth_headers=auth_headers, api_keys=api_keys
run_id=run_id,
task_id=task_id,
execution_output=term_out,
auth_headers=auth_headers,
api_keys=api_keys,
timeout=submit_timeout_tuple,
)

if result is None:
ui.on_error("Failed to submit result")
return OptimizationResult(
success=False,
final_step=step,
status="error",
reason="submit_failed",
details="Failed to submit execution result",
)

is_done = result.get("is_done", False)
prev_metric = result.get("previous_solution_metric_value")

Expand All @@ -220,6 +224,19 @@ def _run_optimization_loop(
except KeyboardInterrupt:
ui.on_interrupted()
return OptimizationResult(success=False, final_step=step, status="terminated", reason="user_terminated_sigint")
except HTTPError as e:
# Surface structured API error details (insufficient credits, auth failures, candidate
# generation failures, etc.) through the UI rather than a generic exception string.
error_message = format_api_error(e)
ui.on_error(error_message)
status_code = getattr(e.response, "status_code", None)
return OptimizationResult(
success=False,
final_step=step,
status="error",
reason=f"http_{status_code}" if status_code else "http_error",
details=error_message,
)
except Exception as e:
ui.on_error(f"Error: {e}")
return OptimizationResult(success=False, final_step=step, status="error", reason="unknown", details=str(e))
Expand Down Expand Up @@ -300,6 +317,7 @@ def resume_optimization(
poll_interval: float = 2.0,
apply_change: bool = False,
output_mode: str = "rich",
submit_timeout: Optional[int] = None,
) -> bool:
"""
Resume an interrupted run using the queue-based optimization loop.
Expand Down Expand Up @@ -444,6 +462,7 @@ def resume_optimization(
start_step=current_step,
poll_interval=poll_interval,
api_keys=api_keys,
submit_timeout=submit_timeout,
)

# Stop heartbeat immediately after loop completes
Expand Down Expand Up @@ -503,6 +522,7 @@ def optimize(
apply_change: bool = False,
require_review: bool = False,
output_mode: str = "rich",
submit_timeout: Optional[int] = None,
) -> bool:
"""
Simplified queue-based optimization loop.
Expand Down Expand Up @@ -628,6 +648,7 @@ def optimize(
start_step=0,
poll_interval=poll_interval,
api_keys=api_keys,
submit_timeout=submit_timeout,
)

# Stop heartbeat immediately after loop completes
Expand Down
Loading