diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 8bb43cc..fa73992 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -11,6 +11,7 @@ from __future__ import annotations +import dataclasses import json import logging import os @@ -140,12 +141,35 @@ def check_tool_calls(self, _model: str, model_info: dict) -> bool: } def get_provider(endpoint: str | None = None) -> APIProvider: - """Return the ``APIProvider`` for the given (or configured) endpoint URL.""" + """Return the ``APIProvider`` for the given (or configured) endpoint URL. + + When running inside an AWF (Agentic Workflow Firewall) sandbox, the + ``AWF_COPILOT_PROXY`` env var names the upstream provider whose behaviour + (headers, model defaults, catalog format) the local proxy mirrors. + The proxy URL is used as ``base_url`` while all other provider traits + come from the named upstream. + + ``AWF_COPILOT_PROXY`` accepts either a bare hostname + (``api.githubcopilot.com``) or a full URL + (``https://api.githubcopilot.com``). + """ url = endpoint or get_AI_endpoint() netloc = urlparse(url).netloc provider = _PROVIDERS.get(netloc) if provider is not None: return provider + + # AWF proxy support: AWF_COPILOT_PROXY names the upstream provider + # (e.g. "api.githubcopilot.com") whose behaviour this proxy mirrors. + awf_upstream = os.getenv("AWF_COPILOT_PROXY", "").strip() + if awf_upstream: + # Normalize: accept both bare hostnames and full URLs. + parsed = urlparse(awf_upstream) + key = parsed.netloc or parsed.path + upstream = _PROVIDERS.get(key) + if upstream: + return dataclasses.replace(upstream, base_url=url) + # Unknown endpoint — return a generic provider with the given base URL return APIProvider(name="custom", base_url=url, default_model="please-set-default-model-via-env") diff --git a/tests/test_capi_extended.py b/tests/test_capi_extended.py index 28a3d06..e3a1188 100644 --- a/tests/test_capi_extended.py +++ b/tests/test_capi_extended.py @@ -123,3 +123,26 @@ def test_custom_endpoint(self): assert p.name == "custom" assert p.base_url == "https://my-custom-llm.example.com/v1/" assert not p.extra_headers + + def test_awf_proxy_bare_hostname(self, monkeypatch): + monkeypatch.setenv("AWF_COPILOT_PROXY", "api.githubcopilot.com") + p = get_provider("http://172.30.0.30:10002") + assert p.name == "copilot" + assert p.base_url == "http://172.30.0.30:10002/" + assert p.default_model == "gpt-4.1" + assert "Copilot-Integration-Id" in p.extra_headers + + def test_awf_proxy_full_url(self, monkeypatch): + monkeypatch.setenv("AWF_COPILOT_PROXY", "https://api.githubcopilot.com") + p = get_provider("http://172.30.0.30:10002") + assert p.name == "copilot" + assert p.base_url == "http://172.30.0.30:10002/" + + def test_awf_proxy_unknown_upstream(self, monkeypatch): + monkeypatch.setenv("AWF_COPILOT_PROXY", "not-a-real-provider.com") + p = get_provider("http://172.30.0.30:10002") + assert p.name == "custom" + + def test_awf_proxy_not_set(self): + p = get_provider("http://172.30.0.30:10002") + assert p.name == "custom"