|
14 | 14 | import json |
15 | 15 | import logging |
16 | 16 | import os |
| 17 | +from collections.abc import Mapping |
17 | 18 | from dataclasses import dataclass, field |
18 | 19 | from types import MappingProxyType |
19 | 20 | from typing import Any |
@@ -45,13 +46,17 @@ class APIProvider: |
45 | 46 |
|
46 | 47 | name: str |
47 | 48 | base_url: str |
48 | | - models_catalog: str = "models" |
| 49 | + models_catalog: str = "/models" |
49 | 50 | default_model: str = "gpt-4.1" |
50 | | - extra_headers: dict[str, str] = field(default_factory=dict) |
| 51 | + extra_headers: Mapping[str, str] = field(default_factory=dict) |
51 | 52 |
|
52 | 53 | def __post_init__(self) -> None: |
| 54 | + # Ensure base_url ends with / so httpx URL.join() preserves the path |
| 55 | + if self.base_url and not self.base_url.endswith("/"): |
| 56 | + object.__setattr__(self, "base_url", self.base_url + "/") |
53 | 57 | # Freeze mutable headers so singleton providers can't be mutated |
54 | | - object.__setattr__(self, "extra_headers", MappingProxyType(self.extra_headers)) |
| 58 | + if isinstance(self.extra_headers, dict): |
| 59 | + object.__setattr__(self, "extra_headers", MappingProxyType(self.extra_headers)) |
55 | 60 |
|
56 | 61 | # -- response parsing ----------------------------------------------------- |
57 | 62 |
|
@@ -123,12 +128,13 @@ def check_tool_calls(self, _model: str, model_info: dict) -> bool: |
123 | 128 | "models.github.ai": _GitHubModelsProvider( |
124 | 129 | name="github-models", |
125 | 130 | base_url="https://models.github.ai/inference", |
126 | | - models_catalog="catalog/models", |
| 131 | + models_catalog="/catalog/models", |
127 | 132 | default_model="openai/gpt-4.1", |
128 | 133 | ), |
129 | 134 | "api.openai.com": _OpenAIProvider( |
130 | 135 | name="openai", |
131 | 136 | base_url="https://api.openai.com/v1", |
| 137 | + models_catalog="/v1/models", |
132 | 138 | default_model="gpt-4.1", |
133 | 139 | ), |
134 | 140 | } |
|
0 commit comments