Skip to content

Commit 7c46fbb

Browse files
committed
Rename TokenValidator -> Authentication
Refactor authentication to use PyJWKClient with caching options. Move decode options into properties.
1 parent cf4b480 commit 7c46fbb

6 files changed

Lines changed: 235 additions & 226 deletions

File tree

manage_breast_screening/dicom/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from manage_breast_screening.core.api_schema import ErrorResponse, StatusResponse
1111
from manage_breast_screening.gateway.models import GatewayAction, GatewayActionStatus
1212

13+
from .authentication import Authentication
1314
from .dicom_recorder import DicomRecorder
14-
from .token_validator import TokenValidator
1515

16-
router = Router(auth=TokenValidator())
16+
router = Router(auth=Authentication())
1717

1818
logger = logging.getLogger(__name__)
1919

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import logging
2+
import os
3+
4+
import jwt
5+
from django.conf import settings
6+
from ninja.security import HttpBearer
7+
8+
logger = logging.getLogger(__name__)
9+
10+
ALLOWED_ALGORITHMS = ["RS256"]
11+
JWT_SET_CACHE_TTL_SECONDS = 3600
12+
13+
14+
class Authentication(HttpBearer):
15+
def authenticate(self, _, token) -> dict | None:
16+
"""
17+
Authenticates the incoming request by validating the JWT token.
18+
"""
19+
if self.bypass_authentication:
20+
logger.warning("Authentication bypass is enabled.")
21+
return {"sub": "bypass_user"}
22+
23+
return self._decode(token)
24+
25+
def _decode(self, token: str) -> dict | None:
26+
"""
27+
Decodes and validates the JWT token using the provided RSA key.
28+
Checks the signature, audience, and issuer claims to ensure the token is valid and intended for this API.
29+
"""
30+
try:
31+
signing_key = self.jwks_client.get_signing_key_from_jwt(token).key
32+
payload = jwt.decode(
33+
token,
34+
signing_key,
35+
algorithms=ALLOWED_ALGORITHMS,
36+
audience=self.audience,
37+
issuer=self.issuers,
38+
)
39+
return payload
40+
except jwt.PyJWKClientError:
41+
logger.exception("Error fetching JWKS keys from Azure AD.")
42+
except jwt.ExpiredSignatureError:
43+
logger.exception("Token is expired")
44+
except (jwt.InvalidAudienceError, jwt.InvalidIssuerError):
45+
logger.exception("Invalid claims. Please check the audience and issuer.")
46+
except jwt.InvalidTokenError:
47+
logger.exception("Token is invalid")
48+
except Exception:
49+
logger.exception("Unable to parse authentication token.")
50+
51+
@property
52+
def jwks_client(self) -> jwt.PyJWKClient:
53+
"""
54+
Creates a PyJWKClient instance for fetching and caching the JWKS keys from Azure AD.
55+
Caching is enabled to improve performance and reduce the number of network requests to Azure AD.
56+
The cache will be refreshed after the specified TTL expires.
57+
"""
58+
return jwt.PyJWKClient(
59+
self.discovery_keys_url,
60+
cache_jwk_set=True,
61+
cache_keys=True,
62+
lifespan=JWT_SET_CACHE_TTL_SECONDS,
63+
)
64+
65+
@property
66+
def discovery_keys_url(self) -> str:
67+
return f"https://login.microsoftonline.com/{self.tenant_id}/discovery/v2.0/keys"
68+
69+
@property
70+
def audience(self) -> str | None:
71+
"""
72+
The expected audience claim in the JWT token. This should match the API_AUDIENCE environment variable.
73+
"""
74+
return os.getenv("API_AUDIENCE")
75+
76+
@property
77+
def tenant_id(self) -> str | None:
78+
"""
79+
The Azure AD tenant ID. This should be set as the TENANT_ID environment variable.
80+
"""
81+
return os.getenv("TENANT_ID", "")
82+
83+
@property
84+
def issuers(self) -> list:
85+
"""
86+
The expected issuer claim(s) in the JWT token. This should match the tenant ID and the Azure AD endpoints.
87+
"""
88+
return [
89+
f"https://sts.windows.net/{self.tenant_id}/",
90+
f"https://login.microsoftonline.com/{self.tenant_id}/v2.0/",
91+
]
92+
93+
@property
94+
def bypass_authentication(self) -> bool:
95+
return getattr(settings, "BYPASS_API_TOKEN_AUTH", False)

manage_breast_screening/dicom/tests/test_api.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@
66
import pytest
77
from django.core.files.uploadedfile import SimpleUploadedFile
88
from ninja.testing import TestClient
9+
from pydicom.uid import generate_uid
910

1011
from manage_breast_screening.core.api import api
1112
from manage_breast_screening.gateway.models import GatewayActionStatus
1213
from manage_breast_screening.gateway.tests.factories import GatewayActionFactory
1314

15+
from ..authentication import Authentication
1416
from ..dicom_recorder import DicomRecorder
1517
from ..models import Study
16-
from ..token_validator import TokenValidator
1718

1819
os.environ["NINJA_SKIP_REGISTRY"] = "yes"
1920

2021
client = TestClient(api)
2122

2223

2324
@pytest.fixture(autouse=True)
24-
def enable_api(monkeypatch):
25+
def setup(monkeypatch):
2526
monkeypatch.setenv("API_ENABLED", "true")
27+
monkeypatch.setenv("API_AUDIENCE", "test_audience")
28+
monkeypatch.setenv("TENANT_ID", "test_tenant_id")
2629

2730

2831
@pytest.fixture
@@ -36,13 +39,13 @@ def dicom_file(dataset) -> bytes:
3639

3740

3841
@pytest.fixture
39-
def mock_token_validator():
40-
with patch.object(TokenValidator, "authenticate", return_value={"sub": "testuser"}):
42+
def mock_authentication():
43+
with patch.object(Authentication, "authenticate", return_value={"sub": "testuser"}):
4144
yield
4245

4346

4447
@pytest.mark.django_db
45-
def test_upload_success(dataset, dicom_file, mock_token_validator):
48+
def test_upload_success(dataset, dicom_file, mock_authentication):
4649
with patch.object(DicomRecorder, "appointment_in_progress", return_value=True):
4750
response = client.put(
4851
"/dicom/abc123",
@@ -60,7 +63,7 @@ def test_upload_success(dataset, dicom_file, mock_token_validator):
6063
assert study.source_message_id == "abc123"
6164

6265

63-
def test_upload_no_file(mock_token_validator):
66+
def test_upload_no_file(mock_authentication):
6467
response = client.put(
6568
"/dicom/abc123",
6669
FILES={"file": None},
@@ -70,7 +73,7 @@ def test_upload_no_file(mock_token_validator):
7073
assert response.status_code == 422
7174

7275

73-
def test_upload_invalid_file(mock_token_validator):
76+
def test_upload_invalid_file(mock_authentication):
7477
invalid_file = SimpleUploadedFile(
7578
"invalid.dcm", b"not a dicom file", content_type="application/dicom"
7679
)
@@ -88,7 +91,7 @@ def test_upload_invalid_file(mock_token_validator):
8891
assert response.json()["detail"] == "The uploaded file is not a valid DICOM file."
8992

9093

91-
def test_upload_file_thats_too_large(mock_token_validator):
94+
def test_upload_file_thats_too_large(mock_authentication):
9295
invalid_file = MagicMock(spec=SimpleUploadedFile, size=101 * 1024 * 1024)
9396

9497
response = client.put(
@@ -103,7 +106,7 @@ def test_upload_file_thats_too_large(mock_token_validator):
103106
assert response.json()["detail"] == "The file cannot be larger than 100MB"
104107

105108

106-
def test_upload_missing_uids(dataset, mock_token_validator):
109+
def test_upload_missing_uids(dataset, mock_authentication):
107110
del dataset.StudyInstanceUID
108111
del dataset.SeriesInstanceUID
109112
del dataset.SOPInstanceUID
@@ -131,7 +134,7 @@ def test_upload_missing_uids(dataset, mock_token_validator):
131134
)
132135

133136

134-
def test_upload_appointment_not_in_progress(dicom_file, mock_token_validator):
137+
def test_upload_appointment_not_in_progress(dicom_file, mock_authentication):
135138
with patch.object(DicomRecorder, "appointment_in_progress", return_value=False):
136139
response = client.put(
137140
"/dicom/abc123",
@@ -143,7 +146,7 @@ def test_upload_appointment_not_in_progress(dicom_file, mock_token_validator):
143146
assert response.json()["title"] == "Internal Server Error"
144147

145148

146-
def test_upload_when_api_disabled(dicom_file, mock_token_validator, monkeypatch):
149+
def test_upload_when_api_disabled(dicom_file, mock_authentication, monkeypatch):
147150
monkeypatch.setenv("API_ENABLED", "false")
148151

149152
response = client.put(
@@ -181,8 +184,28 @@ def test_upload_invalid_auth(dicom_file):
181184
}
182185

183186

187+
def test_upload_bypass_token_validation(dicom_file):
188+
with patch.object(Authentication, "bypass_authentication", return_value=True):
189+
with patch.object(
190+
DicomRecorder,
191+
"get_or_create_records",
192+
return_value=(
193+
MagicMock(study_instance_uid=generate_uid()),
194+
MagicMock(series_instance_uid=generate_uid()),
195+
MagicMock(sop_instance_uid=generate_uid(), id=1),
196+
),
197+
):
198+
response = client.put(
199+
"/dicom/abc123",
200+
FILES={"file": dicom_file},
201+
headers={"Authorization": "Bearer anytoken"},
202+
)
203+
204+
assert response.status_code == 201
205+
206+
184207
@pytest.mark.django_db
185-
def test_report_failure(mock_token_validator):
208+
def test_report_failure(mock_authentication):
186209
action = GatewayActionFactory()
187210

188211
response = client.patch(
@@ -201,7 +224,7 @@ def test_report_failure(mock_token_validator):
201224

202225

203226
@pytest.mark.django_db
204-
def test_report_failure_action_not_found(mock_token_validator):
227+
def test_report_failure_action_not_found(mock_authentication):
205228
response = client.patch(
206229
"/dicom/00000000-0000-0000-0000-000000000000/failure",
207230
json={"error": "Missing PatientID"},
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from unittest.mock import Mock, patch
2+
3+
import jwt
4+
import pytest
5+
from django.conf import settings
6+
7+
from ..authentication import Authentication
8+
9+
10+
@patch(f"{Authentication.__module__}.logger")
11+
class TestAuthentication:
12+
@pytest.fixture(autouse=True)
13+
def setup_env(self):
14+
with patch.dict(
15+
"os.environ",
16+
{
17+
"API_AUDIENCE": "test_audience",
18+
"TENANT_ID": "test_tenant_id",
19+
"BYPASS_API_TOKEN_AUTH": "false",
20+
},
21+
):
22+
yield
23+
24+
@pytest.fixture
25+
def mock_jwks_signing_key(self):
26+
with patch.object(
27+
jwt.PyJWKClient,
28+
"get_signing_key_from_jwt",
29+
return_value=Mock(key="test_signing_key"),
30+
):
31+
yield
32+
33+
@patch.object(
34+
jwt.PyJWKClient, "get_signing_key_from_jwt", side_effect=jwt.PyJWKClientError
35+
)
36+
def test_with_no_matching_signing_key(self, mock_signing_key_error, mock_logger):
37+
authenticator = Authentication()
38+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
39+
mock_logger.exception.assert_called_with(
40+
"Error fetching JWKS keys from Azure AD."
41+
)
42+
43+
@patch(
44+
f"{Authentication.__module__}.jwt.decode", side_effect=jwt.ExpiredSignatureError
45+
)
46+
def test_with_expired_signature(
47+
self, mock_decode, mock_logger, mock_jwks_signing_key
48+
):
49+
authenticator = Authentication()
50+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
51+
mock_logger.exception.assert_called_with("Token is expired")
52+
53+
@patch(
54+
f"{Authentication.__module__}.jwt.decode", side_effect=jwt.InvalidAudienceError
55+
)
56+
def test_with_invalid_claims(self, _, mock_logger, mock_jwks_signing_key):
57+
authenticator = Authentication()
58+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
59+
mock_logger.exception.assert_called_with(
60+
"Invalid claims. Please check the audience and issuer."
61+
)
62+
63+
@patch(
64+
f"{Authentication.__module__}.jwt.decode", side_effect=jwt.InvalidIssuerError
65+
)
66+
def test_with_invalid_issuer(self, _, mock_logger, mock_jwks_signing_key):
67+
authenticator = Authentication()
68+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
69+
mock_logger.exception.assert_called_with(
70+
"Invalid claims. Please check the audience and issuer."
71+
)
72+
73+
@patch(f"{Authentication.__module__}.jwt.decode", side_effect=jwt.InvalidTokenError)
74+
def test_with_invalid_token(self, _, mock_logger, mock_jwks_signing_key):
75+
authenticator = Authentication()
76+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
77+
mock_logger.exception.assert_called_with("Token is invalid")
78+
79+
@patch(f"{Authentication.__module__}.jwt.decode", side_effect=Exception)
80+
def test_with_unexpected_exception(self, _, mock_logger, mock_jwks_signing_key):
81+
authenticator = Authentication()
82+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None
83+
mock_logger.exception.assert_called_with(
84+
"Unable to parse authentication token."
85+
)
86+
87+
@patch(
88+
f"{Authentication.__module__}.jwt.decode", return_value={"sub": "1234567890"}
89+
)
90+
def test_with_valid_token(self, _, mock_logger, mock_jwks_signing_key):
91+
authenticator = Authentication()
92+
assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) == {
93+
"sub": "1234567890"
94+
}
95+
mock_logger.exception.assert_not_called()
96+
97+
def test_authentication_bypass_enabled(self, mock_logger, mock_jwks_signing_key):
98+
with patch.object(settings, "BYPASS_API_TOKEN_AUTH", return_value=True):
99+
authenticator = Authentication()
100+
assert authenticator(
101+
Mock(headers={"Authorization": "Bearer anytoken"})
102+
) == {"sub": "bypass_user"}

0 commit comments

Comments
 (0)