Skip to content

Commit 01041fd

Browse files
committed
Refactor capi.py: replace scattered match/case with provider pattern
Consolidate endpoint-specific logic (catalog path, response parsing, tool-call detection, headers) into an APIProvider dataclass with a hostname-keyed registry. Adding a new endpoint is now a single registry entry instead of changes across three match/case blocks. - Remove AI_API_ENDPOINT_ENUM StrEnum and strenum dependency - Only send Copilot-Integration-Id header to Copilot endpoints - Accept optional endpoint parameter in public functions - Drop fragile "gpt-" substring heuristic for OpenAI tool-call check - Update tests to use new provider API
1 parent 474b44f commit 01041fd

4 files changed

Lines changed: 231 additions & 158 deletions

File tree

src/seclab_taskflow_agent/agent.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import os
77
from collections.abc import Callable
88
from typing import Any
9-
from urllib.parse import urlparse
109

1110
from agents import (
1211
Agent,
@@ -26,7 +25,7 @@
2625
from dotenv import find_dotenv, load_dotenv
2726
from openai import AsyncOpenAI
2827

29-
from .capi import AI_API_ENDPOINT_ENUM, COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token
28+
from .capi import get_AI_endpoint, get_AI_token, get_provider
3029

3130
__all__ = [
3231
"DEFAULT_MODEL",
@@ -39,17 +38,8 @@
3938
load_dotenv(find_dotenv(usecwd=True))
4039

4140
api_endpoint = get_AI_endpoint()
42-
match urlparse(api_endpoint).netloc:
43-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
44-
default_model = "gpt-4.1"
45-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
46-
default_model = "openai/gpt-4.1"
47-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
48-
default_model = "gpt-4.1"
49-
case _:
50-
default_model = "please-set-default-model-via-env"
51-
52-
DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=default_model)
41+
_default_provider = get_provider(api_endpoint)
42+
DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=_default_provider.default_model)
5343

5444

5545
class TaskRunHooks(RunHooks):
@@ -186,10 +176,12 @@ def __init__(
186176
else:
187177
resolved_token = get_AI_token()
188178

179+
# Only send provider-specific headers to matching endpoints
180+
provider = get_provider(resolved_endpoint)
189181
client = AsyncOpenAI(
190182
base_url=resolved_endpoint,
191183
api_key=resolved_token,
192-
default_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID},
184+
default_headers=provider.extra_headers or None,
193185
)
194186
set_tracing_disabled(True)
195187
self.run_hooks = run_hooks or TaskRunHooks()

src/seclab_taskflow_agent/capi.py

Lines changed: 165 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,152 @@
11
# SPDX-FileCopyrightText: GitHub, Inc.
22
# SPDX-License-Identifier: MIT
33

4-
"""AI API endpoint and token management (CAPI integration)."""
4+
"""AI API endpoint and token management.
5+
6+
Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and
7+
custom endpoints). All provider-specific behaviour is captured in a single
8+
``APIProvider`` dataclass so that adding a new provider only requires one
9+
registry entry instead of changes scattered across multiple match/case blocks.
10+
"""
11+
12+
from __future__ import annotations
513

614
import json
715
import logging
816
import os
17+
from dataclasses import dataclass, field
18+
from types import MappingProxyType
19+
from typing import Any
920
from urllib.parse import urlparse
1021

1122
import httpx
12-
from strenum import StrEnum
1323

1424
__all__ = [
15-
"AI_API_ENDPOINT_ENUM",
1625
"COPILOT_INTEGRATION_ID",
26+
"APIProvider",
1727
"get_AI_endpoint",
1828
"get_AI_token",
29+
"get_provider",
1930
"list_capi_models",
2031
"list_tool_call_models",
2132
"supports_tool_calls",
2233
]
2334

35+
COPILOT_INTEGRATION_ID = "vscode-chat"
2436

25-
# Enumeration of currently supported API endpoints.
26-
class AI_API_ENDPOINT_ENUM(StrEnum):
27-
AI_API_MODELS_GITHUB = "models.github.ai"
28-
AI_API_GITHUBCOPILOT = "api.githubcopilot.com"
29-
AI_API_OPENAI = "api.openai.com"
3037

31-
def to_url(self) -> str:
32-
"""Convert the endpoint to its full URL."""
33-
match self:
34-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
35-
return f"https://{self}"
36-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
37-
return f"https://{self}/inference"
38-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
39-
return f"https://{self}/v1"
40-
case _:
41-
raise ValueError(f"Unsupported endpoint: {self}")
38+
# ---------------------------------------------------------------------------
39+
# Provider abstraction
40+
# ---------------------------------------------------------------------------
4241

42+
@dataclass(frozen=True)
43+
class APIProvider:
44+
"""Encapsulates all endpoint-specific behaviour in one place."""
4345

44-
COPILOT_INTEGRATION_ID = "vscode-chat"
46+
name: str
47+
base_url: str
48+
models_catalog: str = "models"
49+
default_model: str = "gpt-4.1"
50+
extra_headers: dict[str, str] = field(default_factory=dict)
51+
52+
def __post_init__(self) -> None:
53+
# Freeze mutable headers so singleton providers can't be mutated
54+
object.__setattr__(self, "extra_headers", MappingProxyType(self.extra_headers))
55+
56+
# -- response parsing -----------------------------------------------------
57+
58+
def parse_models_list(self, body: Any) -> list[dict]:
59+
"""Extract the models list from a catalog response body."""
60+
if isinstance(body, list):
61+
return body
62+
if isinstance(body, dict):
63+
return body.get("data", [])
64+
return []
65+
66+
# -- tool-call capability check -------------------------------------------
67+
68+
def check_tool_calls(self, _model: str, model_info: dict) -> bool:
69+
"""Return True if *model* supports tool calls according to its catalog entry."""
70+
# Default: optimistically assume support when present in catalog
71+
return bool(model_info)
4572

4673

47-
# you can also set https://api.githubcopilot.com if you prefer
48-
# but beware that your taskflows need to reference the correct model id
49-
# since different APIs use their own id schema, use -l with your desired
50-
# endpoint to retrieve the correct id names to use for your taskflow
74+
class _CopilotProvider(APIProvider):
75+
"""GitHub Copilot API (api.githubcopilot.com)."""
76+
77+
def check_tool_calls(self, _model: str, model_info: dict) -> bool:
78+
return (
79+
model_info
80+
.get("capabilities", {})
81+
.get("supports", {})
82+
.get("tool_calls", False)
83+
)
84+
85+
86+
class _GitHubModelsProvider(APIProvider):
87+
"""GitHub Models API (models.github.ai)."""
88+
89+
def parse_models_list(self, body: Any) -> list[dict]:
90+
# Models API returns a bare list, not {"data": [...]}
91+
if isinstance(body, list):
92+
return body
93+
return super().parse_models_list(body)
94+
95+
def check_tool_calls(self, _model: str, model_info: dict) -> bool:
96+
return "tool-calling" in model_info.get("capabilities", [])
97+
98+
99+
class _OpenAIProvider(APIProvider):
100+
"""OpenAI API (api.openai.com).
101+
102+
The OpenAI /v1/models catalog does not expose capability metadata, so
103+
we maintain a prefix allowlist of known chat-completion model families.
104+
"""
105+
106+
_CHAT_PREFIXES = ("gpt-3.5", "gpt-4", "o1", "o3", "o4", "chatgpt-")
107+
108+
def check_tool_calls(self, _model: str, model_info: dict) -> bool:
109+
model_id = model_info.get("id", "").lower()
110+
return any(model_id.startswith(p) for p in self._CHAT_PREFIXES)
111+
# ---------------------------------------------------------------------------
112+
# Provider registry — add new providers here
113+
# ---------------------------------------------------------------------------
114+
115+
_PROVIDERS: dict[str, APIProvider] = {
116+
"api.githubcopilot.com": _CopilotProvider(
117+
name="copilot",
118+
base_url="https://api.githubcopilot.com",
119+
default_model="gpt-4.1",
120+
extra_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID},
121+
),
122+
"models.github.ai": _GitHubModelsProvider(
123+
name="github-models",
124+
base_url="https://models.github.ai/inference",
125+
models_catalog="catalog/models",
126+
default_model="openai/gpt-4.1",
127+
),
128+
"api.openai.com": _OpenAIProvider(
129+
name="openai",
130+
base_url="https://api.openai.com/v1",
131+
default_model="gpt-4.1",
132+
),
133+
}
134+
135+
def get_provider(endpoint: str | None = None) -> APIProvider:
136+
"""Return the ``APIProvider`` for the given (or configured) endpoint URL."""
137+
url = endpoint or get_AI_endpoint()
138+
netloc = urlparse(url).netloc
139+
provider = _PROVIDERS.get(netloc)
140+
if provider is not None:
141+
return provider
142+
# Unknown endpoint — return a generic provider with the given base URL
143+
return APIProvider(name="custom", base_url=url)
144+
145+
146+
# ---------------------------------------------------------------------------
147+
# Endpoint / token helpers
148+
# ---------------------------------------------------------------------------
149+
51150
def get_AI_endpoint() -> str:
52151
"""Return the configured AI API endpoint URL."""
53152
return os.getenv("AI_API_ENDPOINT", default="https://models.github.ai/inference")
@@ -64,82 +163,54 @@ def get_AI_token() -> str:
64163
raise RuntimeError("AI_API_TOKEN environment variable is not set.")
65164

66165

67-
# assume we are >= python 3.9 for our type hints
68-
def list_capi_models(token: str) -> dict[str, dict]:
69-
"""Retrieve a dictionary of available CAPI models"""
70-
models = {}
166+
# ---------------------------------------------------------------------------
167+
# Model catalog
168+
# ---------------------------------------------------------------------------
169+
170+
def list_capi_models(token: str, endpoint: str | None = None) -> dict[str, dict]:
171+
"""Retrieve available models from the configured API endpoint.
172+
173+
Args:
174+
token: Bearer token for authentication.
175+
endpoint: Optional endpoint URL override (defaults to env config).
176+
"""
177+
provider = get_provider(endpoint)
178+
base = provider.base_url
179+
models: dict[str, dict] = {}
71180
try:
72-
api_endpoint = get_AI_endpoint()
73-
netloc = urlparse(api_endpoint).netloc
74-
match netloc:
75-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
76-
models_catalog = "models"
77-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
78-
models_catalog = "catalog/models"
79-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
80-
models_catalog = "models"
81-
case _:
82-
# Unknown endpoint — try the OpenAI-style models catalog
83-
models_catalog = "models"
181+
headers = {
182+
"Accept": "application/json",
183+
"Authorization": f"Bearer {token}",
184+
**provider.extra_headers,
185+
}
84186
r = httpx.get(
85-
httpx.URL(api_endpoint).join(models_catalog),
86-
headers={
87-
"Accept": "application/json",
88-
"Authorization": f"Bearer {token}",
89-
"Copilot-Integration-Id": COPILOT_INTEGRATION_ID,
90-
},
187+
httpx.URL(base).join(provider.models_catalog),
188+
headers=headers,
91189
)
92190
r.raise_for_status()
93-
# CAPI vs Models API
94-
match netloc:
95-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
96-
models_list = r.json().get("data", [])
97-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
98-
models_list = r.json()
99-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
100-
models_list = r.json().get("data", [])
101-
case _:
102-
# Unknown endpoint — try common response shapes
103-
body = r.json()
104-
if isinstance(body, dict):
105-
models_list = body.get("data", [])
106-
elif isinstance(body, list):
107-
models_list = body
108-
else:
109-
models_list = []
110-
for model in models_list:
191+
for model in provider.parse_models_list(r.json()):
111192
models[model.get("id")] = dict(model)
112-
except httpx.RequestError:
113-
logging.exception("Request error")
114-
except json.JSONDecodeError:
115-
logging.exception("JSON error")
116-
except httpx.HTTPStatusError:
117-
logging.exception("HTTP error")
193+
except (httpx.RequestError, httpx.HTTPStatusError, json.JSONDecodeError):
194+
logging.exception("Failed to list models from %s", base)
118195
return models
119196

120197

121-
def supports_tool_calls(model: str, models: dict[str, dict]) -> bool:
122-
"""Check whether the given model supports tool calls."""
123-
api_endpoint = get_AI_endpoint()
124-
netloc = urlparse(api_endpoint).netloc
125-
match netloc:
126-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
127-
return models.get(model, {}).get("capabilities", {}).get("supports", {}).get("tool_calls", False)
128-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
129-
return "tool-calling" in models.get(model, {}).get("capabilities", [])
130-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
131-
return "gpt-" in model.lower()
132-
case _:
133-
# Unknown endpoint — optimistically assume tool-call support
134-
# if the model is present in the catalog.
135-
return model in models
136-
137-
138-
def list_tool_call_models(token: str) -> dict[str, dict]:
198+
def supports_tool_calls(
199+
model: str,
200+
models: dict[str, dict],
201+
endpoint: str | None = None,
202+
) -> bool:
203+
"""Check whether *model* supports tool calls."""
204+
provider = get_provider(endpoint)
205+
return provider.check_tool_calls(model, models.get(model, {}))
206+
207+
208+
def list_tool_call_models(token: str, endpoint: str | None = None) -> dict[str, dict]:
139209
"""Return only models that support tool calls."""
140-
models = list_capi_models(token)
141-
tool_models: dict[str, dict] = {}
142-
for model in models:
143-
if supports_tool_calls(model, models) is True:
144-
tool_models[model] = models[model]
145-
return tool_models
210+
models = list_capi_models(token, endpoint)
211+
provider = get_provider(endpoint)
212+
return {
213+
mid: info
214+
for mid, info in models.items()
215+
if provider.check_tool_calls(mid, info)
216+
}

0 commit comments

Comments
 (0)