Skip to content

Commit 6230bbc

Browse files
committed
Rename TokenValidator -> Authentication
Refactor authentication to use PyJWKClient with caching options. Move decode options into properties.
1 parent 69801ed commit 6230bbc

6 files changed

Lines changed: 236 additions & 240 deletions

File tree

manage_breast_screening/dicom/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from manage_breast_screening.core.api_schema import ErrorResponse, StatusResponse
1111
from manage_breast_screening.gateway.models import GatewayAction, GatewayActionStatus
1212

13+
from .authentication import Authentication
1314
from .dicom_recorder import DicomRecorder
14-
from .token_validator import TokenValidator
1515

16-
router = Router(auth=TokenValidator())
16+
router = Router(auth=Authentication())
1717

1818
logger = logging.getLogger(__name__)
1919

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import logging
2+
import os
3+
4+
import jwt
5+
from django.conf import settings
6+
from ninja.security import HttpBearer
7+
8+
logger = logging.getLogger(__name__)
9+
10+
ALLOWED_ALGORITHMS = ["RS256"]
11+
JWT_SET_CACHE_TTL_SECONDS = 3600
12+
13+
14+
class Authentication(HttpBearer):
15+
def authenticate(self, _, token) -> dict | None:
16+
"""
17+
Authenticates the incoming request by validating the JWT token.
18+
"""
19+
if self.bypass_authentication:
20+
logger.warning("Authentication bypass is enabled.")
21+
return {"sub": "bypass_user"}
22+
23+
return self._decode(token)
24+
25+
def _decode(self, token: str) -> dict | None:
26+
"""
27+
Decodes and validates the JWT token using the provided RSA key.
28+
Checks the signature, audience, and issuer claims to ensure the token is valid and intended for this API.
29+
"""
30+
try:
31+
signing_key = self.jwks_client.get_signing_key_from_jwt(token).key
32+
payload = jwt.decode(
33+
token,
34+
signing_key,
35+
algorithms=ALLOWED_ALGORITHMS,
36+
audience=self.audience,
37+
issuer=self.issuers,
38+
)
39+
return payload
40+
except jwt.PyJWKClientError:
41+
logger.exception("Error fetching JWKS keys from Azure AD.")
42+
except jwt.ExpiredSignatureError:
43+
logger.exception("Token is expired")
44+
except (jwt.InvalidAudienceError, jwt.InvalidIssuerError):
45+
logger.exception("Invalid claims. Please check the audience and issuer.")
46+
except jwt.InvalidTokenError:
47+
logger.exception("Token is invalid")
48+
except Exception:
49+
logger.exception("Unable to parse authentication token.")
50+
51+
@property
52+
def jwks_client(self) -> jwt.PyJWKClient:
53+
"""
54+
Creates a PyJWKClient instance for fetching and caching the JWKS keys from Azure AD.
55+
Caching is enabled to improve performance and reduce the number of network requests to Azure AD.
56+
The cache will be refreshed after the specified TTL expires.
57+
"""
58+
return jwt.PyJWKClient(
59+
self.discovery_keys_url,
60+
cache_jwk_set=True,
61+
cache_keys=True,
62+
lifespan=JWT_SET_CACHE_TTL_SECONDS,
63+
)
64+
65+
@property
66+
def discovery_keys_url(self) -> str:
67+
return f"https://login.microsoftonline.com/{self.tenant_id}/discovery/v2.0/keys"
68+
69+
@property
70+
def audience(self) -> str | None:
71+
"""
72+
The expected audience claim in the JWT token. This should match the API_AUDIENCE environment variable.
73+
"""
74+
return os.getenv("API_AUDIENCE")
75+
76+
@property
77+
def tenant_id(self) -> str | None:
78+
"""
79+
The Azure AD tenant ID. This should be set as the TENANT_ID environment variable.
80+
"""
81+
return os.getenv("TENANT_ID", "")
82+
83+
@property
84+
def issuers(self) -> list:
85+
"""
86+
The expected issuer claim(s) in the JWT token. This should match the tenant ID and the Azure AD endpoints.
87+
"""
88+
return [
89+
f"https://sts.windows.net/{self.tenant_id}/",
90+
f"https://login.microsoftonline.com/{self.tenant_id}/v2.0/",
91+
]
92+
93+
@property
94+
def bypass_authentication(self) -> bool:
95+
return getattr(settings, "BYPASS_API_TOKEN_AUTH", False)

manage_breast_screening/dicom/tests/test_api.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
from django.core.files.uploadedfile import SimpleUploadedFile
88
from ninja.testing import TestClient
9+
from pydicom.uid import generate_uid
910

1011
from manage_breast_screening.core.api import api
1112
from manage_breast_screening.dicom.models import Study
@@ -16,18 +17,20 @@
1617
)
1718
from manage_breast_screening.participants.tests.factories import AppointmentFactory
1819

20+
from ..authentication import Authentication
1921
from ..dicom_recorder import DicomRecorder
2022
from ..models import Study
21-
from ..token_validator import TokenValidator
2223

2324
os.environ["NINJA_SKIP_REGISTRY"] = "yes"
2425

2526
client = TestClient(api)
2627

2728

2829
@pytest.fixture(autouse=True)
29-
def enable_api(monkeypatch):
30+
def setup(monkeypatch):
3031
monkeypatch.setenv("API_ENABLED", "true")
32+
monkeypatch.setenv("API_AUDIENCE", "test_audience")
33+
monkeypatch.setenv("TENANT_ID", "test_tenant_id")
3134

3235

3336
@pytest.fixture
@@ -46,27 +49,21 @@ def appointment_stub():
4649
is_in_progress=MagicMock(return_value=True),
4750
)
4851

52+
4953
@pytest.fixture
50-
def mock_token_validator():
51-
with patch.object(TokenValidator, "authenticate", return_value={"sub": "testuser"}):
54+
def mock_authentication():
55+
with patch.object(Authentication, "authenticate", return_value={"sub": "testuser"}):
5256
yield
5357

5458

5559
@pytest.mark.django_db
56-
def test_upload_success(dataset, dicom_file, monkeypatch):
57-
monkeypatch.setenv("API_ENABLED", "true")
58-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
59-
60+
def test_upload_success(dataset, dicom_file, mock_authentication, appointment_stub):
6061
appointment = AppointmentFactory(current_status=AppointmentStatusNames.IN_PROGRESS)
6162

6263
with patch(
6364
"manage_breast_screening.dicom.dicom_recorder.lookup_appointment",
6465
return_value=appointment,
6566
):
66-
67-
@pytest.mark.django_db
68-
def test_upload_success(dataset, dicom_file, mock_token_validator):
69-
with patch.object(DicomRecorder, "appointment_in_progress", return_value=True):
7067
response = client.put(
7168
f"/dicom/{appointment.pk}",
7269
FILES={"file": dicom_file},
@@ -83,7 +80,7 @@ def test_upload_success(dataset, dicom_file, mock_token_validator):
8380
assert study.source_message_id == str(appointment.pk)
8481

8582

86-
def test_upload_no_file(mock_token_validator):
83+
def test_upload_no_file(mock_authentication):
8784
response = client.put(
8885
"/dicom/abc123",
8986
FILES={"file": None},
@@ -93,9 +90,7 @@ def test_upload_no_file(mock_token_validator):
9390
assert response.status_code == 422
9491

9592

96-
def test_upload_invalid_file(monkeypatch, mock_token_validator):
97-
monkeypatch.setenv("API_ENABLED", "true")
98-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
93+
def test_upload_invalid_file(mock_authentication, appointment_stub):
9994
invalid_file = SimpleUploadedFile(
10095
"invalid.dcm", b"not a dicom file", content_type="application/dicom"
10196
)
@@ -116,7 +111,7 @@ def test_upload_invalid_file(monkeypatch, mock_token_validator):
116111
assert response.json()["detail"] == "The uploaded file is not a valid DICOM file."
117112

118113

119-
def test_upload_file_thats_too_large(mock_token_validator):
114+
def test_upload_file_thats_too_large(mock_authentication):
120115
invalid_file = MagicMock(spec=SimpleUploadedFile, size=101 * 1024 * 1024)
121116

122117
response = client.put(
@@ -131,9 +126,7 @@ def test_upload_file_thats_too_large(mock_token_validator):
131126
assert response.json()["detail"] == "The file cannot be larger than 100MB"
132127

133128

134-
def test_upload_missing_uids(dataset, monkeypatch, appointment_stub):
135-
monkeypatch.setenv("API_ENABLED", "true")
136-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
129+
def test_upload_missing_uids(dataset, mock_authentication, appointment_stub):
137130
del dataset.StudyInstanceUID
138131
del dataset.SeriesInstanceUID
139132
del dataset.SOPInstanceUID
@@ -164,10 +157,7 @@ def test_upload_missing_uids(dataset, monkeypatch, appointment_stub):
164157
)
165158

166159

167-
def test_upload_appointment_not_in_progress(dicom_file, monkeypatch, appointment_stub):
168-
monkeypatch.setenv("API_ENABLED", "true")
169-
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
170-
160+
def test_upload_appointment_not_in_progress(dicom_file, mock_authentication, appointment_stub):
171161
appointment_stub.is_in_progress.return_value = False
172162

173163
with patch(
@@ -184,7 +174,7 @@ def test_upload_appointment_not_in_progress(dicom_file, monkeypatch, appointment
184174
assert response.json()["title"] == "Internal Server Error"
185175

186176

187-
def test_upload_when_api_disabled(dicom_file, mock_token_validator, monkeypatch):
177+
def test_upload_when_api_disabled(dicom_file, mock_authentication, monkeypatch):
188178
monkeypatch.setenv("API_ENABLED", "false")
189179

190180
response = client.put(
@@ -222,8 +212,28 @@ def test_upload_invalid_auth(dicom_file):
222212
}
223213

224214

215+
def test_upload_bypass_token_validation(dicom_file):
216+
with patch.object(Authentication, "bypass_authentication", return_value=True):
217+
with patch.object(
218+
DicomRecorder,
219+
"get_or_create_records",
220+
return_value=(
221+
MagicMock(study_instance_uid=generate_uid()),
222+
MagicMock(series_instance_uid=generate_uid()),
223+
MagicMock(sop_instance_uid=generate_uid(), id=1),
224+
),
225+
):
226+
response = client.put(
227+
"/dicom/abc123",
228+
FILES={"file": dicom_file},
229+
headers={"Authorization": "Bearer anytoken"},
230+
)
231+
232+
assert response.status_code == 201
233+
234+
225235
@pytest.mark.django_db
226-
def test_report_failure(mock_token_validator):
236+
def test_report_failure(mock_authentication):
227237
action = GatewayActionFactory()
228238

229239
response = client.patch(
@@ -242,7 +252,7 @@ def test_report_failure(mock_token_validator):
242252

243253

244254
@pytest.mark.django_db
245-
def test_report_failure_action_not_found(mock_token_validator):
255+
def test_report_failure_action_not_found(mock_authentication):
246256
response = client.patch(
247257
"/dicom/00000000-0000-0000-0000-000000000000/failure",
248258
json={"error": "Missing PatientID"},
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from unittest.mock import Mock, patch
2+
3+
import jwt
4+
import pytest
5+
from django.conf import settings
6+
7+
from ..authentication import Authentication
8+
9+
10+
@patch(f"{Authentication.__module__}.logger")
11+
class TestAuthentication:
12+
@pytest.fixture(autouse=True)
13+
def setup_env(self):
14+
with patch.dict(
15+
"os.environ",
16+
{
17+
"API_AUDIENCE": "test_audience",
18+
"TENANT_ID": "test_tenant_id",
19+
"BYPASS_API_TOKEN_AUTH": "false",
20+
},
21+
):
22+
yield
23+
24+
@pytest.fixture
25+
def mock_jwks_signing_key(self):
26+
with patch.object(
27+
jwt.PyJWKClient,
28+
"get_signing_key_from_jwt",
29+
return_value=Mock(key="test_signing_key"),
30+
):
31+
yield
32+
33+
@patch.object(
34+
jwt.PyJWKClient, "get_signing_key_from_jwt", side_effect=jwt.PyJWKClientError
35+
)
36+
def test_with_no_matching_signing_key(self, mock_signing_key_error, mock_logger):
37+
authenticator = Authentication()
38+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
39+
mock_logger.exception.assert_called_with(
40+
"Error fetching JWKS keys from Azure AD."
41+
)
42+
43+
@patch(
44+
f"{Authentication.__module__}.jwt.decode", side_effect=jwt.ExpiredSignatureError
45+
)
46+
def test_with_expired_signature(
47+
self, mock_decode, mock_logger, mock_jwks_signing_key
48+
):
49+
authenticator = Authentication()
50+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
51+
mock_logger.exception.assert_called_with("Token is expired")
52+
53+
@patch(
54+
f"{Authentication.__module__}.jwt.decode", side_effect=jwt.InvalidAudienceError
55+
)
56+
def test_with_invalid_claims(self, _, mock_logger, mock_jwks_signing_key):
57+
authenticator = Authentication()
58+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
59+
mock_logger.exception.assert_called_with(
60+
"Invalid claims. Please check the audience and issuer."
61+
)
62+
63+
@patch(
64+
f"{Authentication.__module__}.jwt.decode", side_effect=jwt.InvalidIssuerError
65+
)
66+
def test_with_invalid_issuer(self, _, mock_logger, mock_jwks_signing_key):
67+
authenticator = Authentication()
68+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
69+
mock_logger.exception.assert_called_with(
70+
"Invalid claims. Please check the audience and issuer."
71+
)
72+
73+
@patch(f"{Authentication.__module__}.jwt.decode", side_effect=jwt.InvalidTokenError)
74+
def test_with_invalid_token(self, _, mock_logger, mock_jwks_signing_key):
75+
authenticator = Authentication()
76+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
77+
mock_logger.exception.assert_called_with("Token is invalid")
78+
79+
@patch(f"{Authentication.__module__}.jwt.decode", side_effect=Exception)
80+
def test_with_unexpected_exception(self, _, mock_logger, mock_jwks_signing_key):
81+
authenticator = Authentication()
82+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
83+
mock_logger.exception.assert_called_with(
84+
"Unable to parse authentication token."
85+
)
86+
87+
@patch(
88+
f"{Authentication.__module__}.jwt.decode", return_value={"sub": "1234567890"}
89+
)
90+
def test_with_valid_token(self, _, mock_logger, mock_jwks_signing_key):
91+
authenticator = Authentication()
92+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) == {
93+
"sub": "1234567890"
94+
}
95+
mock_logger.exception.assert_not_called()
96+
97+
def test_authentication_bypass_enabled(self, mock_logger, mock_jwks_signing_key):
98+
with patch.object(settings, "BYPASS_API_TOKEN_AUTH", return_value=True):
99+
authenticator = Authentication()
100+
assert authenticator(
101+
Mock(headers={"Authorization": "Bearer anytoken"})
102+
) == {"sub": "bypass_user"}

0 commit comments

Comments
 (0)