Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions manage_breast_screening/config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,5 @@ def list_env(key):
"style-src": (CSP_SELF,),
}
}

BYPASS_API_TOKEN_AUTH = boolean_env("BYPASS_API_TOKEN_AUTH", default=False)
8 changes: 5 additions & 3 deletions manage_breast_screening/dicom/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
from manage_breast_screening.core.api_schema import ErrorResponse, StatusResponse
from manage_breast_screening.gateway.models import GatewayAction, GatewayActionStatus

from .authentication import Authentication
from .dicom_recorder import DicomRecorder

router = Router()
router = Router(auth=Authentication())

logger = logging.getLogger(__name__)

MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB


class SuccessResponse(ninja.Schema):
study_instance_uid: str
Expand All @@ -41,8 +44,7 @@ def upload(request, source_message_id: str, file: File[UploadedFile]):
"""
Accepts PUT with a single DICOM file in form field 'file'
"""
max_size = 100 * 1024 * 1024
if file.size > max_size:
if file.size > MAX_FILE_SIZE:
return 400, {
"title": "File too large",
"status": 400,
Expand Down
96 changes: 96 additions & 0 deletions manage_breast_screening/dicom/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import logging
import os
from functools import cached_property

import jwt
from django.conf import settings
from ninja.security import HttpBearer

logger = logging.getLogger(__name__)

ALLOWED_ALGORITHMS = ["RS256"]
JWT_SET_CACHE_TTL_SECONDS = 3600


class Authentication(HttpBearer):
def authenticate(self, _, token) -> dict | None:
"""
Authenticates the incoming request by validating the JWT token.
"""
if self.bypass_authentication:
logger.warning("Authentication bypass is enabled.")
return {"sub": "bypass_user"}

return self._decode(token)

def _decode(self, token: str) -> dict | None:
"""
Decodes and validates the JWT token using the provided RSA key.
Checks the signature, audience, and issuer claims to ensure the token is valid and intended for this API.
"""
try:
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
payload = jwt.decode(
token,
signing_key.key,
algorithms=ALLOWED_ALGORITHMS,
audience=self.audience,
issuer=self.issuers,
)
return payload
except jwt.PyJWKClientError:
logger.exception("Error fetching JWKS keys from Azure AD.")
except jwt.ExpiredSignatureError:
logger.exception("Token is expired")
except (jwt.InvalidAudienceError, jwt.InvalidIssuerError):
logger.exception("Invalid claims. Please check the audience and issuer.")
except jwt.InvalidTokenError:
logger.exception("Token is invalid")
except Exception:
logger.exception("Unable to parse authentication token.")

@cached_property
def jwks_client(self) -> jwt.PyJWKClient:
Comment thread
carlosmartinez marked this conversation as resolved.
"""
Creates a PyJWKClient instance for fetching and caching the JWKS keys from Azure AD.
Caching is enabled to improve performance and reduce the number of network requests to Azure AD.
The cache will be refreshed after the specified TTL expires.
"""
return jwt.PyJWKClient(
self.discovery_keys_url,
cache_jwk_set=True,
cache_keys=True,
lifespan=JWT_SET_CACHE_TTL_SECONDS,
)

@cached_property
def discovery_keys_url(self) -> str:
return f"https://login.microsoftonline.com/{self.tenant_id}/discovery/v2.0/keys"

@property
def audience(self) -> str | None:
"""
The expected audience claim in the JWT token. This should match the API_AUDIENCE environment variable.
"""
return os.getenv("API_AUDIENCE")

@property
def tenant_id(self) -> str | None:
"""
The Azure AD tenant ID. This should be set as the TENANT_ID environment variable.
"""
return os.getenv("TENANT_ID", "")

@cached_property
def issuers(self) -> list:
"""
The expected issuer claim(s) in the JWT token. This should match the tenant ID and the Azure AD endpoints.
"""
return [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe worth constantising these?

AZURE_AD_V1_ISSUER_TEMPLATE = "https://sts.windows.net/{tenant_id}/"
AZURE_AD_V2_ISSUER_TEMPLATE = "https://login.microsoftonline.com/{tenant_id}/v2.0/"

or some such

Copy link
Copy Markdown
Contributor Author

@steventux steventux Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't work without converting tenant id to a constant too. something to do with the templating.
I've cached the string/array properties which use tenant id to prevent unnecessary interpolation, this has the benefit of docstrings in the body of the property definition.

f"https://sts.windows.net/{self.tenant_id}/",
f"https://login.microsoftonline.com/{self.tenant_id}/v2.0/",
]

@property
def bypass_authentication(self) -> bool:
return getattr(settings, "BYPASS_API_TOKEN_AUTH", False)
113 changes: 69 additions & 44 deletions manage_breast_screening/dicom/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,32 @@
import pytest
from django.core.files.uploadedfile import SimpleUploadedFile
from ninja.testing import TestClient
from pydicom.uid import generate_uid

from manage_breast_screening.core.api import api
from manage_breast_screening.dicom.models import Study
from manage_breast_screening.gateway.models import GatewayActionStatus
from manage_breast_screening.gateway.tests.factories import GatewayActionFactory
from manage_breast_screening.participants.models.appointment import (
AppointmentStatusNames,
)
from manage_breast_screening.participants.tests.factories import AppointmentFactory

from ..authentication import Authentication
from ..dicom_recorder import DicomRecorder
from ..models import Study

os.environ["NINJA_SKIP_REGISTRY"] = "yes"

client = TestClient(api)


@pytest.fixture(autouse=True)
def setup(monkeypatch):
monkeypatch.setenv("API_ENABLED", "true")
monkeypatch.setenv("API_AUDIENCE", "test_audience")
monkeypatch.setenv("TENANT_ID", "test_tenant_id")


@pytest.fixture
def dicom_file(dataset) -> bytes:
with io.BytesIO() as buffer:
Expand All @@ -38,11 +49,14 @@ def appointment_stub():
)


@pytest.mark.django_db
def test_upload_success(dataset, dicom_file, monkeypatch):
monkeypatch.setenv("API_ENABLED", "true")
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
@pytest.fixture
def mock_authentication():
with patch.object(Authentication, "authenticate", return_value={"sub": "testuser"}):
yield


@pytest.mark.django_db
def test_upload_success(dataset, dicom_file, mock_authentication, appointment_stub):
appointment = AppointmentFactory(current_status=AppointmentStatusNames.IN_PROGRESS)

with patch(
Expand All @@ -52,7 +66,7 @@ def test_upload_success(dataset, dicom_file, monkeypatch):
response = client.put(
f"/dicom/{appointment.pk}",
FILES={"file": dicom_file},
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
headers={"Authorization": "Bearer testtoken"},
)

assert response.status_code == 201
Expand All @@ -65,23 +79,17 @@ def test_upload_success(dataset, dicom_file, monkeypatch):
assert study.source_message_id == str(appointment.pk)


def test_upload_no_file(monkeypatch):
monkeypatch.setenv("API_ENABLED", "true")
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")

def test_upload_no_file(mock_authentication):
response = client.put(
"/dicom/abc123",
FILES={"file": None},
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
headers={"Authorization": "Bearer testtoken"},
)

assert response.status_code == 422


def test_upload_invalid_file(monkeypatch, appointment_stub):
monkeypatch.setenv("API_ENABLED", "true")
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")

def test_upload_invalid_file(mock_authentication, appointment_stub):
invalid_file = SimpleUploadedFile(
"invalid.dcm", b"not a dicom file", content_type="application/dicom"
)
Expand All @@ -93,7 +101,7 @@ def test_upload_invalid_file(monkeypatch, appointment_stub):
response = client.put(
"/dicom/abc123",
FILES={"file": invalid_file},
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
headers={"Authorization": "Bearer testtoken"},
)

assert response.status_code == 400
Expand All @@ -102,16 +110,13 @@ def test_upload_invalid_file(monkeypatch, appointment_stub):
assert response.json()["detail"] == "The uploaded file is not a valid DICOM file."


def test_upload_file_thats_too_large(monkeypatch):
monkeypatch.setenv("API_ENABLED", "true")
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")

def test_upload_file_thats_too_large(mock_authentication):
invalid_file = MagicMock(spec=SimpleUploadedFile, size=101 * 1024 * 1024)

response = client.put(
"/dicom/abc123",
FILES={"file": invalid_file},
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
headers={"Authorization": "Bearer testtoken"},
)

assert response.status_code == 400
Expand All @@ -120,10 +125,7 @@ def test_upload_file_thats_too_large(monkeypatch):
assert response.json()["detail"] == "The file cannot be larger than 100MB"


def test_upload_missing_uids(dataset, monkeypatch, appointment_stub):
monkeypatch.setenv("API_ENABLED", "true")
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")

def test_upload_missing_uids(dataset, mock_authentication, appointment_stub):
del dataset.StudyInstanceUID
del dataset.SeriesInstanceUID
del dataset.SOPInstanceUID
Expand All @@ -142,7 +144,7 @@ def test_upload_missing_uids(dataset, monkeypatch, appointment_stub):
response = client.put(
"/dicom/abc123",
FILES={"file": dicom_file},
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
headers={"Authorization": "Bearer testtoken"},
)

assert response.status_code == 400
Expand All @@ -154,10 +156,9 @@ def test_upload_missing_uids(dataset, monkeypatch, appointment_stub):
)


def test_upload_appointment_not_in_progress(dicom_file, monkeypatch, appointment_stub):
monkeypatch.setenv("API_ENABLED", "true")
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")

def test_upload_appointment_not_in_progress(
dicom_file, mock_authentication, appointment_stub
):
appointment_stub.is_in_progress.return_value = False

with patch(
Expand All @@ -167,29 +168,26 @@ def test_upload_appointment_not_in_progress(dicom_file, monkeypatch, appointment
response = client.put(
"/dicom/abc123",
FILES={"file": dicom_file},
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
headers={"Authorization": "Bearer testtoken"},
)

assert response.status_code == 500
assert response.json()["title"] == "Internal Server Error"


@pytest.mark.django_db
def test_upload_when_api_disabled(dicom_file, monkeypatch):
def test_upload_when_api_disabled(dicom_file, mock_authentication, monkeypatch):
monkeypatch.setenv("API_ENABLED", "false")
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")

response = client.put(
"/dicom/abc123",
FILES={"file": dicom_file},
headers={"Authorization": "Bearer " + os.getenv("API_AUTH_TOKEN", "")},
headers={"Authorization": "Bearer testtoken"},
)

assert response.status_code == 403
assert response.json()["status"] == "API is not available"


@pytest.mark.django_db
def test_upload_no_auth(dicom_file):
response = client.put(
"/dicom/abc123",
Expand All @@ -202,11 +200,41 @@ def test_upload_no_auth(dicom_file):
}


@pytest.mark.django_db
def test_report_failure(monkeypatch):
monkeypatch.setenv("API_ENABLED", "true")
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")
def test_upload_invalid_auth(dicom_file):
response = client.put(
"/dicom/abc123",
FILES={"file": dicom_file},
headers={"Authorization": "Bearer invalidtoken"},
)

assert response.status_code == 401
assert response.json() == {
"detail": "Unauthorized",
}


def test_upload_bypass_token_validation(dicom_file):
with patch.object(Authentication, "bypass_authentication", return_value=True):
with patch.object(
DicomRecorder,
"get_or_create_records",
return_value=(
MagicMock(study_instance_uid=generate_uid()),
MagicMock(series_instance_uid=generate_uid()),
MagicMock(sop_instance_uid=generate_uid(), id=1),
),
):
response = client.put(
"/dicom/abc123",
FILES={"file": dicom_file},
headers={"Authorization": "Bearer anytoken"},
)

assert response.status_code == 201


@pytest.mark.django_db
def test_report_failure(mock_authentication):
action = GatewayActionFactory()

response = client.patch(
Expand All @@ -225,10 +253,7 @@ def test_report_failure(monkeypatch):


@pytest.mark.django_db
def test_report_failure_action_not_found(monkeypatch):
monkeypatch.setenv("API_ENABLED", "true")
monkeypatch.setenv("API_AUTH_TOKEN", "testtoken")

def test_report_failure_action_not_found(mock_authentication):
response = client.patch(
"/dicom/00000000-0000-0000-0000-000000000000/failure",
json={"error": "Missing PatientID"},
Expand Down
Loading
Loading