22
33import os
44from datetime import timedelta
5- from typing import cast
5+ from typing import Protocol , cast
66
77import pytest
88import requests
99from dotenv import find_dotenv , load_dotenv
1010from fhir .parameters import Parameters
1111
1212# Load environment variables from .env file in the workspace root
13- # find_dotenv searches upward from current directory for .env file
14- load_dotenv (find_dotenv (usecwd = True ))
13+ load_dotenv (find_dotenv ())
1514
1615
17- class Client :
18- """A simple HTTP client for testing purposes ."""
16+ class Client ( Protocol ) :
17+ """Protocol defining the interface for HTTP clients ."""
1918
20- def __init__ (self , base_url : str , timeout : timedelta = timedelta (seconds = 1 )):
21- self .base_url = base_url
22- self ._timeout = timeout .total_seconds ()
23-
24- cert = None
25- cert_path = os .getenv ("MTLS_CERT" )
26- key_path = os .getenv ("MTLS_KEY" )
27- if cert_path and key_path :
28- cert = (cert_path , key_path )
29- self .cert = cert
19+ base_url : str
20+ cert : tuple [str , str ] | None
3021
3122 def send_to_get_structured_record_endpoint (
3223 self , payload : str , headers : dict [str , str ] | None = None
3324 ) -> requests .Response :
3425 """
3526 Send a request to the get_structured_record endpoint with the given NHS number.
3627 """
28+ ...
29+
30+ def send_health_check (self ) -> requests .Response :
31+ """
32+ Send a health check request to the API.
33+ """
34+ ...
35+
36+
37+ class LocalClient :
38+ """HTTP client that sends requests directly to the API (no proxy auth)."""
39+
40+ def __init__ (
41+ self ,
42+ base_url : str ,
43+ cert : tuple [str , str ] | None = None ,
44+ timeout : timedelta = timedelta (seconds = 1 ),
45+ ):
46+ self .base_url = base_url
47+ self .cert = cert
48+ self ._timeout = timeout .total_seconds ()
49+
50+ def send_to_get_structured_record_endpoint (
51+ self , payload : str , headers : dict [str , str ] | None = None
52+ ) -> requests .Response :
3753 url = f"{ self .base_url } /patient/$gpc.getstructuredrecord"
3854 default_headers = {
3955 "Content-Type" : "application/fhir+json" ,
40- "Ods-from" : "A12345 " ,
56+ "Ods-from" : "CONSUMER " ,
4157 "Ssp-TraceID" : "test-trace-id" ,
4258 }
4359 if headers :
4460 default_headers .update (headers )
61+
4562 return requests .post (
4663 url = url ,
4764 data = payload ,
@@ -51,24 +68,62 @@ def send_to_get_structured_record_endpoint(
5168 )
5269
5370 def send_health_check (self ) -> requests .Response :
54- """
55- Send a health check request to the API.
56- Returns:
57- Response object from the request
58- """
5971 url = f"{ self .base_url } /health"
6072 return requests .get (url = url , timeout = self ._timeout , cert = self .cert )
6173
6274
75+ class RemoteClient :
76+ """HTTP client for remote testing via the APIM proxy."""
77+
78+ def __init__ (
79+ self ,
80+ api_url : str ,
81+ auth_headers : dict [str , str ],
82+ cert : tuple [str , str ] | None = None ,
83+ timeout : timedelta = timedelta (seconds = 5 ),
84+ ):
85+ self .base_url = api_url
86+ self .cert = cert
87+ self ._auth_headers = auth_headers
88+ self ._timeout = timeout .total_seconds ()
89+
90+ def send_to_get_structured_record_endpoint (
91+ self , payload : str , headers : dict [str , str ] | None = None
92+ ) -> requests .Response :
93+ url = f"{ self .base_url } /patient/$gpc.getstructuredrecord"
94+
95+ default_headers = self ._auth_headers | {
96+ "Content-Type" : "application/fhir+json" ,
97+ "Ods-from" : "CONSUMER" ,
98+ "Ssp-TraceID" : "test-trace-id" ,
99+ }
100+ if headers :
101+ default_headers .update (headers )
102+
103+ return requests .post (
104+ url = url ,
105+ data = payload ,
106+ headers = default_headers ,
107+ timeout = self ._timeout ,
108+ cert = self .cert ,
109+ )
110+
111+ def send_health_check (self ) -> requests .Response :
112+ url = f"{ self .base_url } /health"
113+ return requests .get (
114+ url = url , headers = self ._auth_headers , timeout = self ._timeout , cert = self .cert
115+ )
116+
117+
63118@pytest .fixture (scope = "session" )
64119def mtls_cert () -> tuple [str , str ] | None :
65- """
66- Provide mTLS certificate paths.
67- """
120+ """Returns the mTLS certificate and key paths if provided in the environment."""
68121 cert_path = os .getenv ("MTLS_CERT" )
69122 key_path = os .getenv ("MTLS_KEY" )
123+
70124 if cert_path and key_path :
71125 return (cert_path , key_path )
126+
72127 return None
73128
74129
@@ -89,18 +144,51 @@ def simple_request_payload() -> Parameters:
89144
90145
91146@pytest .fixture
92- def happy_path_headers ( ) -> dict [str , str ]:
93- return {
94- "Content-Type" : "application/fhir+json" ,
95- "Ods-from" : "A12345" ,
96- "Ssp-TraceID" : "test-trace-id" ,
97- }
147+ def get_headers ( request : pytest . FixtureRequest ) -> dict [str , str ]:
148+ """Return merged auth headers for remote tests, or empty dict for local."""
149+ env = os . getenv ( "ENV" ) or request . config . getoption ( "--env" )
150+ if env == "remote" :
151+ apikey_headers = request . getfixturevalue ( "status_endpoint_auth_headers" )
152+ token = os . getenv ( "APIGEE_ACCESS_TOKEN" )
98153
154+ if token :
155+ return {"Authorization" : f"Bearer { token } " , ** apikey_headers }
99156
100- @pytest .fixture (scope = "module" )
101- def client (base_url : str ) -> Client :
102- """Create a test client for the application."""
103- return Client (base_url = base_url )
157+ nhsd_headers = request .getfixturevalue ("nhsd_apim_auth_headers" )
158+ headers = nhsd_headers | apikey_headers
159+ return cast ("dict[str, str]" , headers )
160+
161+ return {}
162+
163+
164+ @pytest .fixture
165+ def client (
166+ request : pytest .FixtureRequest ,
167+ base_url : str ,
168+ mtls_cert : tuple [str , str ] | None ,
169+ ) -> Client :
170+ """Create the appropriate HTTP client."""
171+ env = os .getenv ("ENV" ) or request .config .getoption ("--env" )
172+
173+ if env == "local" :
174+ return LocalClient (base_url = base_url , cert = mtls_cert )
175+ elif env == "remote" :
176+ proxy_url = request .getfixturevalue ("nhsd_apim_proxy_url" )
177+
178+ apikey_headers = request .getfixturevalue ("status_endpoint_auth_headers" )
179+ token = os .getenv ("APIGEE_ACCESS_TOKEN" )
180+
181+ if token :
182+ auth_headers = {"Authorization" : f"Bearer { token } " , ** apikey_headers }
183+ else :
184+ nhsd_headers = request .getfixturevalue ("nhsd_apim_auth_headers" )
185+ auth_headers = nhsd_headers | apikey_headers
186+
187+ return RemoteClient (
188+ api_url = proxy_url , auth_headers = auth_headers , cert = mtls_cert
189+ )
190+ else :
191+ raise ValueError (f"Unknown env: { env } " )
104192
105193
106194@pytest .fixture (scope = "module" )
@@ -123,3 +211,32 @@ def _fetch_env_variable[T](
123211 if not value :
124212 raise ValueError (f"{ name } environment variable is not set." )
125213 return cast ("T" , value )
214+
215+
216+ def pytest_addoption (parser : pytest .Parser ) -> None :
217+ parser .addoption (
218+ "--env" ,
219+ action = "store" ,
220+ default = "local" ,
221+ help = "Environment to run tests against" ,
222+ )
223+
224+
225+ def pytest_collection_modifyitems (
226+ config : pytest .Config , items : list [pytest .Item ]
227+ ) -> None :
228+ env = os .getenv ("ENV" ) or config .getoption ("--env" )
229+
230+ if env == "local" :
231+ skip_remote = pytest .mark .skip (reason = "Test only runs in remote environment" )
232+ for item in items :
233+ if item .get_closest_marker ("remote_only" ):
234+ item .add_marker (skip_remote )
235+
236+ if env == "remote" :
237+ for item in items :
238+ item .add_marker (
239+ pytest .mark .nhsd_apim_authorization (
240+ access = "application" , level = "level3"
241+ )
242+ )
0 commit comments