diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index a7966f2..a26222a 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -6,7 +6,6 @@ import os from collections.abc import Callable from typing import Any -from urllib.parse import urlparse from agents import ( Agent, @@ -26,7 +25,7 @@ from dotenv import find_dotenv, load_dotenv from openai import AsyncOpenAI -from .capi import AI_API_ENDPOINT_ENUM, COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token +from .capi import get_AI_endpoint, get_AI_token, get_provider __all__ = [ "DEFAULT_MODEL", @@ -39,17 +38,8 @@ load_dotenv(find_dotenv(usecwd=True)) api_endpoint = get_AI_endpoint() -match urlparse(api_endpoint).netloc: - case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT: - default_model = "gpt-4.1" - case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: - default_model = "openai/gpt-4.1" - case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: - default_model = "gpt-4.1" - case _: - default_model = "please-set-default-model-via-env" - -DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=default_model) +_default_provider = get_provider(api_endpoint) +DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=_default_provider.default_model) class TaskRunHooks(RunHooks): @@ -186,10 +176,12 @@ def __init__( else: resolved_token = get_AI_token() + # Only send provider-specific headers to matching endpoints + provider = get_provider(resolved_endpoint) client = AsyncOpenAI( base_url=resolved_endpoint, api_key=resolved_token, - default_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID}, + default_headers=provider.extra_headers or None, ) set_tracing_disabled(True) self.run_hooks = run_hooks or TaskRunHooks() diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 4ed8dcd..8bb43cc 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -1,53 +1,159 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT -"""AI API endpoint and token management (CAPI integration).""" +"""AI API endpoint and token management. + +Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and +custom endpoints). All provider-specific behaviour is captured in a single +``APIProvider`` dataclass so that adding a new provider only requires one +registry entry instead of changes scattered across multiple match/case blocks. +""" + +from __future__ import annotations import json import logging import os +from collections.abc import Mapping +from dataclasses import dataclass, field +from types import MappingProxyType +from typing import Any from urllib.parse import urlparse import httpx -from strenum import StrEnum __all__ = [ - "AI_API_ENDPOINT_ENUM", "COPILOT_INTEGRATION_ID", + "APIProvider", "get_AI_endpoint", "get_AI_token", + "get_provider", "list_capi_models", "list_tool_call_models", "supports_tool_calls", ] +COPILOT_INTEGRATION_ID = os.getenv("COPILOT_INTEGRATION_ID", "vscode-chat") + + +# --------------------------------------------------------------------------- +# Provider abstraction +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class APIProvider: + """Encapsulates all endpoint-specific behaviour in one place.""" + + name: str + base_url: str + models_catalog: str = "/models" + default_model: str = "gpt-4.1" + extra_headers: Mapping[str, str] = field(default_factory=dict) + + def __post_init__(self) -> None: + # Ensure base_url ends with / so httpx URL.join() preserves the path + if self.base_url and not self.base_url.endswith("/"): + object.__setattr__(self, "base_url", self.base_url + "/") + # Freeze mutable headers so singleton providers can't be mutated + if isinstance(self.extra_headers, dict): + object.__setattr__(self, "extra_headers", MappingProxyType(self.extra_headers)) + + # -- response parsing ----------------------------------------------------- -# Enumeration of currently supported API endpoints. -class AI_API_ENDPOINT_ENUM(StrEnum): - AI_API_MODELS_GITHUB = "models.github.ai" - AI_API_GITHUBCOPILOT = "api.githubcopilot.com" - AI_API_OPENAI = "api.openai.com" + def parse_models_list(self, body: Any) -> list[dict]: + """Extract the models list from a catalog response body.""" + if isinstance(body, list): + return body + if isinstance(body, dict): + data = body.get("data", []) + return data if isinstance(data, list) else [] + return [] - def to_url(self) -> str: - """Convert the endpoint to its full URL.""" - match self: - case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT: - return f"https://{self}" - case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: - return f"https://{self}/inference" - case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: - return f"https://{self}/v1" - case _: - raise ValueError(f"Unsupported endpoint: {self}") + # -- tool-call capability check ------------------------------------------- + def check_tool_calls(self, _model: str, model_info: dict) -> bool: + """Return True if *model* supports tool calls according to its catalog entry.""" + # Default: optimistically assume support when present in catalog + return bool(model_info) -COPILOT_INTEGRATION_ID = "vscode-chat" +class _CopilotProvider(APIProvider): + """GitHub Copilot API (api.githubcopilot.com).""" + + def check_tool_calls(self, _model: str, model_info: dict) -> bool: + return ( + model_info + .get("capabilities", {}) + .get("supports", {}) + .get("tool_calls", False) + ) + + +class _GitHubModelsProvider(APIProvider): + """GitHub Models API (models.github.ai).""" + + def parse_models_list(self, body: Any) -> list[dict]: + # Models API returns a bare list, not {"data": [...]} + if isinstance(body, list): + return body + return super().parse_models_list(body) + + def check_tool_calls(self, _model: str, model_info: dict) -> bool: + return "tool-calling" in model_info.get("capabilities", []) + + +class _OpenAIProvider(APIProvider): + """OpenAI API (api.openai.com). + + The OpenAI /v1/models catalog does not expose capability metadata, so + we maintain a prefix allowlist of known chat-completion model families. + """ + + _CHAT_PREFIXES = ("gpt-3.5", "gpt-4", "o1", "o3", "o4", "chatgpt-") + + def check_tool_calls(self, _model: str, model_info: dict) -> bool: + model_id = model_info.get("id", "").lower() + return any(model_id.startswith(p) for p in self._CHAT_PREFIXES) +# --------------------------------------------------------------------------- +# Provider registry — add new providers here +# --------------------------------------------------------------------------- + +_PROVIDERS: dict[str, APIProvider] = { + "api.githubcopilot.com": _CopilotProvider( + name="copilot", + base_url="https://api.githubcopilot.com", + default_model="gpt-4.1", + extra_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID}, + ), + "models.github.ai": _GitHubModelsProvider( + name="github-models", + base_url="https://models.github.ai/inference", + models_catalog="/catalog/models", + default_model="openai/gpt-4.1", + ), + "api.openai.com": _OpenAIProvider( + name="openai", + base_url="https://api.openai.com/v1", + models_catalog="/v1/models", + default_model="gpt-4.1", + ), +} + +def get_provider(endpoint: str | None = None) -> APIProvider: + """Return the ``APIProvider`` for the given (or configured) endpoint URL.""" + url = endpoint or get_AI_endpoint() + netloc = urlparse(url).netloc + provider = _PROVIDERS.get(netloc) + if provider is not None: + return provider + # Unknown endpoint — return a generic provider with the given base URL + return APIProvider(name="custom", base_url=url, default_model="please-set-default-model-via-env") + + +# --------------------------------------------------------------------------- +# Endpoint / token helpers +# --------------------------------------------------------------------------- -# you can also set https://api.githubcopilot.com if you prefer -# but beware that your taskflows need to reference the correct model id -# since different APIs use their own id schema, use -l with your desired -# endpoint to retrieve the correct id names to use for your taskflow def get_AI_endpoint() -> str: """Return the configured AI API endpoint URL.""" return os.getenv("AI_API_ENDPOINT", default="https://models.github.ai/inference") @@ -64,82 +170,54 @@ def get_AI_token() -> str: raise RuntimeError("AI_API_TOKEN environment variable is not set.") -# assume we are >= python 3.9 for our type hints -def list_capi_models(token: str) -> dict[str, dict]: - """Retrieve a dictionary of available CAPI models""" - models = {} +# --------------------------------------------------------------------------- +# Model catalog +# --------------------------------------------------------------------------- + +def list_capi_models(token: str, endpoint: str | None = None) -> dict[str, dict]: + """Retrieve available models from the configured API endpoint. + + Args: + token: Bearer token for authentication. + endpoint: Optional endpoint URL override (defaults to env config). + """ + provider = get_provider(endpoint) + base = provider.base_url + models: dict[str, dict] = {} try: - api_endpoint = get_AI_endpoint() - netloc = urlparse(api_endpoint).netloc - match netloc: - case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT: - models_catalog = "models" - case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: - models_catalog = "catalog/models" - case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: - models_catalog = "models" - case _: - # Unknown endpoint — try the OpenAI-style models catalog - models_catalog = "models" + headers = { + "Accept": "application/json", + "Authorization": f"Bearer {token}", + **provider.extra_headers, + } r = httpx.get( - httpx.URL(api_endpoint).join(models_catalog), - headers={ - "Accept": "application/json", - "Authorization": f"Bearer {token}", - "Copilot-Integration-Id": COPILOT_INTEGRATION_ID, - }, + httpx.URL(base).join(provider.models_catalog), + headers=headers, ) r.raise_for_status() - # CAPI vs Models API - match netloc: - case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT: - models_list = r.json().get("data", []) - case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: - models_list = r.json() - case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: - models_list = r.json().get("data", []) - case _: - # Unknown endpoint — try common response shapes - body = r.json() - if isinstance(body, dict): - models_list = body.get("data", []) - elif isinstance(body, list): - models_list = body - else: - models_list = [] - for model in models_list: + for model in provider.parse_models_list(r.json()): models[model.get("id")] = dict(model) - except httpx.RequestError: - logging.exception("Request error") - except json.JSONDecodeError: - logging.exception("JSON error") - except httpx.HTTPStatusError: - logging.exception("HTTP error") + except (httpx.RequestError, httpx.HTTPStatusError, json.JSONDecodeError): + logging.exception("Failed to list models from %s", base) return models -def supports_tool_calls(model: str, models: dict[str, dict]) -> bool: - """Check whether the given model supports tool calls.""" - api_endpoint = get_AI_endpoint() - netloc = urlparse(api_endpoint).netloc - match netloc: - case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT: - return models.get(model, {}).get("capabilities", {}).get("supports", {}).get("tool_calls", False) - case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: - return "tool-calling" in models.get(model, {}).get("capabilities", []) - case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: - return "gpt-" in model.lower() - case _: - # Unknown endpoint — optimistically assume tool-call support - # if the model is present in the catalog. - return model in models - - -def list_tool_call_models(token: str) -> dict[str, dict]: +def supports_tool_calls( + model: str, + models: dict[str, dict], + endpoint: str | None = None, +) -> bool: + """Check whether *model* supports tool calls.""" + provider = get_provider(endpoint) + return provider.check_tool_calls(model, models.get(model, {})) + + +def list_tool_call_models(token: str, endpoint: str | None = None) -> dict[str, dict]: """Return only models that support tool calls.""" - models = list_capi_models(token) - tool_models: dict[str, dict] = {} - for model in models: - if supports_tool_calls(model, models) is True: - tool_models[model] = models[model] - return tool_models + models = list_capi_models(token, endpoint) + provider = get_provider(endpoint) + return { + mid: info + for mid, info in models.items() + if provider.check_tool_calls(mid, info) + } diff --git a/tests/test_api_endpoint_config.py b/tests/test_api_endpoint_config.py index 3912496..6abdfaf 100644 --- a/tests/test_api_endpoint_config.py +++ b/tests/test_api_endpoint_config.py @@ -10,7 +10,7 @@ import pytest -from seclab_taskflow_agent.capi import AI_API_ENDPOINT_ENUM, get_AI_endpoint, list_capi_models +from seclab_taskflow_agent.capi import get_AI_endpoint, get_provider, list_capi_models class TestAPIEndpoint: @@ -18,55 +18,37 @@ class TestAPIEndpoint: def test_default_api_endpoint(self): """Test that default API endpoint is set to models.github.ai/inference.""" - # When no env var is set, it should default to models.github.ai/inference try: - # Save original env original_env = os.environ.pop("AI_API_ENDPOINT", None) endpoint = get_AI_endpoint() assert endpoint is not None assert isinstance(endpoint, str) - assert urlparse(endpoint).netloc == AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB + assert urlparse(endpoint).netloc == "models.github.ai" finally: - # Restore original env if original_env: os.environ["AI_API_ENDPOINT"] = original_env def test_api_endpoint_env_override(self): """Test that AI_API_ENDPOINT can be overridden by environment variable.""" try: - # Save original env original_env = os.environ.pop("AI_API_ENDPOINT", None) - # Set different endpoint test_endpoint = "https://api.githubcopilot.com" os.environ["AI_API_ENDPOINT"] = test_endpoint - assert get_AI_endpoint() == test_endpoint finally: - # Restore original env if original_env: os.environ["AI_API_ENDPOINT"] = original_env - def test_to_url_models_github(self): - """Test to_url method for models.github.ai endpoint.""" - endpoint = AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB - assert endpoint.to_url() == "https://models.github.ai/inference" - - def test_to_url_githubcopilot(self): - """Test to_url method for GitHub Copilot endpoint.""" - endpoint = AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT - assert endpoint.to_url() == "https://api.githubcopilot.com" - - def test_to_url_openai(self): - """Test to_url method for OpenAI endpoint.""" - endpoint = AI_API_ENDPOINT_ENUM.AI_API_OPENAI - assert endpoint.to_url() == "https://api.openai.com/v1" + def test_provider_base_urls(self): + """Test that providers resolve to expected base URLs.""" + assert get_provider("https://models.github.ai/inference").base_url == "https://models.github.ai/inference/" + assert get_provider("https://api.githubcopilot.com").base_url == "https://api.githubcopilot.com/" + assert get_provider("https://api.openai.com/v1").base_url == "https://api.openai.com/v1/" def test_unsupported_endpoint(self, monkeypatch): """Test that unsupported API endpoint falls back gracefully.""" api_endpoint = "https://unsupported.example.com" monkeypatch.setenv("AI_API_ENDPOINT", api_endpoint) - # Unknown endpoints should not raise; they try OpenAI-style catalog - # and return an empty dict on connection failure. result = list_capi_models("abc") assert isinstance(result, dict) assert result == {} diff --git a/tests/test_capi_extended.py b/tests/test_capi_extended.py index 297f6ae..28a3d06 100644 --- a/tests/test_capi_extended.py +++ b/tests/test_capi_extended.py @@ -5,11 +5,11 @@ from __future__ import annotations -from seclab_taskflow_agent.capi import AI_API_ENDPOINT_ENUM, supports_tool_calls +from seclab_taskflow_agent.capi import get_provider, supports_tool_calls class TestSupportsToolCalls: - """Tests for supports_tool_calls with unknown endpoints.""" + """Tests for supports_tool_calls with various endpoints.""" def test_unknown_endpoint_known_model(self, monkeypatch): """Unknown endpoint returns True when model is in the catalog.""" @@ -67,31 +67,59 @@ def test_models_github_endpoint_no_tool_calling(self, monkeypatch): } assert supports_tool_calls("some-model", models) is False - def test_openai_endpoint_gpt_model(self, monkeypatch): - """OpenAI endpoint returns True for models containing 'gpt-'.""" + def test_openai_endpoint_model_in_catalog(self, monkeypatch): + """OpenAI endpoint returns True for known chat model families.""" monkeypatch.setenv("AI_API_ENDPOINT", "https://api.openai.com/v1") - assert supports_tool_calls("gpt-4o", {}) is True + models = {"gpt-4o": {"id": "gpt-4o"}} + assert supports_tool_calls("gpt-4o", models) is True - def test_openai_endpoint_non_gpt_model(self, monkeypatch): - """OpenAI endpoint returns False for non-GPT models.""" + def test_openai_endpoint_o_series(self, monkeypatch): + """OpenAI endpoint returns True for o-series reasoning models.""" monkeypatch.setenv("AI_API_ENDPOINT", "https://api.openai.com/v1") - assert supports_tool_calls("claude-3-opus", {}) is False - - -class TestAIAPIEndpointEnum: - """Tests for the AI_API_ENDPOINT_ENUM StrEnum.""" - - def test_enum_values(self): - """All expected endpoint values exist.""" - assert AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB == "models.github.ai" - assert AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT == "api.githubcopilot.com" - assert AI_API_ENDPOINT_ENUM.AI_API_OPENAI == "api.openai.com" + for mid in ("o1-preview", "o3-mini", "o4-mini"): + models = {mid: {"id": mid}} + assert supports_tool_calls(mid, models) is True - def test_to_url_models_github(self): - assert AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB.to_url() == "https://models.github.ai/inference" - - def test_to_url_copilot(self): - assert AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT.to_url() == "https://api.githubcopilot.com" + def test_openai_endpoint_non_chat_model(self, monkeypatch): + """OpenAI endpoint returns False for embeddings/audio/image models.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://api.openai.com/v1") + for mid in ("text-embedding-ada-002", "whisper-1", "dall-e-3", "tts-1"): + models = {mid: {"id": mid}} + assert supports_tool_calls(mid, models) is False - def test_to_url_openai(self): - assert AI_API_ENDPOINT_ENUM.AI_API_OPENAI.to_url() == "https://api.openai.com/v1" + def test_openai_endpoint_model_not_in_catalog(self, monkeypatch): + """OpenAI endpoint returns False when model is not in catalog.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://api.openai.com/v1") + assert supports_tool_calls("missing-model", {}) is False + + def test_explicit_endpoint_override(self): + """supports_tool_calls accepts an explicit endpoint parameter.""" + models = {"my-model": {"id": "my-model", "capabilities": {"supports": {"tool_calls": True}}}} + assert supports_tool_calls("my-model", models, endpoint="https://api.githubcopilot.com") is True + + +class TestGetProvider: + """Tests for the provider registry.""" + + def test_copilot_provider(self): + p = get_provider("https://api.githubcopilot.com") + assert p.name == "copilot" + assert p.base_url == "https://api.githubcopilot.com/" + assert "Copilot-Integration-Id" in p.extra_headers + + def test_github_models_provider(self): + p = get_provider("https://models.github.ai/inference") + assert p.name == "github-models" + assert p.models_catalog == "/catalog/models" + assert p.default_model == "openai/gpt-4.1" + + def test_openai_provider(self): + p = get_provider("https://api.openai.com/v1") + assert p.name == "openai" + assert not p.extra_headers + + def test_custom_endpoint(self): + p = get_provider("https://my-custom-llm.example.com/v1") + assert p.name == "custom" + assert p.base_url == "https://my-custom-llm.example.com/v1/" + assert not p.extra_headers