11# SPDX-FileCopyrightText: GitHub, Inc.
22# SPDX-License-Identifier: MIT
33
4- """AI API endpoint and token management (CAPI integration)."""
4+ """AI API endpoint and token management.
5+
6+ Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and
7+ custom endpoints). All provider-specific behaviour is captured in a single
8+ ``APIProvider`` dataclass so that adding a new provider only requires one
9+ registry entry instead of changes scattered across multiple match/case blocks.
10+ """
11+
12+ from __future__ import annotations
513
614import json
715import logging
816import os
17+ from dataclasses import dataclass , field
18+ from types import MappingProxyType
19+ from typing import Any
920from urllib .parse import urlparse
1021
1122import httpx
12- from strenum import StrEnum
1323
1424__all__ = [
15- "AI_API_ENDPOINT_ENUM" ,
1625 "COPILOT_INTEGRATION_ID" ,
26+ "APIProvider" ,
1727 "get_AI_endpoint" ,
1828 "get_AI_token" ,
29+ "get_provider" ,
1930 "list_capi_models" ,
2031 "list_tool_call_models" ,
2132 "supports_tool_calls" ,
2233]
2334
35+ COPILOT_INTEGRATION_ID = "vscode-chat"
2436
25- # Enumeration of currently supported API endpoints.
26- class AI_API_ENDPOINT_ENUM (StrEnum ):
27- AI_API_MODELS_GITHUB = "models.github.ai"
28- AI_API_GITHUBCOPILOT = "api.githubcopilot.com"
29- AI_API_OPENAI = "api.openai.com"
3037
31- def to_url (self ) -> str :
32- """Convert the endpoint to its full URL."""
33- match self :
34- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
35- return f"https://{ self } "
36- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
37- return f"https://{ self } /inference"
38- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
39- return f"https://{ self } /v1"
40- case _:
41- raise ValueError (f"Unsupported endpoint: { self } " )
38+ # ---------------------------------------------------------------------------
39+ # Provider abstraction
40+ # ---------------------------------------------------------------------------
4241
42+ @dataclass (frozen = True )
43+ class APIProvider :
44+ """Encapsulates all endpoint-specific behaviour in one place."""
4345
44- COPILOT_INTEGRATION_ID = "vscode-chat"
46+ name : str
47+ base_url : str
48+ models_catalog : str = "models"
49+ default_model : str = "gpt-4.1"
50+ extra_headers : dict [str , str ] = field (default_factory = dict )
51+
52+ def __post_init__ (self ) -> None :
53+ # Freeze mutable headers so singleton providers can't be mutated
54+ object .__setattr__ (self , "extra_headers" , MappingProxyType (self .extra_headers ))
55+
56+ # -- response parsing -----------------------------------------------------
57+
58+ def parse_models_list (self , body : Any ) -> list [dict ]:
59+ """Extract the models list from a catalog response body."""
60+ if isinstance (body , list ):
61+ return body
62+ if isinstance (body , dict ):
63+ return body .get ("data" , [])
64+ return []
65+
66+ # -- tool-call capability check -------------------------------------------
67+
68+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
69+ """Return True if *model* supports tool calls according to its catalog entry."""
70+ # Default: optimistically assume support when present in catalog
71+ return bool (model_info )
4572
4673
47- # you can also set https://api.githubcopilot.com if you prefer
48- # but beware that your taskflows need to reference the correct model id
49- # since different APIs use their own id schema, use -l with your desired
50- # endpoint to retrieve the correct id names to use for your taskflow
74+ class _CopilotProvider (APIProvider ):
75+ """GitHub Copilot API (api.githubcopilot.com)."""
76+
77+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
78+ return (
79+ model_info
80+ .get ("capabilities" , {})
81+ .get ("supports" , {})
82+ .get ("tool_calls" , False )
83+ )
84+
85+
86+ class _GitHubModelsProvider (APIProvider ):
87+ """GitHub Models API (models.github.ai)."""
88+
89+ def parse_models_list (self , body : Any ) -> list [dict ]:
90+ # Models API returns a bare list, not {"data": [...]}
91+ if isinstance (body , list ):
92+ return body
93+ return super ().parse_models_list (body )
94+
95+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
96+ return "tool-calling" in model_info .get ("capabilities" , [])
97+
98+
99+ class _OpenAIProvider (APIProvider ):
100+ """OpenAI API (api.openai.com).
101+
102+ The OpenAI /v1/models catalog does not expose capability metadata, so
103+ we maintain a prefix allowlist of known chat-completion model families.
104+ """
105+
106+ _CHAT_PREFIXES = ("gpt-3.5" , "gpt-4" , "o1" , "o3" , "o4" , "chatgpt-" )
107+
108+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
109+ model_id = model_info .get ("id" , "" ).lower ()
110+ return any (model_id .startswith (p ) for p in self ._CHAT_PREFIXES )
111+ # ---------------------------------------------------------------------------
112+ # Provider registry — add new providers here
113+ # ---------------------------------------------------------------------------
114+
115+ _PROVIDERS : dict [str , APIProvider ] = {
116+ "api.githubcopilot.com" : _CopilotProvider (
117+ name = "copilot" ,
118+ base_url = "https://api.githubcopilot.com" ,
119+ default_model = "gpt-4.1" ,
120+ extra_headers = {"Copilot-Integration-Id" : COPILOT_INTEGRATION_ID },
121+ ),
122+ "models.github.ai" : _GitHubModelsProvider (
123+ name = "github-models" ,
124+ base_url = "https://models.github.ai/inference" ,
125+ models_catalog = "catalog/models" ,
126+ default_model = "openai/gpt-4.1" ,
127+ ),
128+ "api.openai.com" : _OpenAIProvider (
129+ name = "openai" ,
130+ base_url = "https://api.openai.com/v1" ,
131+ default_model = "gpt-4.1" ,
132+ ),
133+ }
134+
135+ def get_provider (endpoint : str | None = None ) -> APIProvider :
136+ """Return the ``APIProvider`` for the given (or configured) endpoint URL."""
137+ url = endpoint or get_AI_endpoint ()
138+ netloc = urlparse (url ).netloc
139+ provider = _PROVIDERS .get (netloc )
140+ if provider is not None :
141+ return provider
142+ # Unknown endpoint — return a generic provider with the given base URL
143+ return APIProvider (name = "custom" , base_url = url )
144+
145+
146+ # ---------------------------------------------------------------------------
147+ # Endpoint / token helpers
148+ # ---------------------------------------------------------------------------
149+
51150def get_AI_endpoint () -> str :
52151 """Return the configured AI API endpoint URL."""
53152 return os .getenv ("AI_API_ENDPOINT" , default = "https://models.github.ai/inference" )
@@ -64,82 +163,54 @@ def get_AI_token() -> str:
64163 raise RuntimeError ("AI_API_TOKEN environment variable is not set." )
65164
66165
67- # assume we are >= python 3.9 for our type hints
68- def list_capi_models (token : str ) -> dict [str , dict ]:
69- """Retrieve a dictionary of available CAPI models"""
70- models = {}
166+ # ---------------------------------------------------------------------------
167+ # Model catalog
168+ # ---------------------------------------------------------------------------
169+
170+ def list_capi_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
171+ """Retrieve available models from the configured API endpoint.
172+
173+ Args:
174+ token: Bearer token for authentication.
175+ endpoint: Optional endpoint URL override (defaults to env config).
176+ """
177+ provider = get_provider (endpoint )
178+ base = provider .base_url
179+ models : dict [str , dict ] = {}
71180 try :
72- api_endpoint = get_AI_endpoint ()
73- netloc = urlparse (api_endpoint ).netloc
74- match netloc :
75- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
76- models_catalog = "models"
77- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
78- models_catalog = "catalog/models"
79- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
80- models_catalog = "models"
81- case _:
82- # Unknown endpoint — try the OpenAI-style models catalog
83- models_catalog = "models"
181+ headers = {
182+ "Accept" : "application/json" ,
183+ "Authorization" : f"Bearer { token } " ,
184+ ** provider .extra_headers ,
185+ }
84186 r = httpx .get (
85- httpx .URL (api_endpoint ).join (models_catalog ),
86- headers = {
87- "Accept" : "application/json" ,
88- "Authorization" : f"Bearer { token } " ,
89- "Copilot-Integration-Id" : COPILOT_INTEGRATION_ID ,
90- },
187+ httpx .URL (base ).join (provider .models_catalog ),
188+ headers = headers ,
91189 )
92190 r .raise_for_status ()
93- # CAPI vs Models API
94- match netloc :
95- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
96- models_list = r .json ().get ("data" , [])
97- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
98- models_list = r .json ()
99- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
100- models_list = r .json ().get ("data" , [])
101- case _:
102- # Unknown endpoint — try common response shapes
103- body = r .json ()
104- if isinstance (body , dict ):
105- models_list = body .get ("data" , [])
106- elif isinstance (body , list ):
107- models_list = body
108- else :
109- models_list = []
110- for model in models_list :
191+ for model in provider .parse_models_list (r .json ()):
111192 models [model .get ("id" )] = dict (model )
112- except httpx .RequestError :
113- logging .exception ("Request error" )
114- except json .JSONDecodeError :
115- logging .exception ("JSON error" )
116- except httpx .HTTPStatusError :
117- logging .exception ("HTTP error" )
193+ except (httpx .RequestError , httpx .HTTPStatusError , json .JSONDecodeError ):
194+ logging .exception ("Failed to list models from %s" , base )
118195 return models
119196
120197
121- def supports_tool_calls (model : str , models : dict [str , dict ]) -> bool :
122- """Check whether the given model supports tool calls."""
123- api_endpoint = get_AI_endpoint ()
124- netloc = urlparse (api_endpoint ).netloc
125- match netloc :
126- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
127- return models .get (model , {}).get ("capabilities" , {}).get ("supports" , {}).get ("tool_calls" , False )
128- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
129- return "tool-calling" in models .get (model , {}).get ("capabilities" , [])
130- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
131- return "gpt-" in model .lower ()
132- case _:
133- # Unknown endpoint — optimistically assume tool-call support
134- # if the model is present in the catalog.
135- return model in models
136-
137-
138- def list_tool_call_models (token : str ) -> dict [str , dict ]:
198+ def supports_tool_calls (
199+ model : str ,
200+ models : dict [str , dict ],
201+ endpoint : str | None = None ,
202+ ) -> bool :
203+ """Check whether *model* supports tool calls."""
204+ provider = get_provider (endpoint )
205+ return provider .check_tool_calls (model , models .get (model , {}))
206+
207+
208+ def list_tool_call_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
139209 """Return only models that support tool calls."""
140- models = list_capi_models (token )
141- tool_models : dict [str , dict ] = {}
142- for model in models :
143- if supports_tool_calls (model , models ) is True :
144- tool_models [model ] = models [model ]
145- return tool_models
210+ models = list_capi_models (token , endpoint )
211+ provider = get_provider (endpoint )
212+ return {
213+ mid : info
214+ for mid , info in models .items ()
215+ if provider .check_tool_calls (mid , info )
216+ }
0 commit comments