Skip to content

Commit b71c2e0

Browse files
authored
Merge pull request #229 from GitHubSecurityLab/anticomputer/awf-proxy-support
Support AWF proxy endpoints via AWF_COPILOT_PROXY
2 parents f139ac8 + 1879d01 commit b71c2e0

2 files changed

Lines changed: 48 additions & 1 deletion

File tree

src/seclab_taskflow_agent/capi.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import dataclasses
1415
import json
1516
import logging
1617
import os
@@ -140,12 +141,35 @@ def check_tool_calls(self, _model: str, model_info: dict) -> bool:
140141
}
141142

142143
def get_provider(endpoint: str | None = None) -> APIProvider:
143-
"""Return the ``APIProvider`` for the given (or configured) endpoint URL."""
144+
"""Return the ``APIProvider`` for the given (or configured) endpoint URL.
145+
146+
When running inside an AWF (Agentic Workflow Firewall) sandbox, the
147+
``AWF_COPILOT_PROXY`` env var names the upstream provider whose behaviour
148+
(headers, model defaults, catalog format) the local proxy mirrors.
149+
The proxy URL is used as ``base_url`` while all other provider traits
150+
come from the named upstream.
151+
152+
``AWF_COPILOT_PROXY`` accepts either a bare hostname
153+
(``api.githubcopilot.com``) or a full URL
154+
(``https://api.githubcopilot.com``).
155+
"""
144156
url = endpoint or get_AI_endpoint()
145157
netloc = urlparse(url).netloc
146158
provider = _PROVIDERS.get(netloc)
147159
if provider is not None:
148160
return provider
161+
162+
# AWF proxy support: AWF_COPILOT_PROXY names the upstream provider
163+
# (e.g. "api.githubcopilot.com") whose behaviour this proxy mirrors.
164+
awf_upstream = os.getenv("AWF_COPILOT_PROXY", "").strip()
165+
if awf_upstream:
166+
# Normalize: accept both bare hostnames and full URLs.
167+
parsed = urlparse(awf_upstream)
168+
key = parsed.netloc or parsed.path
169+
upstream = _PROVIDERS.get(key)
170+
if upstream:
171+
return dataclasses.replace(upstream, base_url=url)
172+
149173
# Unknown endpoint — return a generic provider with the given base URL
150174
return APIProvider(name="custom", base_url=url, default_model="please-set-default-model-via-env")
151175

tests/test_capi_extended.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,26 @@ def test_custom_endpoint(self):
123123
assert p.name == "custom"
124124
assert p.base_url == "https://my-custom-llm.example.com/v1/"
125125
assert not p.extra_headers
126+
127+
def test_awf_proxy_bare_hostname(self, monkeypatch):
128+
monkeypatch.setenv("AWF_COPILOT_PROXY", "api.githubcopilot.com")
129+
p = get_provider("http://172.30.0.30:10002")
130+
assert p.name == "copilot"
131+
assert p.base_url == "http://172.30.0.30:10002/"
132+
assert p.default_model == "gpt-4.1"
133+
assert "Copilot-Integration-Id" in p.extra_headers
134+
135+
def test_awf_proxy_full_url(self, monkeypatch):
136+
monkeypatch.setenv("AWF_COPILOT_PROXY", "https://api.githubcopilot.com")
137+
p = get_provider("http://172.30.0.30:10002")
138+
assert p.name == "copilot"
139+
assert p.base_url == "http://172.30.0.30:10002/"
140+
141+
def test_awf_proxy_unknown_upstream(self, monkeypatch):
142+
monkeypatch.setenv("AWF_COPILOT_PROXY", "not-a-real-provider.com")
143+
p = get_provider("http://172.30.0.30:10002")
144+
assert p.name == "custom"
145+
146+
def test_awf_proxy_not_set(self):
147+
p = get_provider("http://172.30.0.30:10002")
148+
assert p.name == "custom"

0 commit comments

Comments
 (0)