diff --git a/pyproject.toml b/pyproject.toml index f6de0e4..e561870 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,6 @@ ignore = [ "EM102", # Exception f-strings "G004", # Logging f-strings "T201", # print() used for user output - "TRY003", # Raise with inline message strings # Backwards-compatibility suppressions for existing code "A001", # Variable shadows built-in @@ -253,3 +252,5 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "tests/*" = ["S101", "PLR2004"] +# Vendored third-party code — do not modify to suppress linting issues +"src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/*" = ["TRY003"] diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index a7966f2..980fc1f 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -27,6 +27,7 @@ from openai import AsyncOpenAI from .capi import AI_API_ENDPOINT_ENUM, COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token +from .exceptions import TokenEnvVarNotSetError __all__ = [ "DEFAULT_MODEL", @@ -182,7 +183,7 @@ def __init__( if token: resolved_token = os.getenv(token, "") if not resolved_token: - raise RuntimeError(f"Token env var {token!r} is not set") + raise TokenEnvVarNotSetError(token) else: resolved_token = get_AI_token() diff --git a/src/seclab_taskflow_agent/available_tools.py b/src/seclab_taskflow_agent/available_tools.py index 577ae06..8e41663 100644 --- a/src/seclab_taskflow_agent/available_tools.py +++ b/src/seclab_taskflow_agent/available_tools.py @@ -40,6 +40,44 @@ class FileTypeException(Exception): pass +class InvalidToolFormatError(BadToolNameError): + def __init__(self, toolname: str) -> None: + super().__init__( + f'Not a valid toolname: "{toolname}". ' + f'Expected format: "packagename.filename"' + ) + + +class ToolDirNotFoundError(BadToolNameError): + def __init__(self, toolname: str, pkg_dir: object) -> None: + super().__init__(f"Cannot load {toolname} because {pkg_dir} is not a valid directory.") + + +class FiletypeMismatchError(FileTypeException): + def __init__(self, filepath: object, expected: str, got: str) -> None: + super().__init__(f"Error in {filepath}: expected filetype {expected!r}, got {got!r}.") + + +class UnknownFiletypeError(BadToolNameError): + def __init__(self, filetype: str, toolname: str) -> None: + super().__init__(f"Unknown filetype {filetype!r} in {toolname}") + + +class ToolValidationError(BadToolNameError): + def __init__(self, toolname: str, exc: Exception) -> None: + super().__init__(f"Validation error loading {toolname}: {exc}") + + +class ToolLoadError(BadToolNameError): + def __init__(self, toolname: str, exc: Exception) -> None: + super().__init__(f"Cannot load {toolname}: {exc}") + + +class ToolFileNotFoundError(BadToolNameError): + def __init__(self, toolname: str, filepath: object) -> None: + super().__init__(f"Cannot load {toolname} because {filepath} is not a valid file.") + + class AvailableToolType(Enum): Personality = "personality" Taskflow = "taskflow" @@ -108,18 +146,13 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: # Resolve package and filename from dotted path components = toolname.rsplit(".", 1) if len(components) != 2: - raise BadToolNameError( - f'Not a valid toolname: "{toolname}". ' - f'Expected format: "packagename.filename"' - ) + raise InvalidToolFormatError(toolname) package, filename = components try: pkg_dir = importlib.resources.files(package) if not pkg_dir.is_dir(): - raise BadToolNameError( - f"Cannot load {toolname} because {pkg_dir} is not a valid directory." - ) + raise ToolDirNotFoundError(toolname, pkg_dir) filepath = pkg_dir.joinpath(filename + ".yaml") with filepath.open() as fh: raw = yaml.safe_load(fh) @@ -128,17 +161,12 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: header = raw.get("seclab-taskflow-agent", {}) filetype = header.get("filetype", "") if filetype != tooltype.value: - raise FileTypeException( - f"Error in {filepath}: expected filetype {tooltype.value!r}, " - f"got {filetype!r}." - ) + raise FiletypeMismatchError(filepath, tooltype.value, filetype) # Parse into the appropriate Pydantic model model_cls = DOCUMENT_MODELS.get(filetype) if model_cls is None: - raise BadToolNameError( - f"Unknown filetype {filetype!r} in {toolname}" - ) + raise UnknownFiletypeError(filetype, toolname) try: doc = model_cls(**raw) @@ -147,9 +175,7 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: for err in exc.errors(): if "Unsupported version" in str(err.get("msg", "")): raise VersionException(str(err["msg"])) from exc - raise BadToolNameError( - f"Validation error loading {toolname}: {exc}" - ) from exc + raise ToolValidationError(toolname, exc) from exc # Cache and return if tooltype not in self._cache: @@ -158,10 +184,8 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: return doc except ModuleNotFoundError as exc: - raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc + raise ToolLoadError(toolname, exc) from exc except FileNotFoundError: - raise BadToolNameError( - f"Cannot load {toolname} because {filepath} is not a valid file." - ) + raise ToolFileNotFoundError(toolname, filepath) except ValueError as exc: - raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc + raise ToolLoadError(toolname, exc) from exc diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 4ed8dcd..75ed9f6 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -11,6 +11,8 @@ import httpx from strenum import StrEnum +from .exceptions import AITokenNotFoundError, UnsupportedEndpointError + __all__ = [ "AI_API_ENDPOINT_ENUM", "COPILOT_INTEGRATION_ID", @@ -38,7 +40,7 @@ def to_url(self) -> str: case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: return f"https://{self}/v1" case _: - raise ValueError(f"Unsupported endpoint: {self}") + raise UnsupportedEndpointError(self) COPILOT_INTEGRATION_ID = "vscode-chat" @@ -61,7 +63,7 @@ def get_AI_token() -> str: token = os.getenv("COPILOT_TOKEN") if token: return token - raise RuntimeError("AI_API_TOKEN environment variable is not set.") + raise AITokenNotFoundError() # assume we are >= python 3.9 for our type hints diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index 7569431..396dd3d 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -23,6 +23,7 @@ from .available_tools import AvailableTools from .banner import get_banner from .capi import get_AI_token, list_tool_call_models +from .exceptions import InvalidGlobalVariableError from .path_utils import log_file_name app = typer.Typer( @@ -36,7 +37,7 @@ def _parse_global(value: str) -> tuple[str, str]: """Parse a ``KEY=VALUE`` string into a (key, value) pair.""" if "=" not in value: - raise typer.BadParameter(f"Invalid global variable format: {value!r}. Expected KEY=VALUE.") + raise InvalidGlobalVariableError(value) key, _, val = value.partition("=") return key.strip(), val.strip() diff --git a/src/seclab_taskflow_agent/exceptions.py b/src/seclab_taskflow_agent/exceptions.py new file mode 100644 index 0000000..0f7b528 --- /dev/null +++ b/src/seclab_taskflow_agent/exceptions.py @@ -0,0 +1,219 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Custom exception classes for seclab-taskflow-agent. + +Defines project-specific exception types so that error messages are +encapsulated inside the exception class (satisfying TRY003) and are +therefore discoverable and reusable across the codebase. +""" + +from __future__ import annotations + +__all__ = [ + "AITokenNotFoundError", + "ExecutableNotFoundError", + "InvalidGlobalVariableError", + "MaxRateLimitReachedError", + "MCPConnectionTimeoutError", + "MissingHostPortError", + "MutuallyExclusiveTaskFieldsError", + "NoAgentsResolvedError", + "PersonalityNotFoundError", + "ProcessThreadTimeoutError", + "PromptTemplateRenderingError", + "RequiredEnvVarNotFoundError", + "ResultTextNotJSONError", + "ReusableTaskflowNotFoundError", + "ReusableTaskflowTooManyTasksError", + "SessionNotFoundError", + "ShellAndPromptMutuallyExclusiveError", + "ShellCommandError", + "TaskModelSettingsTypeError", + "TemplateRenderingError", + "TokenEnvVarNotSetError", + "ToolResultNotJSONError", + "UnknownModelSettingsError", + "UnsupportedEndpointError", + "UnsupportedMCPTransportError", + "UnsupportedVersionError", +] + +from typing import Any + +import typer +from openai import APITimeoutError + +# --------------------------------------------------------------------------- +# API / token errors +# --------------------------------------------------------------------------- + + +class UnsupportedEndpointError(ValueError): + def __init__(self, endpoint: str) -> None: + super().__init__(f"Unsupported endpoint: {endpoint}") + + +class AITokenNotFoundError(RuntimeError): + def __init__(self) -> None: + super().__init__("AI_API_TOKEN environment variable is not set.") + + +class TokenEnvVarNotSetError(RuntimeError): + def __init__(self, token: str) -> None: + super().__init__(f"Token env var {token!r} is not set") + + +# --------------------------------------------------------------------------- +# CLI errors +# --------------------------------------------------------------------------- + + +class InvalidGlobalVariableError(typer.BadParameter): + def __init__(self, value: str) -> None: + super().__init__(f"Invalid global variable format: {value!r}. Expected KEY=VALUE.") + + +# --------------------------------------------------------------------------- +# MCP transport errors +# --------------------------------------------------------------------------- + + +class UnsupportedMCPTransportError(ValueError): + def __init__(self, kind: str) -> None: + super().__init__(f"Unsupported MCP transport: {kind}") + + +class MissingHostPortError(ValueError): + def __init__(self, url: str) -> None: + super().__init__(f"URL must include a host and port: {url}") + + +class MCPConnectionTimeoutError(TimeoutError): + def __init__(self, host: str, port: int, timeout: float) -> None: + super().__init__(f"Could not connect to {host}:{port} after {timeout} seconds") + + +class ProcessThreadTimeoutError(RuntimeError): + def __init__(self) -> None: + super().__init__("Process thread did not exit within timeout.") + + +class ExecutableNotFoundError(FileNotFoundError): + def __init__(self, command: str) -> None: + super().__init__(f"Could not resolve path to {command}") + + +# --------------------------------------------------------------------------- +# Model / config validation errors +# --------------------------------------------------------------------------- + + +class UnsupportedVersionError(ValueError): + def __init__(self, version: str, supported: str) -> None: + super().__init__(f"Unsupported version: {version}. Only version {supported} is supported.") + + +class MutuallyExclusiveTaskFieldsError(ValueError): + def __init__(self) -> None: + super().__init__("shell task ('run') and prompt task ('user_prompt') are mutually exclusive") + + +# --------------------------------------------------------------------------- +# Runner errors +# --------------------------------------------------------------------------- + + +class UnknownModelSettingsError(ValueError): + def __init__(self, config_ref: str, unknown: Any) -> None: + super().__init__( + f"Settings section of model_config file {config_ref} contains models not in the model section: {unknown}" + ) + + +class ReusableTaskflowNotFoundError(ValueError): + def __init__(self, taskflow_name: str) -> None: + super().__init__(f"No such reusable taskflow: {taskflow_name}") + + +class ReusableTaskflowTooManyTasksError(ValueError): + def __init__(self) -> None: + super().__init__("Reusable taskflows can only contain 1 task") + + +class TaskModelSettingsTypeError(ValueError): + def __init__(self, task_name: str) -> None: + super().__init__(f"model_settings in task {task_name} needs to be a dictionary") + + +class ToolResultNotJSONError(ValueError): + def __init__(self) -> None: + super().__init__("Tool result is not valid JSON") + + +class ResultTextNotJSONError(ValueError): + def __init__(self) -> None: + super().__init__("Result text is not valid JSON") + + +class TemplateRenderingError(ValueError): + def __init__(self, error: Exception) -> None: + super().__init__(f"Template rendering failed: {error}") + + +class MaxRateLimitReachedError(APITimeoutError): + def __init__(self) -> None: + super().__init__("Max rate limit backoff reached") + + +class ShellAndPromptMutuallyExclusiveError(ValueError): + def __init__(self) -> None: + super().__init__("shell task and prompt task are mutually exclusive!") + + +class PromptTemplateRenderingError(ValueError): + def __init__(self, error: Exception) -> None: + super().__init__(f"Failed to render prompt template: {error}") + + +class PersonalityNotFoundError(ValueError): + def __init__(self, agent_name: str) -> None: + super().__init__(f"No such personality: {agent_name}") + + +class NoAgentsResolvedError(ValueError): + def __init__(self) -> None: + super().__init__( + "No agents resolved for this task. " + "Specify a personality with -p or provide an agents list." + ) + + +# --------------------------------------------------------------------------- +# Session errors +# --------------------------------------------------------------------------- + + +class SessionNotFoundError(FileNotFoundError): + def __init__(self, session_id: str) -> None: + super().__init__(f"No session checkpoint found: {session_id}") + + +# --------------------------------------------------------------------------- +# Shell / process errors +# --------------------------------------------------------------------------- + + +class ShellCommandError(RuntimeError): + def __init__(self, cmd: str, stderr: str) -> None: + super().__init__(f"Command {cmd} failed: {stderr}") + + +# --------------------------------------------------------------------------- +# Template / environment errors +# --------------------------------------------------------------------------- + + +class RequiredEnvVarNotFoundError(LookupError): + def __init__(self, var_name: str) -> None: + super().__init__(f"Required environment variable {var_name} not found!") diff --git a/src/seclab_taskflow_agent/mcp_lifecycle.py b/src/seclab_taskflow_agent/mcp_lifecycle.py index 117f52a..55a3613 100644 --- a/src/seclab_taskflow_agent/mcp_lifecycle.py +++ b/src/seclab_taskflow_agent/mcp_lifecycle.py @@ -23,6 +23,7 @@ MCPNamespaceWrap, mcp_client_params, ) +from .exceptions import UnsupportedMCPTransportError if TYPE_CHECKING: from .available_tools import AvailableTools @@ -116,7 +117,7 @@ def _print_err(line: str) -> None: client_session_timeout_seconds=client_session_timeout, ) case _: - raise ValueError(f"Unsupported MCP transport: {params['kind']}") + raise UnsupportedMCPTransportError(params['kind']) entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc, name=tb)) diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py index af7d03d..ffe3f1c 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py @@ -19,6 +19,16 @@ # this is a local fork of https://github.com/riga/jsonrpyc modified for our purposes from . import jsonrpyc +from .exceptions import ( + LegacyServerNotSupportedError, + NoActiveConnectionError, + NoActiveDatabaseError, + NonAbsoluteURIError, + NotFileURIError, + QueryRunError, + QuickEvalTargetNotFoundError, + UnsupportedOutputFormatError, +) WAIT_INTERVAL = 0.1 @@ -194,10 +204,10 @@ def _server_request_run( template_values: dict | None = None, ): if not self.active_database: - raise RuntimeError("No Active Database") + raise NoActiveDatabaseError() if not self.active_connection: - raise RuntimeError("No Active Connection") + raise NoActiveConnectionError() if isinstance(quick_eval_pos, dict): # A quick eval position contains: @@ -302,7 +312,7 @@ def _format(self, query): def _resolve_query_server(self): help_msg = shell_command_to_string(self.codeql_cli + ["excute", "--help"]) if not re.search("query-server2", help_msg): - raise RuntimeError("Legacy server not supported!") + raise LegacyServerNotSupportedError() return "query-server2" def _resolve_library_paths(self, query_path): @@ -463,11 +473,11 @@ def _file_uri_to_path(uri): # internally the codeql client will resolve both relative and full paths # regardless of root directory differences if not uri.startswith("file:///"): - raise ValueError("URI path should be formatted as absolute") + raise NonAbsoluteURIError() # note: don't try to parse paths like "file://a/b" because that returns "/b", should be "file:///a/b" parsed = urlparse(uri) if parsed.scheme != "file": - raise ValueError(f"Not a file:// uri: {uri}") + raise NotFileURIError(uri) path = unquote(parsed.path) region = None if ":" in path: @@ -605,7 +615,7 @@ def run_query( if target: target_pos = get_query_position(query_path, target) if not target_pos: - raise ValueError(f"Could not resolve quick eval target for {target}") + raise QuickEvalTargetNotFoundError(target) try: with ( QueryServer(database, keep_alive=keep_alive, log_stderr=log_stderr) as server, @@ -633,7 +643,7 @@ def run_query( case "sarif": result = server._bqrs_to_sarif(bqrs_path, server._query_info(query_path)) case _: - raise ValueError("Unsupported output format {fmt}") + raise UnsupportedOutputFormatError(fmt) except Exception as e: - raise RuntimeError(f"Error in run_query: {e}") from e + raise QueryRunError(e) from e return result diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/exceptions.py b/src/seclab_taskflow_agent/mcp_servers/codeql/exceptions.py new file mode 100644 index 0000000..c10c56d --- /dev/null +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/exceptions.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Custom exception classes for the CodeQL MCP server.""" + +from __future__ import annotations + +__all__ = [ + "DatabaseNotFoundError", + "LegacyServerNotSupportedError", + "NoActiveConnectionError", + "NoActiveDatabaseError", + "NonAbsoluteURIError", + "NotFileURIError", + "QueryRunError", + "QuickEvalTargetNotFoundError", + "UnsupportedLanguageError", + "UnsupportedOutputFormatError", + "UnsupportedQueryError", +] + + +class NoActiveDatabaseError(RuntimeError): + def __init__(self) -> None: + super().__init__("No Active Database") + + +class NoActiveConnectionError(RuntimeError): + def __init__(self) -> None: + super().__init__("No Active Connection") + + +class LegacyServerNotSupportedError(RuntimeError): + def __init__(self) -> None: + super().__init__("Legacy server not supported!") + + +class NonAbsoluteURIError(ValueError): + def __init__(self) -> None: + super().__init__("URI path should be formatted as absolute") + + +class NotFileURIError(ValueError): + def __init__(self, uri: str) -> None: + super().__init__(f"Not a file:// uri: {uri}") + + +class QuickEvalTargetNotFoundError(ValueError): + def __init__(self, target: str) -> None: + super().__init__(f"Could not resolve quick eval target for {target}") + + +class UnsupportedOutputFormatError(ValueError): + def __init__(self, fmt: str) -> None: + super().__init__(f"Unsupported output format {fmt}") + + +class QueryRunError(RuntimeError): + def __init__(self, error: Exception) -> None: + super().__init__(f"Error in run_query: {error}") + + +class UnsupportedLanguageError(RuntimeError): + def __init__(self, language: str) -> None: + super().__init__(f"Error: Language `{language}` not supported!") + + +class UnsupportedQueryError(RuntimeError): + def __init__(self, query: str, language: str) -> None: + super().__init__(f"Error: query `{query}` not supported for `{language}`!") + + +class DatabaseNotFoundError(RuntimeError): + def __init__(self, path: str) -> None: + super().__init__(f"Error: Database not found at {path}!") diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py b/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py index d245666..79e1aec 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py @@ -14,6 +14,7 @@ from seclab_taskflow_agent.path_utils import log_file_name, mcp_data_dir from .client import _debug_log, file_from_uri, list_src_files, run_query, search_in_src_archive +from .exceptions import DatabaseNotFoundError, UnsupportedLanguageError, UnsupportedQueryError logging.basicConfig( level=logging.DEBUG, @@ -53,10 +54,10 @@ def _resolve_query_path(language: str, query: str) -> Path: global TEMPLATED_QUERY_PATHS if language not in TEMPLATED_QUERY_PATHS: - raise RuntimeError(f"Error: Language `{language}` not supported!") + raise UnsupportedLanguageError(language) query_path = TEMPLATED_QUERY_PATHS[language].get(query) if not query_path: - raise RuntimeError(f"Error: query `{query}` not supported for `{language}`!") + raise UnsupportedQueryError(query, language) return Path(query_path) @@ -69,7 +70,7 @@ def _resolve_db_path(relative_db_path: str | Path): absolute_path = CODEQL_DBS_BASE_PATH / relative_db_path if not absolute_path.is_dir(): _debug_log(f"Database path not found: {absolute_path}") - raise RuntimeError(f"Error: Database not found at {absolute_path}!") + raise DatabaseNotFoundError(absolute_path) return absolute_path diff --git a/src/seclab_taskflow_agent/mcp_transport.py b/src/seclab_taskflow_agent/mcp_transport.py index 8632fd8..f574ae4 100644 --- a/src/seclab_taskflow_agent/mcp_transport.py +++ b/src/seclab_taskflow_agent/mcp_transport.py @@ -35,6 +35,8 @@ from agents.mcp import MCPServerStdio +from .exceptions import MCPConnectionTimeoutError, MissingHostPortError, ProcessThreadTimeoutError + # Exit codes that are considered normal termination. _EXPECTED_EXIT_CODES: frozenset[int] = frozenset({0, -signal.SIGTERM}) @@ -109,7 +111,7 @@ async def async_wait_for_connection( host = parsed.hostname port = parsed.port if host is None or port is None: - raise ValueError(f"URL must include a host and port: {self.url}") + raise MissingHostPortError(self.url) deadline = asyncio.get_event_loop().time() + timeout while True: try: @@ -119,7 +121,7 @@ async def async_wait_for_connection( return except (OSError, ConnectionRefusedError): if asyncio.get_event_loop().time() > deadline: - raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds") + raise MCPConnectionTimeoutError(host, port, timeout) await asyncio.sleep(poll_interval) def wait_for_connection( @@ -139,7 +141,7 @@ def wait_for_connection( host = parsed.hostname port = parsed.port if host is None or port is None: - raise ValueError(f"URL must include a host and port: {self.url}") + raise MissingHostPortError(self.url) deadline = time.time() + timeout while True: try: @@ -147,7 +149,7 @@ def wait_for_connection( return except OSError: if time.time() > deadline: - raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds") + raise MCPConnectionTimeoutError(host, port, timeout) time.sleep(poll_interval) def run(self) -> None: @@ -216,7 +218,7 @@ def join_and_raise(self, timeout: float | None = None) -> None: """ self.join(timeout) if self.is_alive(): - raise RuntimeError("Process thread did not exit within timeout.") + raise ProcessThreadTimeoutError() if self.exception is not None: raise self.exception diff --git a/src/seclab_taskflow_agent/mcp_utils.py b/src/seclab_taskflow_agent/mcp_utils.py index a186bee..7f70721 100644 --- a/src/seclab_taskflow_agent/mcp_utils.py +++ b/src/seclab_taskflow_agent/mcp_utils.py @@ -27,6 +27,7 @@ from .available_tools import AvailableTools from .env_utils import swap_env +from .exceptions import ExecutableNotFoundError, UnsupportedMCPTransportError # Re-export transport classes and prompt builder so that existing # ``from .mcp_utils import …`` statements continue to work. @@ -202,7 +203,7 @@ def mcp_client_params( logging.debug(f"Initializing streamable toolbox: {tb}\nargs:\n{args}\nenv:\n{env}\n") exe = shutil.which(sp.command) if exe is None: - raise FileNotFoundError(f"Could not resolve path to {sp.command}") + raise ExecutableNotFoundError(sp.command) start_cmd = [exe] if args: for i, v in enumerate(args): @@ -220,7 +221,7 @@ def mcp_client_params( server_params["env"] = env case _: - raise ValueError(f"Unsupported MCP transport {kind}") + raise UnsupportedMCPTransportError(kind) client_params[tb] = ( server_params, diff --git a/src/seclab_taskflow_agent/models.py b/src/seclab_taskflow_agent/models.py index 6445b64..d8f4794 100644 --- a/src/seclab_taskflow_agent/models.py +++ b/src/seclab_taskflow_agent/models.py @@ -29,6 +29,8 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from .exceptions import MutuallyExclusiveTaskFieldsError, UnsupportedVersionError + # Valid API type values for model configuration. ApiType = Literal["chat_completions", "responses"] @@ -62,9 +64,7 @@ def _normalise_version(cls, v: Any) -> str: @classmethod def _validate_version(cls, v: str) -> str: if v != SUPPORTED_VERSION: - raise ValueError( - f"Unsupported version: {v}. Only version {SUPPORTED_VERSION} is supported." - ) + raise UnsupportedVersionError(v, SUPPORTED_VERSION) return v @@ -106,7 +106,7 @@ class TaskDefinition(BaseModel): @model_validator(mode="after") def _run_xor_prompt(self) -> TaskDefinition: if self.run and self.user_prompt: - raise ValueError("shell task ('run') and prompt task ('user_prompt') are mutually exclusive") + raise MutuallyExclusiveTaskFieldsError() return self diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 5869385..ea845f6 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -37,6 +37,20 @@ from .agent import DEFAULT_MODEL, TaskAgent, TaskAgentHooks, TaskRunHooks from .available_tools import AvailableTools from .env_utils import TmpEnv +from .exceptions import ( + MaxRateLimitReachedError, + NoAgentsResolvedError, + PersonalityNotFoundError, + PromptTemplateRenderingError, + ResultTextNotJSONError, + ReusableTaskflowNotFoundError, + ReusableTaskflowTooManyTasksError, + ShellAndPromptMutuallyExclusiveError, + TaskModelSettingsTypeError, + TemplateRenderingError, + ToolResultNotJSONError, + UnknownModelSettingsError, +) from .mcp_lifecycle import MCP_CLEANUP_TIMEOUT, build_mcp_servers, mcp_session_task from .models import ModelConfigDocument, PersonalityDocument, TaskDefinition from .mcp_prompt import mcp_system_prompt @@ -78,9 +92,7 @@ def _resolve_model_config( models_params: dict[str, dict[str, Any]] = m_config.model_settings or {} unknown = set(models_params) - set(model_keys) if unknown: - raise ValueError( - f"Settings section of model_config file {model_config_ref} contains models not in the model section: {unknown}" - ) + raise UnknownModelSettingsError(model_config_ref, unknown) return model_keys, model_dict, models_params, m_config.api_type @@ -103,9 +115,9 @@ def _merge_reusable_task( """ reusable_doc = available_tools.get_taskflow(task.uses) if reusable_doc is None: - raise ValueError(f"No such reusable taskflow: {task.uses}") + raise ReusableTaskflowNotFoundError(task.uses) if len(reusable_doc.taskflow) > 1: - raise ValueError("Reusable taskflows can only contain 1 task") + raise ReusableTaskflowTooManyTasksError() parent_task = reusable_doc.taskflow[0].task merged: dict[str, Any] = parent_task.model_dump(by_alias=True, exclude_defaults=True) current: dict[str, Any] = task.model_dump(by_alias=True, exclude_defaults=True) @@ -147,7 +159,7 @@ def _resolve_task_model( task_model_settings: dict[str, Any] | Any = task.model_settings or {} if not isinstance(task_model_settings, dict): - raise ValueError(f"model_settings in task {task.name or ''} needs to be a dictionary") + raise TaskModelSettingsTypeError(task.name or '') # Task-level overrides can also set engine keys task_settings = dict(task_model_settings) @@ -198,14 +210,14 @@ async def _build_prompts_to_run( raise except json.JSONDecodeError as exc: logging.critical(f"Could not parse tool result as JSON: {last_mcp_tool_results[-1][:200]}") - raise ValueError("Tool result is not valid JSON") from exc + raise ToolResultNotJSONError() from exc text = last_result.get("text", "") try: iterable_result = json.loads(text) except json.JSONDecodeError as exc: logging.critical(f"Could not parse result text: {text}") - raise ValueError("Result text is not valid JSON") from exc + raise ResultTextNotJSONError() from exc try: iter(iterable_result) except TypeError: @@ -228,7 +240,7 @@ async def _build_prompts_to_run( prompts_to_run.append(rendered_prompt) except jinja2.TemplateError as e: logging.error(f"Error rendering template for result {value}: {e}") - raise ValueError(f"Template rendering failed: {e}") + raise TemplateRenderingError(e) # Consume only after all prompts rendered successfully so that # the result remains available for retry/resume on failure. @@ -403,7 +415,7 @@ async def _run_streamed() -> None: max_retry -= 1 except RateLimitError: if rate_limit_backoff == MAX_RATE_LIMIT_BACKOFF: - raise APITimeoutError("Max rate limit backoff reached") + raise MaxRateLimitReachedError() if rate_limit_backoff > MAX_RATE_LIMIT_BACKOFF: rate_limit_backoff = MAX_RATE_LIMIT_BACKOFF else: @@ -556,7 +568,7 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo inputs = task.inputs or {} task_prompt = task.user_prompt or "" if run and task_prompt: - raise ValueError("shell task and prompt task are mutually exclusive!") + raise ShellAndPromptMutuallyExclusiveError() must_complete = task.must_complete max_turns = task.max_steps or DEFAULT_MAX_TURNS toolboxes_override = task.toolboxes or [] @@ -577,7 +589,7 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo ) except jinja2.TemplateError as e: logging.error(f"Template rendering error: {e}") - raise ValueError(f"Failed to render prompt template: {e}") from e + raise PromptTemplateRenderingError(e) from e with TmpEnv(env): prompts_to_run: list[str] = await _build_prompts_to_run( @@ -611,14 +623,11 @@ async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) - for agent_name in current_agents: personality = available_tools.get_personality(agent_name) if personality is None: - raise ValueError(f"No such personality: {agent_name}") + raise PersonalityNotFoundError(agent_name) resolved_agents[agent_name] = personality if not resolved_agents: - raise ValueError( - "No agents resolved for this task. " - "Specify a personality with -p or provide an agents list." - ) + raise NoAgentsResolvedError() async def _deploy(ra: dict, pp: str) -> bool: async with semaphore: diff --git a/src/seclab_taskflow_agent/session.py b/src/seclab_taskflow_agent/session.py index 9b77151..180e80d 100644 --- a/src/seclab_taskflow_agent/session.py +++ b/src/seclab_taskflow_agent/session.py @@ -23,6 +23,7 @@ from pydantic import BaseModel, Field +from .exceptions import SessionNotFoundError from .path_utils import _data_dir @@ -121,7 +122,7 @@ def load(cls, session_id: str) -> TaskflowSession: """ path = session_dir() / f"{session_id}.json" if not path.exists(): - raise FileNotFoundError(f"No session checkpoint found: {session_id}") + raise SessionNotFoundError(session_id) return cls.model_validate_json(path.read_text()) @classmethod diff --git a/src/seclab_taskflow_agent/shell_utils.py b/src/seclab_taskflow_agent/shell_utils.py index 75175ec..5d7a26a 100644 --- a/src/seclab_taskflow_agent/shell_utils.py +++ b/src/seclab_taskflow_agent/shell_utils.py @@ -9,6 +9,8 @@ from mcp.types import CallToolResult, TextContent +from .exceptions import ShellCommandError + __all__ = ["shell_command_to_string", "shell_exec_with_temporary_file", "shell_tool_call"] @@ -23,7 +25,7 @@ def shell_command_to_string(cmd: list[str]) -> str: stdout, stderr = p.communicate() p.wait() if p.returncode: - raise RuntimeError(f"Command {cmd} failed: {stderr}") + raise ShellCommandError(cmd, stderr) return stdout diff --git a/src/seclab_taskflow_agent/template_utils.py b/src/seclab_taskflow_agent/template_utils.py index 2f21d4a..818b117 100644 --- a/src/seclab_taskflow_agent/template_utils.py +++ b/src/seclab_taskflow_agent/template_utils.py @@ -14,6 +14,7 @@ from .available_tools import AvailableTools from .available_tools import BadToolNameError +from .exceptions import RequiredEnvVarNotFoundError class PromptLoader(jinja2.BaseLoader): @@ -77,7 +78,7 @@ def env_function(var_name: str, default: Optional[str] = None, required: bool = """ value = os.getenv(var_name, default) if value is None and required: - raise LookupError(f"Required environment variable {var_name} not found!") + raise RequiredEnvVarNotFoundError(var_name) return value or ""