diff --git a/manage_breast_screening/config/settings/base.py b/manage_breast_screening/config/settings/base.py index c0dfe6b81..3d1defd76 100644 --- a/manage_breast_screening/config/settings/base.py +++ b/manage_breast_screening/config/settings/base.py @@ -382,3 +382,5 @@ def list_env(key): "style-src": (CSP_SELF,), } } + +BYPASS_API_TOKEN_AUTH = boolean_env("BYPASS_API_TOKEN_AUTH", default=False) diff --git a/manage_breast_screening/dicom/api.py b/manage_breast_screening/dicom/api.py index 1246283a0..bbf663e42 100644 --- a/manage_breast_screening/dicom/api.py +++ b/manage_breast_screening/dicom/api.py @@ -10,12 +10,15 @@ from manage_breast_screening.core.api_schema import ErrorResponse, StatusResponse from manage_breast_screening.gateway.models import GatewayAction, GatewayActionStatus +from .authentication import Authentication from .dicom_recorder import DicomRecorder -router = Router() +router = Router(auth=Authentication()) logger = logging.getLogger(__name__) +MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB + class SuccessResponse(ninja.Schema): study_instance_uid: str @@ -41,8 +44,7 @@ def upload(request, source_message_id: str, file: File[UploadedFile]): """ Accepts PUT with a single DICOM file in form field 'file' """ - max_size = 100 * 1024 * 1024 - if file.size > max_size: + if file.size > MAX_FILE_SIZE: return 400, { "title": "File too large", "status": 400, diff --git a/manage_breast_screening/dicom/authentication.py b/manage_breast_screening/dicom/authentication.py new file mode 100644 index 000000000..c23f94447 --- /dev/null +++ b/manage_breast_screening/dicom/authentication.py @@ -0,0 +1,96 @@ +import logging +import os +from functools import cached_property + +import jwt +from django.conf import settings +from ninja.security import HttpBearer + +logger = logging.getLogger(__name__) + +ALLOWED_ALGORITHMS = ["RS256"] +JWT_SET_CACHE_TTL_SECONDS = 3600 + + +class Authentication(HttpBearer): + def authenticate(self, _, token) -> dict | None: + """ + Authenticates the incoming request by validating the JWT token. + """ + if self.bypass_authentication: + logger.warning("Authentication bypass is enabled.") + return {"sub": "bypass_user"} + + return self._decode(token) + + def _decode(self, token: str) -> dict | None: + """ + Decodes and validates the JWT token using the provided RSA key. + Checks the signature, audience, and issuer claims to ensure the token is valid and intended for this API. + """ + try: + signing_key = self.jwks_client.get_signing_key_from_jwt(token) + payload = jwt.decode( + token, + signing_key.key, + algorithms=ALLOWED_ALGORITHMS, + audience=self.audience, + issuer=self.issuers, + ) + return payload + except jwt.PyJWKClientError: + logger.exception("Error fetching JWKS keys from Azure AD.") + except jwt.ExpiredSignatureError: + logger.exception("Token is expired") + except (jwt.InvalidAudienceError, jwt.InvalidIssuerError): + logger.exception("Invalid claims. Please check the audience and issuer.") + except jwt.InvalidTokenError: + logger.exception("Token is invalid") + except Exception: + logger.exception("Unable to parse authentication token.") + + @cached_property + def jwks_client(self) -> jwt.PyJWKClient: + """ + Creates a PyJWKClient instance for fetching and caching the JWKS keys from Azure AD. + Caching is enabled to improve performance and reduce the number of network requests to Azure AD. + The cache will be refreshed after the specified TTL expires. + """ + return jwt.PyJWKClient( + self.discovery_keys_url, + cache_jwk_set=True, + cache_keys=True, + lifespan=JWT_SET_CACHE_TTL_SECONDS, + ) + + @cached_property + def discovery_keys_url(self) -> str: + return f"https://login.microsoftonline.com/{self.tenant_id}/discovery/v2.0/keys" + + @property + def audience(self) -> str | None: + """ + The expected audience claim in the JWT token. This should match the API_AUDIENCE environment variable. + """ + return os.getenv("API_AUDIENCE") + + @property + def tenant_id(self) -> str | None: + """ + The Azure AD tenant ID. This should be set as the TENANT_ID environment variable. + """ + return os.getenv("TENANT_ID", "") + + @cached_property + def issuers(self) -> list: + """ + The expected issuer claim(s) in the JWT token. This should match the tenant ID and the Azure AD endpoints. + """ + return [ + f"https://sts.windows.net/{self.tenant_id}/", + f"https://login.microsoftonline.com/{self.tenant_id}/v2.0/", + ] + + @property + def bypass_authentication(self) -> bool: + return getattr(settings, "BYPASS_API_TOKEN_AUTH", False) diff --git a/manage_breast_screening/dicom/tests/test_api.py b/manage_breast_screening/dicom/tests/test_api.py index 3d8827685..8a7589f09 100644 --- a/manage_breast_screening/dicom/tests/test_api.py +++ b/manage_breast_screening/dicom/tests/test_api.py @@ -6,9 +6,9 @@ import pytest from django.core.files.uploadedfile import SimpleUploadedFile from ninja.testing import TestClient +from pydicom.uid import generate_uid from manage_breast_screening.core.api import api -from manage_breast_screening.dicom.models import Study from manage_breast_screening.gateway.models import GatewayActionStatus from manage_breast_screening.gateway.tests.factories import GatewayActionFactory from manage_breast_screening.participants.models.appointment import ( @@ -16,11 +16,22 @@ ) from manage_breast_screening.participants.tests.factories import AppointmentFactory +from ..authentication import Authentication +from ..dicom_recorder import DicomRecorder +from ..models import Study + os.environ["NINJA_SKIP_REGISTRY"] = "yes" client = TestClient(api) +@pytest.fixture(autouse=True) +def setup(monkeypatch): + monkeypatch.setenv("API_ENABLED", "true") + monkeypatch.setenv("API_AUDIENCE", "test_audience") + monkeypatch.setenv("TENANT_ID", "test_tenant_id") + + @pytest.fixture def dicom_file(dataset) -> bytes: with io.BytesIO() as buffer: @@ -38,11 +49,14 @@ def appointment_stub(): ) -@pytest.mark.django_db -def test_upload_success(dataset, dicom_file, monkeypatch): - monkeypatch.setenv("API_ENABLED", "true") - monkeypatch.setenv("API_AUTH_TOKEN", "testtoken") +@pytest.fixture +def mock_authentication(): + with patch.object(Authentication, "authenticate", return_value={"sub": "testuser"}): + yield + +@pytest.mark.django_db +def test_upload_success(dataset, dicom_file, mock_authentication, appointment_stub): appointment = AppointmentFactory(current_status=AppointmentStatusNames.IN_PROGRESS) with patch( @@ -52,7 +66,7 @@ def test_upload_success(dataset, dicom_file, monkeypatch): response = client.put( f"/dicom/{appointment.pk}", FILES={"file": dicom_file}, - headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")}, + headers={"Authorization": "Bearer testtoken"}, ) assert response.status_code == 201 @@ -65,23 +79,17 @@ def test_upload_success(dataset, dicom_file, monkeypatch): assert study.source_message_id == str(appointment.pk) -def test_upload_no_file(monkeypatch): - monkeypatch.setenv("API_ENABLED", "true") - monkeypatch.setenv("API_AUTH_TOKEN", "testtoken") - +def test_upload_no_file(mock_authentication): response = client.put( "/dicom/abc123", FILES={"file": None}, - headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")}, + headers={"Authorization": "Bearer testtoken"}, ) assert response.status_code == 422 -def test_upload_invalid_file(monkeypatch, appointment_stub): - monkeypatch.setenv("API_ENABLED", "true") - monkeypatch.setenv("API_AUTH_TOKEN", "testtoken") - +def test_upload_invalid_file(mock_authentication, appointment_stub): invalid_file = SimpleUploadedFile( "invalid.dcm", b"not a dicom file", content_type="application/dicom" ) @@ -93,7 +101,7 @@ def test_upload_invalid_file(monkeypatch, appointment_stub): response = client.put( "/dicom/abc123", FILES={"file": invalid_file}, - headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")}, + headers={"Authorization": "Bearer testtoken"}, ) assert response.status_code == 400 @@ -102,16 +110,13 @@ def test_upload_invalid_file(monkeypatch, appointment_stub): assert response.json()["detail"] == "The uploaded file is not a valid DICOM file." -def test_upload_file_thats_too_large(monkeypatch): - monkeypatch.setenv("API_ENABLED", "true") - monkeypatch.setenv("API_AUTH_TOKEN", "testtoken") - +def test_upload_file_thats_too_large(mock_authentication): invalid_file = MagicMock(spec=SimpleUploadedFile, size=101 * 1024 * 1024) response = client.put( "/dicom/abc123", FILES={"file": invalid_file}, - headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")}, + headers={"Authorization": "Bearer testtoken"}, ) assert response.status_code == 400 @@ -120,10 +125,7 @@ def test_upload_file_thats_too_large(monkeypatch): assert response.json()["detail"] == "The file cannot be larger than 100MB" -def test_upload_missing_uids(dataset, monkeypatch, appointment_stub): - monkeypatch.setenv("API_ENABLED", "true") - monkeypatch.setenv("API_AUTH_TOKEN", "testtoken") - +def test_upload_missing_uids(dataset, mock_authentication, appointment_stub): del dataset.StudyInstanceUID del dataset.SeriesInstanceUID del dataset.SOPInstanceUID @@ -142,7 +144,7 @@ def test_upload_missing_uids(dataset, monkeypatch, appointment_stub): response = client.put( "/dicom/abc123", FILES={"file": dicom_file}, - headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")}, + headers={"Authorization": "Bearer testtoken"}, ) assert response.status_code == 400 @@ -154,10 +156,9 @@ def test_upload_missing_uids(dataset, monkeypatch, appointment_stub): ) -def test_upload_appointment_not_in_progress(dicom_file, monkeypatch, appointment_stub): - monkeypatch.setenv("API_ENABLED", "true") - monkeypatch.setenv("API_AUTH_TOKEN", "testtoken") - +def test_upload_appointment_not_in_progress( + dicom_file, mock_authentication, appointment_stub +): appointment_stub.is_in_progress.return_value = False with patch( @@ -167,29 +168,26 @@ def test_upload_appointment_not_in_progress(dicom_file, monkeypatch, appointment response = client.put( "/dicom/abc123", FILES={"file": dicom_file}, - headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")}, + headers={"Authorization": "Bearer testtoken"}, ) assert response.status_code == 500 assert response.json()["title"] == "Internal Server Error" -@pytest.mark.django_db -def test_upload_when_api_disabled(dicom_file, monkeypatch): +def test_upload_when_api_disabled(dicom_file, mock_authentication, monkeypatch): monkeypatch.setenv("API_ENABLED", "false") - monkeypatch.setenv("API_AUTH_TOKEN", "testtoken") response = client.put( "/dicom/abc123", FILES={"file": dicom_file}, - headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")}, + headers={"Authorization": "Bearer testtoken"}, ) assert response.status_code == 403 assert response.json()["status"] == "API is not available" -@pytest.mark.django_db def test_upload_no_auth(dicom_file): response = client.put( "/dicom/abc123", @@ -202,11 +200,41 @@ def test_upload_no_auth(dicom_file): } -@pytest.mark.django_db -def test_report_failure(monkeypatch): - monkeypatch.setenv("API_ENABLED", "true") - monkeypatch.setenv("API_AUTH_TOKEN", "testtoken") +def test_upload_invalid_auth(dicom_file): + response = client.put( + "/dicom/abc123", + FILES={"file": dicom_file}, + headers={"Authorization": "Bearer invalidtoken"}, + ) + assert response.status_code == 401 + assert response.json() == { + "detail": "Unauthorized", + } + + +def test_upload_bypass_token_validation(dicom_file): + with patch.object(Authentication, "bypass_authentication", return_value=True): + with patch.object( + DicomRecorder, + "get_or_create_records", + return_value=( + MagicMock(study_instance_uid=generate_uid()), + MagicMock(series_instance_uid=generate_uid()), + MagicMock(sop_instance_uid=generate_uid(), id=1), + ), + ): + response = client.put( + "/dicom/abc123", + FILES={"file": dicom_file}, + headers={"Authorization": "Bearer anytoken"}, + ) + + assert response.status_code == 201 + + +@pytest.mark.django_db +def test_report_failure(mock_authentication): action = GatewayActionFactory() response = client.patch( @@ -225,10 +253,7 @@ def test_report_failure(monkeypatch): @pytest.mark.django_db -def test_report_failure_action_not_found(monkeypatch): - monkeypatch.setenv("API_ENABLED", "true") - monkeypatch.setenv("API_AUTH_TOKEN", "testtoken") - +def test_report_failure_action_not_found(mock_authentication): response = client.patch( "/dicom/00000000-0000-0000-0000-000000000000/failure", json={"error": "Missing PatientID"}, diff --git a/manage_breast_screening/dicom/tests/test_authentication.py b/manage_breast_screening/dicom/tests/test_authentication.py new file mode 100644 index 000000000..d8e6ee770 --- /dev/null +++ b/manage_breast_screening/dicom/tests/test_authentication.py @@ -0,0 +1,102 @@ +from unittest.mock import Mock, patch + +import jwt +import pytest +from django.conf import settings + +from ..authentication import Authentication + + +@patch(f"{Authentication.__module__}.logger") +class TestAuthentication: + @pytest.fixture(autouse=True) + def setup_env(self): + with patch.dict( + "os.environ", + { + "API_AUDIENCE": "test_audience", + "TENANT_ID": "test_tenant_id", + "BYPASS_API_TOKEN_AUTH": "false", + }, + ): + yield + + @pytest.fixture + def mock_jwks_signing_key(self): + with patch.object( + jwt.PyJWKClient, + "get_signing_key_from_jwt", + return_value=Mock(key="test_signing_key"), + ): + yield + + @patch.object( + jwt.PyJWKClient, "get_signing_key_from_jwt", side_effect=jwt.PyJWKClientError + ) + def test_with_no_matching_signing_key(self, mock_signing_key_error, mock_logger): + authenticator = Authentication() + assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None + mock_logger.exception.assert_called_with( + "Error fetching JWKS keys from Azure AD." + ) + + @patch( + f"{Authentication.__module__}.jwt.decode", side_effect=jwt.ExpiredSignatureError + ) + def test_with_expired_signature( + self, mock_decode, mock_logger, mock_jwks_signing_key + ): + authenticator = Authentication() + assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None + mock_logger.exception.assert_called_with("Token is expired") + + @patch( + f"{Authentication.__module__}.jwt.decode", side_effect=jwt.InvalidAudienceError + ) + def test_with_invalid_claims(self, _, mock_logger, mock_jwks_signing_key): + authenticator = Authentication() + assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None + mock_logger.exception.assert_called_with( + "Invalid claims. Please check the audience and issuer." + ) + + @patch( + f"{Authentication.__module__}.jwt.decode", side_effect=jwt.InvalidIssuerError + ) + def test_with_invalid_issuer(self, _, mock_logger, mock_jwks_signing_key): + authenticator = Authentication() + assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None + mock_logger.exception.assert_called_with( + "Invalid claims. Please check the audience and issuer." + ) + + @patch(f"{Authentication.__module__}.jwt.decode", side_effect=jwt.InvalidTokenError) + def test_with_invalid_token(self, _, mock_logger, mock_jwks_signing_key): + authenticator = Authentication() + assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None + mock_logger.exception.assert_called_with("Token is invalid") + + @patch(f"{Authentication.__module__}.jwt.decode", side_effect=Exception) + def test_with_unexpected_exception(self, _, mock_logger, mock_jwks_signing_key): + authenticator = Authentication() + assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) is None + mock_logger.exception.assert_called_with( + "Unable to parse authentication token." + ) + + @patch( + f"{Authentication.__module__}.jwt.decode", return_value={"sub": "1234567890"} + ) + def test_with_valid_token(self, _, mock_logger, mock_jwks_signing_key): + authenticator = Authentication() + assert authenticator(Mock(headers={"Authorization": "Bearer abc123"})) == { + "sub": "1234567890" + } + mock_logger.exception.assert_not_called() + + def test_authentication_bypass_enabled(self, mock_logger, mock_jwks_signing_key): + with patch.object(settings, "BYPASS_API_TOKEN_AUTH", return_value=True): + authenticator = Authentication() + assert authenticator( + Mock(headers={"Authorization": "Bearer anytoken"}) + ) == {"sub": "bypass_user"}