11# SPDX-FileCopyrightText: GitHub, Inc.
22# SPDX-License-Identifier: MIT
33
4- """AI API endpoint and token management (CAPI integration)."""
4+ """AI API endpoint and token management.
5+
6+ Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and
7+ custom endpoints). All provider-specific behaviour is captured in a single
8+ ``APIProvider`` dataclass so that adding a new provider only requires one
9+ registry entry instead of changes scattered across multiple match/case blocks.
10+ """
11+
12+ from __future__ import annotations
513
614import json
715import logging
816import os
17+ from dataclasses import dataclass , field
18+ from types import MappingProxyType
19+ from typing import Any
920from urllib .parse import urlparse
1021
1122import httpx
12- from strenum import StrEnum
1323
1424__all__ = [
15- "AI_API_ENDPOINT_ENUM" ,
1625 "COPILOT_INTEGRATION_ID" ,
26+ "APIProvider" ,
1727 "get_AI_endpoint" ,
1828 "get_AI_token" ,
29+ "get_provider" ,
1930 "list_capi_models" ,
2031 "list_tool_call_models" ,
2132 "supports_tool_calls" ,
2233]
2334
35+ COPILOT_INTEGRATION_ID = "vscode-chat"
2436
25- # Enumeration of currently supported API endpoints.
26- class AI_API_ENDPOINT_ENUM (StrEnum ):
27- AI_API_MODELS_GITHUB = "models.github.ai"
28- AI_API_GITHUBCOPILOT = "api.githubcopilot.com"
29- AI_API_OPENAI = "api.openai.com"
3037
31- def to_url (self ) -> str :
32- """Convert the endpoint to its full URL."""
33- match self :
34- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
35- return f"https://{ self } "
36- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
37- return f"https://{ self } /inference"
38- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
39- return f"https://{ self } /v1"
40- case _:
41- raise ValueError (f"Unsupported endpoint: { self } " )
38+ # ---------------------------------------------------------------------------
39+ # Provider abstraction
40+ # ---------------------------------------------------------------------------
4241
42+ @dataclass (frozen = True )
43+ class APIProvider :
44+ """Encapsulates all endpoint-specific behaviour in one place."""
4345
44- COPILOT_INTEGRATION_ID = "vscode-chat"
46+ name : str
47+ base_url : str
48+ models_catalog : str = "models"
49+ default_model : str = "gpt-4.1"
50+ extra_headers : dict [str , str ] = field (default_factory = dict )
51+
52+ def __post_init__ (self ) -> None :
53+ # Freeze mutable headers so singleton providers can't be mutated
54+ object .__setattr__ (self , "extra_headers" , MappingProxyType (self .extra_headers ))
55+
56+ # -- response parsing -----------------------------------------------------
57+
58+ def parse_models_list (self , body : Any ) -> list [dict ]:
59+ """Extract the models list from a catalog response body."""
60+ if isinstance (body , list ):
61+ return body
62+ if isinstance (body , dict ):
63+ data = body .get ("data" , [])
64+ return data if isinstance (data , list ) else []
65+ return []
66+
67+ # -- tool-call capability check -------------------------------------------
68+
69+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
70+ """Return True if *model* supports tool calls according to its catalog entry."""
71+ # Default: optimistically assume support when present in catalog
72+ return bool (model_info )
4573
4674
47- # you can also set https://api.githubcopilot.com if you prefer
48- # but beware that your taskflows need to reference the correct model id
49- # since different APIs use their own id schema, use -l with your desired
50- # endpoint to retrieve the correct id names to use for your taskflow
75+ class _CopilotProvider (APIProvider ):
76+ """GitHub Copilot API (api.githubcopilot.com)."""
77+
78+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
79+ return (
80+ model_info
81+ .get ("capabilities" , {})
82+ .get ("supports" , {})
83+ .get ("tool_calls" , False )
84+ )
85+
86+
87+ class _GitHubModelsProvider (APIProvider ):
88+ """GitHub Models API (models.github.ai)."""
89+
90+ def parse_models_list (self , body : Any ) -> list [dict ]:
91+ # Models API returns a bare list, not {"data": [...]}
92+ if isinstance (body , list ):
93+ return body
94+ return super ().parse_models_list (body )
95+
96+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
97+ return "tool-calling" in model_info .get ("capabilities" , [])
98+
99+
100+ class _OpenAIProvider (APIProvider ):
101+ """OpenAI API (api.openai.com).
102+
103+ The OpenAI /v1/models catalog does not expose capability metadata, so
104+ we maintain a prefix allowlist of known chat-completion model families.
105+ """
106+
107+ _CHAT_PREFIXES = ("gpt-3.5" , "gpt-4" , "o1" , "o3" , "o4" , "chatgpt-" )
108+
109+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
110+ model_id = model_info .get ("id" , "" ).lower ()
111+ return any (model_id .startswith (p ) for p in self ._CHAT_PREFIXES )
112+ # ---------------------------------------------------------------------------
113+ # Provider registry — add new providers here
114+ # ---------------------------------------------------------------------------
115+
116+ _PROVIDERS : dict [str , APIProvider ] = {
117+ "api.githubcopilot.com" : _CopilotProvider (
118+ name = "copilot" ,
119+ base_url = "https://api.githubcopilot.com" ,
120+ default_model = "gpt-4.1" ,
121+ extra_headers = {"Copilot-Integration-Id" : COPILOT_INTEGRATION_ID },
122+ ),
123+ "models.github.ai" : _GitHubModelsProvider (
124+ name = "github-models" ,
125+ base_url = "https://models.github.ai/inference" ,
126+ models_catalog = "catalog/models" ,
127+ default_model = "openai/gpt-4.1" ,
128+ ),
129+ "api.openai.com" : _OpenAIProvider (
130+ name = "openai" ,
131+ base_url = "https://api.openai.com/v1" ,
132+ default_model = "gpt-4.1" ,
133+ ),
134+ }
135+
136+ def get_provider (endpoint : str | None = None ) -> APIProvider :
137+ """Return the ``APIProvider`` for the given (or configured) endpoint URL."""
138+ url = endpoint or get_AI_endpoint ()
139+ netloc = urlparse (url ).netloc
140+ provider = _PROVIDERS .get (netloc )
141+ if provider is not None :
142+ return provider
143+ # Unknown endpoint — return a generic provider with the given base URL
144+ return APIProvider (name = "custom" , base_url = url , default_model = "please-set-default-model-via-env" )
145+
146+
147+ # ---------------------------------------------------------------------------
148+ # Endpoint / token helpers
149+ # ---------------------------------------------------------------------------
150+
51151def get_AI_endpoint () -> str :
52152 """Return the configured AI API endpoint URL."""
53153 return os .getenv ("AI_API_ENDPOINT" , default = "https://models.github.ai/inference" )
@@ -64,82 +164,54 @@ def get_AI_token() -> str:
64164 raise RuntimeError ("AI_API_TOKEN environment variable is not set." )
65165
66166
67- # assume we are >= python 3.9 for our type hints
68- def list_capi_models (token : str ) -> dict [str , dict ]:
69- """Retrieve a dictionary of available CAPI models"""
70- models = {}
167+ # ---------------------------------------------------------------------------
168+ # Model catalog
169+ # ---------------------------------------------------------------------------
170+
171+ def list_capi_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
172+ """Retrieve available models from the configured API endpoint.
173+
174+ Args:
175+ token: Bearer token for authentication.
176+ endpoint: Optional endpoint URL override (defaults to env config).
177+ """
178+ provider = get_provider (endpoint )
179+ base = provider .base_url
180+ models : dict [str , dict ] = {}
71181 try :
72- api_endpoint = get_AI_endpoint ()
73- netloc = urlparse (api_endpoint ).netloc
74- match netloc :
75- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
76- models_catalog = "models"
77- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
78- models_catalog = "catalog/models"
79- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
80- models_catalog = "models"
81- case _:
82- # Unknown endpoint — try the OpenAI-style models catalog
83- models_catalog = "models"
182+ headers = {
183+ "Accept" : "application/json" ,
184+ "Authorization" : f"Bearer { token } " ,
185+ ** provider .extra_headers ,
186+ }
84187 r = httpx .get (
85- httpx .URL (api_endpoint ).join (models_catalog ),
86- headers = {
87- "Accept" : "application/json" ,
88- "Authorization" : f"Bearer { token } " ,
89- "Copilot-Integration-Id" : COPILOT_INTEGRATION_ID ,
90- },
188+ httpx .URL (base ).join (provider .models_catalog ),
189+ headers = headers ,
91190 )
92191 r .raise_for_status ()
93- # CAPI vs Models API
94- match netloc :
95- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
96- models_list = r .json ().get ("data" , [])
97- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
98- models_list = r .json ()
99- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
100- models_list = r .json ().get ("data" , [])
101- case _:
102- # Unknown endpoint — try common response shapes
103- body = r .json ()
104- if isinstance (body , dict ):
105- models_list = body .get ("data" , [])
106- elif isinstance (body , list ):
107- models_list = body
108- else :
109- models_list = []
110- for model in models_list :
192+ for model in provider .parse_models_list (r .json ()):
111193 models [model .get ("id" )] = dict (model )
112- except httpx .RequestError :
113- logging .exception ("Request error" )
114- except json .JSONDecodeError :
115- logging .exception ("JSON error" )
116- except httpx .HTTPStatusError :
117- logging .exception ("HTTP error" )
194+ except (httpx .RequestError , httpx .HTTPStatusError , json .JSONDecodeError ):
195+ logging .exception ("Failed to list models from %s" , base )
118196 return models
119197
120198
121- def supports_tool_calls (model : str , models : dict [str , dict ]) -> bool :
122- """Check whether the given model supports tool calls."""
123- api_endpoint = get_AI_endpoint ()
124- netloc = urlparse (api_endpoint ).netloc
125- match netloc :
126- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
127- return models .get (model , {}).get ("capabilities" , {}).get ("supports" , {}).get ("tool_calls" , False )
128- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
129- return "tool-calling" in models .get (model , {}).get ("capabilities" , [])
130- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
131- return "gpt-" in model .lower ()
132- case _:
133- # Unknown endpoint — optimistically assume tool-call support
134- # if the model is present in the catalog.
135- return model in models
136-
137-
138- def list_tool_call_models (token : str ) -> dict [str , dict ]:
199+ def supports_tool_calls (
200+ model : str ,
201+ models : dict [str , dict ],
202+ endpoint : str | None = None ,
203+ ) -> bool :
204+ """Check whether *model* supports tool calls."""
205+ provider = get_provider (endpoint )
206+ return provider .check_tool_calls (model , models .get (model , {}))
207+
208+
209+ def list_tool_call_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
139210 """Return only models that support tool calls."""
140- models = list_capi_models (token )
141- tool_models : dict [str , dict ] = {}
142- for model in models :
143- if supports_tool_calls (model , models ) is True :
144- tool_models [model ] = models [model ]
145- return tool_models
211+ models = list_capi_models (token , endpoint )
212+ provider = get_provider (endpoint )
213+ return {
214+ mid : info
215+ for mid , info in models .items ()
216+ if provider .check_tool_calls (mid , info )
217+ }
0 commit comments