Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions manage_breast_screening/config/.env.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ API_AUTH_TOKEN=changeme
# Django Ninja API
API_ENABLED=true
API_DOCS_ENABLED=true
BYPASS_API_AUTHORISATION=false
BYPASS_API_AUTHENTICATION=false

# Automatic loading of PACS images from gateway
GATEWAY_IMAGES_ENABLED=False
Expand Down
3 changes: 2 additions & 1 deletion manage_breast_screening/config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,4 +383,5 @@ def list_env(key):
}
}

BYPASS_API_TOKEN_AUTH = boolean_env("BYPASS_API_TOKEN_AUTH", default=False)
BYPASS_API_AUTHENTICATION = boolean_env("BYPASS_API_AUTHENTICATION", default=False)
BYPASS_API_AUTHORISATION = boolean_env("BYPASS_API_AUTHORISATION", default=False)
14 changes: 8 additions & 6 deletions manage_breast_screening/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@

from manage_breast_screening.dicom.api import router as dicom_router

from .api_schema import StatusResponse
from .api_schema import ErrorResponse, StatusResponse


def check_availability():
def decorator(func):
@wraps(func)
def wrapper(request, *args, **kwargs):
if not os.getenv("API_ENABLED", "true").lower() == "true":
return 403, {"status": "API is not available"}
if os.getenv("API_ENABLED", "true").lower() != "true":
return 403, {
"title": "Forbidden",
"status": 403,
"detail": "API is not available",
}
return func(request, *args, **kwargs)

return wrapper
Expand Down Expand Up @@ -58,8 +62,6 @@ def authenticate(self, request, token):
dicom_router.add_decorator(check_availability())


@api.get(
"/status", response={200: StatusResponse, 403: StatusResponse}, tags=["Status"]
)
@api.get("/status", response={200: StatusResponse, 403: ErrorResponse}, tags=["Status"])
def status(request):
return 200, {"status": "API is available"}
5 changes: 4 additions & 1 deletion manage_breast_screening/core/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def test_status_endpoint_api_disabled(monkeypatch):
response = client.get("/status", headers={"Authorization": "Bearer testtoken"})

assert response.status_code == 403
assert response.json() == {"status": "API is not available"}
json = response.json()
assert json["title"] == "Forbidden"
assert json["status"] == 403
assert json["detail"] == "API is not available"


def test_status_wrong_auth(monkeypatch):
Expand Down
19 changes: 18 additions & 1 deletion manage_breast_screening/dicom/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from manage_breast_screening.gateway.models import GatewayAction, GatewayActionStatus

from .authentication import Authentication
from .authorisation import Authorisation
from .dicom_recorder import DicomRecorder

router = Router(auth=Authentication())
Expand All @@ -36,7 +37,7 @@ class FailurePayload(ninja.Schema):
response={
201: SuccessResponse,
400: ErrorResponse,
403: StatusResponse,
403: ErrorResponse,
500: ErrorResponse,
},
)
Expand All @@ -58,6 +59,22 @@ def upload(request, source_message_id: str, file: File[UploadedFile]):
"detail": "A DICOM file must be uploaded in the 'file' form field.",
}

oid = request.auth.get("oid")

if oid is None:
return 403, {
"title": "Forbidden",
"status": 403,
"detail": "Authentication failed: OID claim is missing.",
}

if Authorisation.authorise(source_message_id, oid) is False:
return 403, {
"title": "Forbidden",
"status": 403,
"detail": "You do not have permission to upload for this message ID.",
}

try:
study, series, image = DicomRecorder.get_or_create_records(
source_message_id, file
Expand Down
9 changes: 5 additions & 4 deletions manage_breast_screening/dicom/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@


class Authentication(HttpBearer):
def authenticate(self, _, token) -> dict | None:
def authenticate(self, request, token) -> dict | None:
"""
Authenticates the incoming request by validating the JWT token.
"""
if self.bypass_authentication:
logger.warning("Authentication bypass is enabled.")
return {"sub": "bypass_user"}
return {"oid": "bypass_object_id", "sub": "bypass_user"}

return self._decode(token)
request.auth = self._decode(token)
return request.auth

def _decode(self, token: str) -> dict | None:
"""
Expand Down Expand Up @@ -93,4 +94,4 @@ def issuers(self) -> list:

@property
def bypass_authentication(self) -> bool:
return getattr(settings, "BYPASS_API_TOKEN_AUTH", False)
return getattr(settings, "BYPASS_API_AUTHENTICATION", False)
34 changes: 34 additions & 0 deletions manage_breast_screening/dicom/authorisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
from datetime import date

from django.conf import settings

from manage_breast_screening.gateway.models import GatewayAction, GatewayActionStatus

logger = logging.getLogger(__name__)


class Authorisation:
@staticmethod
def authorise(source_message_id: str, oid: str) -> bool:
"""
Check for the existence of a GatewayAction with the given source_message_id and oid, created today.
This provides a link between the source_message_id we send to the gateway in the appointment workflow
and the oid associated with the system assigned managed identity of the gateway stored in the Gateway model.
"""
if __class__.bypass_authorisation():
return True

return GatewayAction.objects.filter(
id=source_message_id,
gateway__oid=oid,
created_at__date=date.today(),
status__in=[
GatewayActionStatus.SENT,
GatewayActionStatus.CONFIRMED,
],
).exists()

@staticmethod
def bypass_authorisation() -> bool:
return getattr(settings, "BYPASS_API_AUTHORISATION", False)
91 changes: 64 additions & 27 deletions manage_breast_screening/dicom/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@

from manage_breast_screening.core.api import api
from manage_breast_screening.gateway.models import GatewayActionStatus
from manage_breast_screening.gateway.tests.factories import GatewayActionFactory
from manage_breast_screening.gateway.tests.factories import (
GatewayActionFactory,
GatewayFactory,
)
from manage_breast_screening.participants.models.appointment import (
AppointmentStatusNames,
)
from manage_breast_screening.participants.tests.factories import AppointmentFactory

from ..authentication import Authentication
from ..authorisation import Authorisation
from ..dicom_recorder import DicomRecorder
from ..models import Study

Expand Down Expand Up @@ -50,33 +54,61 @@ def appointment_stub():


@pytest.fixture
def mock_authentication():
with patch.object(Authentication, "authenticate", return_value={"sub": "testuser"}):
def source_message_id():
return "00000000-0000-0000-0000-000000000009"


@pytest.fixture
def gateway_oid():
return "00000000-0000-0000-0000-000000000001"


@pytest.fixture
def gateway_action(source_message_id, gateway_oid):
return GatewayActionFactory(
id=source_message_id,
status=GatewayActionStatus.SENT,
gateway=GatewayFactory(oid=gateway_oid),
)


@pytest.fixture
def mock_authentication(gateway_oid):
with patch.object(
Authentication, "_decode", return_value={"oid": gateway_oid, "sub": "testuser"}
):
yield


@pytest.fixture
def mock_authorisation(gateway_oid):
with patch.object(Authorisation, "authorise", return_value=True):
yield


@pytest.mark.django_db
def test_upload_success(dataset, dicom_file, mock_authentication, appointment_stub):
def test_upload_success(dataset, dicom_file, mock_authentication, gateway_action):
appointment = AppointmentFactory(current_status=AppointmentStatusNames.IN_PROGRESS)

with patch(
"manage_breast_screening.dicom.dicom_recorder.lookup_appointment",
return_value=appointment,
):
response = client.put(
f"/dicom/{appointment.pk}",
f"/dicom/{gateway_action.id}",
FILES={"file": dicom_file},
headers={"Authorization": "Bearer testtoken"},
)

print(response.json())
assert response.status_code == 201
json = response.json()
study = Study.objects.last()
assert json["study_instance_uid"] == dataset.StudyInstanceUID
assert json["series_instance_uid"] == dataset.SeriesInstanceUID
assert json["sop_instance_uid"] == dataset.SOPInstanceUID
assert json["instance_id"] == str(study.images().first().id)
assert study.source_message_id == str(appointment.pk)
assert study.source_message_id == str(gateway_action.id)


def test_upload_no_file(mock_authentication):
Expand All @@ -89,7 +121,9 @@ def test_upload_no_file(mock_authentication):
assert response.status_code == 422


def test_upload_invalid_file(mock_authentication, appointment_stub):
def test_upload_invalid_file(
mock_authentication, mock_authorisation, appointment_stub, source_message_id
):
invalid_file = SimpleUploadedFile(
"invalid.dcm", b"not a dicom file", content_type="application/dicom"
)
Expand All @@ -99,7 +133,7 @@ def test_upload_invalid_file(mock_authentication, appointment_stub):
return_value=appointment_stub,
):
response = client.put(
"/dicom/abc123",
f"/dicom/{source_message_id}",
FILES={"file": invalid_file},
headers={"Authorization": "Bearer testtoken"},
)
Expand All @@ -125,7 +159,9 @@ def test_upload_file_thats_too_large(mock_authentication):
assert response.json()["detail"] == "The file cannot be larger than 100MB"


def test_upload_missing_uids(dataset, mock_authentication, appointment_stub):
def test_upload_missing_uids(
dataset, mock_authentication, mock_authorisation, appointment_stub
):
del dataset.StudyInstanceUID
del dataset.SeriesInstanceUID
del dataset.SOPInstanceUID
Expand Down Expand Up @@ -157,7 +193,7 @@ def test_upload_missing_uids(dataset, mock_authentication, appointment_stub):


def test_upload_appointment_not_in_progress(
dicom_file, mock_authentication, appointment_stub
dicom_file, mock_authentication, mock_authorisation, appointment_stub
):
appointment_stub.is_in_progress.return_value = False

Expand Down Expand Up @@ -185,7 +221,7 @@ def test_upload_when_api_disabled(dicom_file, mock_authentication, monkeypatch):
)

assert response.status_code == 403
assert response.json()["status"] == "API is not available"
assert response.json()["detail"] == "API is not available"


def test_upload_no_auth(dicom_file):
Expand Down Expand Up @@ -213,22 +249,23 @@ def test_upload_invalid_auth(dicom_file):
}


def test_upload_bypass_token_validation(dicom_file):
with patch.object(Authentication, "bypass_authentication", return_value=True):
with patch.object(
DicomRecorder,
"get_or_create_records",
return_value=(
MagicMock(study_instance_uid=generate_uid()),
MagicMock(series_instance_uid=generate_uid()),
MagicMock(sop_instance_uid=generate_uid(), id=1),
),
):
response = client.put(
"/dicom/abc123",
FILES={"file": dicom_file},
headers={"Authorization": "Bearer anytoken"},
)
@patch.object(Authentication, "bypass_authentication", return_value=True)
@patch.object(Authorisation, "bypass_authorisation", return_value=True)
def test_upload_bypass_auth(_y, _x, dicom_file):
with patch.object(
DicomRecorder,
"get_or_create_records",
return_value=(
MagicMock(study_instance_uid=generate_uid()),
MagicMock(series_instance_uid=generate_uid()),
MagicMock(sop_instance_uid=generate_uid(), id=1),
),
):
response = client.put(
"/dicom/abc123",
FILES={"file": dicom_file},
headers={"Authorization": "Bearer anytoken"},
)

assert response.status_code == 201

Expand Down
14 changes: 12 additions & 2 deletions manage_breast_screening/dicom/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,19 @@ def test_with_valid_token(self, _, mock_logger, mock_jwks_signing_key):
}
mock_logger.exception.assert_not_called()

def test_request_auth_object_is_set(self, mock_logger, mock_jwks_signing_key):
with patch(
f"{Authentication.__module__}.jwt.decode",
return_value={"oid": "test_oid", "sub": "test_user"},
):
authenticator = Authentication()
request = Mock(headers={"Authorization": "Bearer abc123"})
assert authenticator(request) == {"oid": "test_oid", "sub": "test_user"}
assert request.auth == {"oid": "test_oid", "sub": "test_user"}

def test_authentication_bypass_enabled(self, mock_logger, mock_jwks_signing_key):
with patch.object(settings, "BYPASS_API_TOKEN_AUTH", return_value=True):
with patch.object(settings, "BYPASS_API_AUTHENTICATION", return_value=True):
authenticator = Authentication()
assert authenticator(
Mock(headers={"Authorization": "Bearer anytoken"})
) == {"sub": "bypass_user"}
) == {"oid": "bypass_object_id", "sub": "bypass_user"}
Loading
Loading