Skip to content

Commit 5cfcf51

Browse files
authored
Merge pull request #1355 from NHSDigital/feat/validate-managed-identity-token-for-dicom-api
Decode managed identity token for dicom API routes
2 parents abc17c0 + 5c12134 commit 5cfcf51

5 files changed

Lines changed: 274 additions & 47 deletions

File tree

manage_breast_screening/config/settings/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,5 @@ def list_env(key):
382382
"style-src": (CSP_SELF,),
383383
}
384384
}
385+
386+
BYPASS_API_TOKEN_AUTH = boolean_env("BYPASS_API_TOKEN_AUTH", default=False)

manage_breast_screening/dicom/api.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
from manage_breast_screening.core.api_schema import ErrorResponse, StatusResponse
1111
from manage_breast_screening.gateway.models import GatewayAction, GatewayActionStatus
1212

13+
from .authentication import Authentication
1314
from .dicom_recorder import DicomRecorder
1415

15-
router = Router()
16+
router = Router(auth=Authentication())
1617

1718
logger = logging.getLogger(__name__)
1819

20+
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
21+
1922

2023
class SuccessResponse(ninja.Schema):
2124
study_instance_uid: str
@@ -41,8 +44,7 @@ def upload(request, source_message_id: str, file: File[UploadedFile]):
4144
"""
4245
Accepts PUT with a single DICOM file in form field 'file'
4346
"""
44-
max_size = 100 * 1024 * 1024
45-
if file.size > max_size:
47+
if file.size > MAX_FILE_SIZE:
4648
return 400, {
4749
"title": "File too large",
4850
"status": 400,
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import logging
2+
import os
3+
from functools import cached_property
4+
5+
import jwt
6+
from django.conf import settings
7+
from ninja.security import HttpBearer
8+
9+
logger = logging.getLogger(__name__)
10+
11+
ALLOWED_ALGORITHMS = ["RS256"]
12+
JWT_SET_CACHE_TTL_SECONDS = 3600
13+
14+
15+
class Authentication(HttpBearer):
16+
def authenticate(self, _, token) -> dict | None:
17+
"""
18+
Authenticates the incoming request by validating the JWT token.
19+
"""
20+
if self.bypass_authentication:
21+
logger.warning("Authentication bypass is enabled.")
22+
return {"sub": "bypass_user"}
23+
24+
return self._decode(token)
25+
26+
def _decode(self, token: str) -> dict | None:
27+
"""
28+
Decodes and validates the JWT token using the provided RSA key.
29+
Checks the signature, audience, and issuer claims to ensure the token is valid and intended for this API.
30+
"""
31+
try:
32+
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
33+
payload = jwt.decode(
34+
token,
35+
signing_key.key,
36+
algorithms=ALLOWED_ALGORITHMS,
37+
audience=self.audience,
38+
issuer=self.issuers,
39+
)
40+
return payload
41+
except jwt.PyJWKClientError:
42+
logger.exception("Error fetching JWKS keys from Azure AD.")
43+
except jwt.ExpiredSignatureError:
44+
logger.exception("Token is expired")
45+
except (jwt.InvalidAudienceError, jwt.InvalidIssuerError):
46+
logger.exception("Invalid claims. Please check the audience and issuer.")
47+
except jwt.InvalidTokenError:
48+
logger.exception("Token is invalid")
49+
except Exception:
50+
logger.exception("Unable to parse authentication token.")
51+
52+
@cached_property
53+
def jwks_client(self) -> jwt.PyJWKClient:
54+
"""
55+
Creates a PyJWKClient instance for fetching and caching the JWKS keys from Azure AD.
56+
Caching is enabled to improve performance and reduce the number of network requests to Azure AD.
57+
The cache will be refreshed after the specified TTL expires.
58+
"""
59+
return jwt.PyJWKClient(
60+
self.discovery_keys_url,
61+
cache_jwk_set=True,
62+
cache_keys=True,
63+
lifespan=JWT_SET_CACHE_TTL_SECONDS,
64+
)
65+
66+
@cached_property
67+
def discovery_keys_url(self) -> str:
68+
return f"https://login.microsoftonline.com/{self.tenant_id}/discovery/v2.0/keys"
69+
70+
@property
71+
def audience(self) -> str | None:
72+
"""
73+
The expected audience claim in the JWT token. This should match the API_AUDIENCE environment variable.
74+
"""
75+
return os.getenv("API_AUDIENCE")
76+
77+
@property
78+
def tenant_id(self) -> str | None:
79+
"""
80+
The Azure AD tenant ID. This should be set as the TENANT_ID environment variable.
81+
"""
82+
return os.getenv("TENANT_ID", "")
83+
84+
@cached_property
85+
def issuers(self) -> list:
86+
"""
87+
The expected issuer claim(s) in the JWT token. This should match the tenant ID and the Azure AD endpoints.
88+
"""
89+
return [
90+
f"https://sts.windows.net/{self.tenant_id}/",
91+
f"https://login.microsoftonline.com/{self.tenant_id}/v2.0/",
92+
]
93+
94+
@property
95+
def bypass_authentication(self) -> bool:
96+
return getattr(settings, "BYPASS_API_TOKEN_AUTH", False)

manage_breast_screening/dicom/tests/test_api.py

Lines changed: 69 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,32 @@
66
import pytest
77
from django.core.files.uploadedfile import SimpleUploadedFile
88
from ninja.testing import TestClient
9+
from pydicom.uid import generate_uid
910

1011
from manage_breast_screening.core.api import api
11-
from manage_breast_screening.dicom.models import Study
1212
from manage_breast_screening.gateway.models import GatewayActionStatus
1313
from manage_breast_screening.gateway.tests.factories import GatewayActionFactory
1414
from manage_breast_screening.participants.models.appointment import (
1515
AppointmentStatusNames,
1616
)
1717
from manage_breast_screening.participants.tests.factories import AppointmentFactory
1818

19+
from ..authentication import Authentication
20+
from ..dicom_recorder import DicomRecorder
21+
from ..models import Study
22+
1923
os.environ["NINJA_SKIP_REGISTRY"] = "yes"
2024

2125
client = TestClient(api)
2226

2327

28+
@pytest.fixture(autouse=True)
29+
def setup(monkeypatch):
30+
monkeypatch.setenv("API_ENABLED", "true")
31+
monkeypatch.setenv("API_AUDIENCE", "test_audience")
32+
monkeypatch.setenv("TENANT_ID", "test_tenant_id")
33+
34+
2435
@pytest.fixture
2536
def dicom_file(dataset) -> bytes:
2637
with io.BytesIO() as buffer:
@@ -38,11 +49,14 @@ def appointment_stub():
3849
)
3950

4051

41-
@pytest.mark.django_db
42-
def test_upload_success(dataset, dicom_file, monkeypatch):
43-
monkeypatch.setenv("API_ENABLED", "true")
44-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
52+
@pytest.fixture
53+
def mock_authentication():
54+
with patch.object(Authentication, "authenticate", return_value={"sub": "testuser"}):
55+
yield
56+
4557

58+
@pytest.mark.django_db
59+
def test_upload_success(dataset, dicom_file, mock_authentication, appointment_stub):
4660
appointment = AppointmentFactory(current_status=AppointmentStatusNames.IN_PROGRESS)
4761

4862
with patch(
@@ -52,7 +66,7 @@ def test_upload_success(dataset, dicom_file, monkeypatch):
5266
response = client.put(
5367
f"/dicom/{appointment.pk}",
5468
FILES={"file": dicom_file},
55-
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
69+
headers={"Authorization": "Bearer testtoken"},
5670
)
5771

5872
assert response.status_code == 201
@@ -65,23 +79,17 @@ def test_upload_success(dataset, dicom_file, monkeypatch):
6579
assert study.source_message_id == str(appointment.pk)
6680

6781

68-
def test_upload_no_file(monkeypatch):
69-
monkeypatch.setenv("API_ENABLED", "true")
70-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
71-
82+
def test_upload_no_file(mock_authentication):
7283
response = client.put(
7384
"/dicom/abc123",
7485
FILES={"file": None},
75-
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
86+
headers={"Authorization": "Bearer testtoken"},
7687
)
7788

7889
assert response.status_code == 422
7990

8091

81-
def test_upload_invalid_file(monkeypatch, appointment_stub):
82-
monkeypatch.setenv("API_ENABLED", "true")
83-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
84-
92+
def test_upload_invalid_file(mock_authentication, appointment_stub):
8593
invalid_file = SimpleUploadedFile(
8694
"invalid.dcm", b"not a dicom file", content_type="application/dicom"
8795
)
@@ -93,7 +101,7 @@ def test_upload_invalid_file(monkeypatch, appointment_stub):
93101
response = client.put(
94102
"/dicom/abc123",
95103
FILES={"file": invalid_file},
96-
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
104+
headers={"Authorization": "Bearer testtoken"},
97105
)
98106

99107
assert response.status_code == 400
@@ -102,16 +110,13 @@ def test_upload_invalid_file(monkeypatch, appointment_stub):
102110
assert response.json()["detail"] == "The uploaded file is not a valid DICOM file."
103111

104112

105-
def test_upload_file_thats_too_large(monkeypatch):
106-
monkeypatch.setenv("API_ENABLED", "true")
107-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
108-
113+
def test_upload_file_thats_too_large(mock_authentication):
109114
invalid_file = MagicMock(spec=SimpleUploadedFile, size=101 * 1024 * 1024)
110115

111116
response = client.put(
112117
"/dicom/abc123",
113118
FILES={"file": invalid_file},
114-
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
119+
headers={"Authorization": "Bearer testtoken"},
115120
)
116121

117122
assert response.status_code == 400
@@ -120,10 +125,7 @@ def test_upload_file_thats_too_large(monkeypatch):
120125
assert response.json()["detail"] == "The file cannot be larger than 100MB"
121126

122127

123-
def test_upload_missing_uids(dataset, monkeypatch, appointment_stub):
124-
monkeypatch.setenv("API_ENABLED", "true")
125-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
126-
128+
def test_upload_missing_uids(dataset, mock_authentication, appointment_stub):
127129
del dataset.StudyInstanceUID
128130
del dataset.SeriesInstanceUID
129131
del dataset.SOPInstanceUID
@@ -142,7 +144,7 @@ def test_upload_missing_uids(dataset, monkeypatch, appointment_stub):
142144
response = client.put(
143145
"/dicom/abc123",
144146
FILES={"file": dicom_file},
145-
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
147+
headers={"Authorization": "Bearer testtoken"},
146148
)
147149

148150
assert response.status_code == 400
@@ -154,10 +156,9 @@ def test_upload_missing_uids(dataset, monkeypatch, appointment_stub):
154156
)
155157

156158

157-
def test_upload_appointment_not_in_progress(dicom_file, monkeypatch, appointment_stub):
158-
monkeypatch.setenv("API_ENABLED", "true")
159-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
160-
159+
def test_upload_appointment_not_in_progress(
160+
dicom_file, mock_authentication, appointment_stub
161+
):
161162
appointment_stub.is_in_progress.return_value = False
162163

163164
with patch(
@@ -167,29 +168,26 @@ def test_upload_appointment_not_in_progress(dicom_file, monkeypatch, appointment
167168
response = client.put(
168169
"/dicom/abc123",
169170
FILES={"file": dicom_file},
170-
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
171+
headers={"Authorization": "Bearer testtoken"},
171172
)
172173

173174
assert response.status_code == 500
174175
assert response.json()["title"] == "Internal Server Error"
175176

176177

177-
@pytest.mark.django_db
178-
def test_upload_when_api_disabled(dicom_file, monkeypatch):
178+
def test_upload_when_api_disabled(dicom_file, mock_authentication, monkeypatch):
179179
monkeypatch.setenv("API_ENABLED", "false")
180-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
181180

182181
response = client.put(
183182
"/dicom/abc123",
184183
FILES={"file": dicom_file},
185-
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
184+
headers={"Authorization": "Bearer testtoken"},
186185
)
187186

188187
assert response.status_code == 403
189188
assert response.json()["status"] == "API is not available"
190189

191190

192-
@pytest.mark.django_db
193191
def test_upload_no_auth(dicom_file):
194192
response = client.put(
195193
"/dicom/abc123",
@@ -202,11 +200,41 @@ def test_upload_no_auth(dicom_file):
202200
}
203201

204202

205-
@pytest.mark.django_db
206-
def test_report_failure(monkeypatch):
207-
monkeypatch.setenv("API_ENABLED", "true")
208-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
203+
def test_upload_invalid_auth(dicom_file):
204+
response = client.put(
205+
"/dicom/abc123",
206+
FILES={"file": dicom_file},
207+
headers={"Authorization": "Bearer invalidtoken"},
208+
)
209209

210+
assert response.status_code == 401
211+
assert response.json() == {
212+
"detail": "Unauthorized",
213+
}
214+
215+
216+
def test_upload_bypass_token_validation(dicom_file):
217+
with patch.object(Authentication, "bypass_authentication", return_value=True):
218+
with patch.object(
219+
DicomRecorder,
220+
"get_or_create_records",
221+
return_value=(
222+
MagicMock(study_instance_uid=generate_uid()),
223+
MagicMock(series_instance_uid=generate_uid()),
224+
MagicMock(sop_instance_uid=generate_uid(), id=1),
225+
),
226+
):
227+
response = client.put(
228+
"/dicom/abc123",
229+
FILES={"file": dicom_file},
230+
headers={"Authorization": "Bearer anytoken"},
231+
)
232+
233+
assert response.status_code == 201
234+
235+
236+
@pytest.mark.django_db
237+
def test_report_failure(mock_authentication):
210238
action = GatewayActionFactory()
211239

212240
response = client.patch(
@@ -225,10 +253,7 @@ def test_report_failure(monkeypatch):
225253

226254

227255
@pytest.mark.django_db
228-
def test_report_failure_action_not_found(monkeypatch):
229-
monkeypatch.setenv("API_ENABLED", "true")
230-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
231-
256+
def test_report_failure_action_not_found(mock_authentication):
232257
response = client.patch(
233258
"/dicom/00000000-0000-0000-0000-000000000000/failure",
234259
json={"error": "Missing PatientID"},

0 commit comments

Comments
 (0)