Skip to content

Commit 1879d01

Browse files
committed
Address review feedback: dataclasses.replace, tests, input normalization
- Use dataclasses.replace() instead of manual field reconstruction to avoid silently dropping fields if APIProvider gains new ones - Normalize AWF_COPILOT_PROXY input: accept both bare hostnames (api.githubcopilot.com) and full URLs (https://api.githubcopilot.com) - Add 4 test cases covering bare hostname, full URL, unknown upstream, and unset env var scenarios
1 parent 4a4d21e commit 1879d01

2 files changed

Lines changed: 34 additions & 9 deletions

File tree

src/seclab_taskflow_agent/capi.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import dataclasses
1415
import json
1516
import logging
1617
import os
@@ -147,6 +148,10 @@ def get_provider(endpoint: str | None = None) -> APIProvider:
147148
(headers, model defaults, catalog format) the local proxy mirrors.
148149
The proxy URL is used as ``base_url`` while all other provider traits
149150
come from the named upstream.
151+
152+
``AWF_COPILOT_PROXY`` accepts either a bare hostname
153+
(``api.githubcopilot.com``) or a full URL
154+
(``https://api.githubcopilot.com``).
150155
"""
151156
url = endpoint or get_AI_endpoint()
152157
netloc = urlparse(url).netloc
@@ -156,17 +161,14 @@ def get_provider(endpoint: str | None = None) -> APIProvider:
156161

157162
# AWF proxy support: AWF_COPILOT_PROXY names the upstream provider
158163
# (e.g. "api.githubcopilot.com") whose behaviour this proxy mirrors.
159-
awf_upstream = os.getenv("AWF_COPILOT_PROXY")
164+
awf_upstream = os.getenv("AWF_COPILOT_PROXY", "").strip()
160165
if awf_upstream:
161-
upstream = _PROVIDERS.get(awf_upstream)
166+
# Normalize: accept both bare hostnames and full URLs.
167+
parsed = urlparse(awf_upstream)
168+
key = parsed.netloc or parsed.path
169+
upstream = _PROVIDERS.get(key)
162170
if upstream:
163-
return type(upstream)(
164-
name=upstream.name,
165-
base_url=url,
166-
models_catalog=upstream.models_catalog,
167-
default_model=upstream.default_model,
168-
extra_headers=dict(upstream.extra_headers),
169-
)
171+
return dataclasses.replace(upstream, base_url=url)
170172

171173
# Unknown endpoint — return a generic provider with the given base URL
172174
return APIProvider(name="custom", base_url=url, default_model="please-set-default-model-via-env")

tests/test_capi_extended.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,26 @@ def test_custom_endpoint(self):
123123
assert p.name == "custom"
124124
assert p.base_url == "https://my-custom-llm.example.com/v1/"
125125
assert not p.extra_headers
126+
127+
def test_awf_proxy_bare_hostname(self, monkeypatch):
128+
monkeypatch.setenv("AWF_COPILOT_PROXY", "api.githubcopilot.com")
129+
p = get_provider("http://172.30.0.30:10002")
130+
assert p.name == "copilot"
131+
assert p.base_url == "http://172.30.0.30:10002/"
132+
assert p.default_model == "gpt-4.1"
133+
assert "Copilot-Integration-Id" in p.extra_headers
134+
135+
def test_awf_proxy_full_url(self, monkeypatch):
136+
monkeypatch.setenv("AWF_COPILOT_PROXY", "https://api.githubcopilot.com")
137+
p = get_provider("http://172.30.0.30:10002")
138+
assert p.name == "copilot"
139+
assert p.base_url == "http://172.30.0.30:10002/"
140+
141+
def test_awf_proxy_unknown_upstream(self, monkeypatch):
142+
monkeypatch.setenv("AWF_COPILOT_PROXY", "not-a-real-provider.com")
143+
p = get_provider("http://172.30.0.30:10002")
144+
assert p.name == "custom"
145+
146+
def test_awf_proxy_not_set(self):
147+
p = get_provider("http://172.30.0.30:10002")
148+
assert p.name == "custom"

0 commit comments

Comments
 (0)