Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from agents.run import DEFAULT_MAX_TURNS
from dotenv import find_dotenv, load_dotenv
from openai import AsyncOpenAI
import httpx

from .capi import get_AI_endpoint, get_AI_token, get_provider

Expand Down Expand Up @@ -182,6 +183,7 @@ def __init__(
base_url=resolved_endpoint,
api_key=resolved_token,
default_headers=provider.extra_headers or None,
timeout=httpx.Timeout(connect=10.0, read=300.0, write=300.0, pool=60.0),
)
set_tracing_disabled(True)
self.run_hooks = run_hooks or TaskRunHooks()
Expand All @@ -198,6 +200,7 @@ def _ToolsToFinalOutputFunction(
else:
model_impl = OpenAIChatCompletionsModel(model=model, openai_client=client)

self._openai_client = client
self.agent = Agent(
name=name,
instructions=instructions,
Expand All @@ -209,6 +212,11 @@ def _ToolsToFinalOutputFunction(
hooks=agent_hooks or TaskAgentHooks(),
)

async def close(self) -> None:
"""Close the underlying AsyncOpenAI client and its httpx connection pool."""
if self._openai_client is not None:
await self._openai_client.close()

async def run(self, prompt: str, max_turns: int = DEFAULT_MAX_TURNS) -> result.RunResult:
"""Run the agent to completion and return the result."""
return await Runner.run(starting_agent=self.agent, input=prompt, max_turns=max_turns, hooks=self.run_hooks)
Expand Down
180 changes: 155 additions & 25 deletions src/seclab_taskflow_agent/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import json
import logging
import os
import sys
import threading
import time
import uuid
from typing import Any

Expand Down Expand Up @@ -51,6 +54,57 @@
MAX_API_RETRY = 5 # Maximum number of consecutive API error retries
TASK_RETRY_LIMIT = 3 # Maximum retry attempts for a failed task
TASK_RETRY_BACKOFF = 10 # Initial backoff in seconds between task retries
# Application-level backstop: kill a streaming run if no events yielded for 30 min.
# Complements the TCP-level httpx.Timeout(read=300s) in agent.py which catches
# dead sockets; this catches subtler hangs where the connection stays open but
# the server (or async generator) stops producing events.
STREAM_IDLE_TIMEOUT = 1800

# Watchdog: a non-asyncio thread that force-kills the process if the event
# loop stops making progress. Covers every hang variant (dead connections,
# asyncio cleanup spin, MCP cleanup, etc.) because it runs outside asyncio.
WATCHDOG_IDLE_TIMEOUT = int(os.environ.get("WATCHDOG_IDLE_TIMEOUT", "2100")) # 35 min default

_watchdog_last_activity = time.monotonic()
_watchdog_lock = threading.Lock()


def watchdog_ping() -> None:
"""Call from any coroutine/callback to signal the process is alive."""
global _watchdog_last_activity
with _watchdog_lock:
_watchdog_last_activity = time.monotonic()


def _watchdog_thread(timeout: int) -> None:
"""Background thread: force-exit if no activity for *timeout* seconds."""
check_interval = min(60, max(1, timeout // 5))
while True:
time.sleep(check_interval)
with _watchdog_lock:
idle = time.monotonic() - _watchdog_last_activity
if idle > timeout:
logging.error(
f"Watchdog: no activity for {idle:.0f}s (limit {timeout}s) — "
"force-exiting to prevent hang"
)
sys.stderr.flush()
sys.stdout.flush()
os._exit(2)


_watchdog_started = False


def start_watchdog(timeout: int = WATCHDOG_IDLE_TIMEOUT) -> None:
"""Start the watchdog thread (idempotent, daemon thread)."""
global _watchdog_started
if _watchdog_started:
return
_watchdog_started = True

Check notice

Code scanning / CodeQL

Unused global variable Note

The global variable '_watchdog_started' is not used.
watchdog_ping() # reset timestamp so late callers don't trigger immediately
t = threading.Thread(target=_watchdog_thread, args=(timeout,), daemon=True)
t.start()
Comment on lines +99 to +107
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_watchdog is documented as idempotent, but the current implementation always spawns a new daemon thread on every call. Also, _watchdog_last_activity is initialized at module import time and start_watchdog() does not reset it, so if this module is imported and start_watchdog() is invoked later than timeout seconds, the watchdog can force-exit almost immediately. Consider (1) tracking a module-level started flag/thread to make this truly idempotent, and (2) calling watchdog_ping() (or otherwise resetting the timestamp) inside start_watchdog() before starting the thread.

Copilot uses AI. Check for mistakes.


def _resolve_model_config(
Expand Down Expand Up @@ -321,6 +375,9 @@
await servers_connected.wait()
logging.debug("All mcp servers are connected!")

agent0: TaskAgent | None = None
handoff_agents: list[TaskAgent] = []

try:
important_guidelines = [
"Do not prompt the user with questions.",
Expand All @@ -334,29 +391,29 @@
agent_names = list(agents.keys())
for handoff_name in agent_names[1:]:
personality = agents[handoff_name]
handoffs.append(
TaskAgent(
name=compress_name(handoff_name),
instructions=prompt_with_handoff_instructions(
mcp_system_prompt(
personality.personality,
personality.task,
server_prompts=server_prompts,
important_guidelines=important_guidelines,
)
),
handoffs=[],
exclude_from_context=exclude_from_context,
mcp_servers=[e.server for e in entries],
model=model,
model_settings=model_settings,
api_type=api_type,
endpoint=endpoint,
token=token,
run_hooks=run_hooks,
agent_hooks=agent_hooks,
).agent
ta = TaskAgent(
name=compress_name(handoff_name),
instructions=prompt_with_handoff_instructions(
mcp_system_prompt(
personality.personality,
personality.task,
server_prompts=server_prompts,
important_guidelines=important_guidelines,
)
),
handoffs=[],
exclude_from_context=exclude_from_context,
mcp_servers=[e.server for e in entries],
model=model,
model_settings=model_settings,
api_type=api_type,
endpoint=endpoint,
token=token,
run_hooks=run_hooks,
agent_hooks=agent_hooks,
)
handoff_agents.append(ta)
handoffs.append(ta.agent)

# Create primary agent
primary_name = agent_names[0]
Expand Down Expand Up @@ -389,11 +446,44 @@
max_retry = MAX_API_RETRY
rate_limit_backoff = RATE_LIMIT_BACKOFF
while rate_limit_backoff:
result = None
try:
result = agent0.run_streamed(prompt, max_turns=max_turns)
async for event in result.stream_events():
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
await render_model_output(event.data.delta, async_task=async_task, task_id=task_id)
stream = None
try:
stream = result.stream_events()
async_iter = stream.__aiter__()
while True:
try:
event = await asyncio.wait_for(
async_iter.__anext__(),
timeout=STREAM_IDLE_TIMEOUT,
)
except StopAsyncIteration:
break
except asyncio.TimeoutError:
logging.error(
f"Stream idle for {STREAM_IDLE_TIMEOUT}s — "
"connection likely dead, raising APITimeoutError"
)
raise APITimeoutError("Stream idle timeout exceeded")
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
watchdog_ping()
await render_model_output(event.data.delta, async_task=async_task, task_id=task_id)
finally:
if stream is not None:
aclose = getattr(stream, "aclose", None)
if aclose is not None:
try:
await aclose()
except Exception:
logging.exception("Failed to close streamed response")
# Cancel the RunResultStreaming background tasks.
# aclose() on the stream_events() async generator throws
# GeneratorExit which skips _cleanup_tasks(), so we must
# cancel explicitly to avoid leaking _run_impl_task.
if result is not None:
result.cancel()
await render_model_output("\n\n", async_task=async_task, task_id=task_id)
return
except APITimeoutError:
Expand Down Expand Up @@ -433,6 +523,19 @@
return complete

finally:
# Close all AsyncOpenAI clients to release httpx connection pools.
# Dead CLOSE_WAIT sockets in the pool cause kqueue CPU spin if left open.
watchdog_ping()
for ta in handoff_agents:
try:
await ta.close()
except Exception:
logging.exception("Exception closing handoff agent client")
if agent0 is not None:
try:
await agent0.close()
except Exception:
logging.exception("Exception closing primary agent client")
start_cleanup.set()
cleanup_attempts_left = len(entries)
while cleanup_attempts_left and entries:
Expand All @@ -443,6 +546,21 @@
continue
except Exception:
logging.exception("Exception in mcp server cleanup task")
# Cancel the MCP session task if it's still running to prevent
# the asyncio event loop from spinning on a dangling task.
if not mcp_sessions.done():
mcp_sessions.cancel()
try:
await asyncio.wait_for(mcp_sessions, timeout=MCP_CLEANUP_TIMEOUT)
except asyncio.TimeoutError:
logging.warning(
"Timed out waiting for MCP session task cancellation after %s seconds",
MCP_CLEANUP_TIMEOUT,
)
except asyncio.CancelledError:
pass
except Exception:
logging.exception("Exception while waiting for MCP session task cancellation")


async def run_main(
Expand All @@ -465,12 +583,18 @@
"""
from .session import TaskflowSession

# Start the watchdog thread — if the process hangs for any reason
# (asyncio spin, dead connections, MCP cleanup), this kills it.
start_watchdog()

last_mcp_tool_results: list[str] = []

async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str) -> None:
watchdog_ping()
last_mcp_tool_results.append(result)

async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None:
watchdog_ping()
await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n")

async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]) -> None:
Expand Down Expand Up @@ -731,3 +855,9 @@
if session is not None and not session.error:
session.mark_finished()
await render_model_output(f"** 🤖✅ Session {session.session_id} completed\n")

# Force-exit to prevent asyncio event loop spin on dangling
# tasks/connections from the responses API path. Flush first.
sys.stdout.flush()
sys.stderr.flush()
os._exit(0 if (session is None or session.finished) else 1)
Loading
Loading