Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion manage_breast_screening/auth/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import time
from unittest.mock import ANY, Mock

import pytest
from authlib.jose import JsonWebKey, jwt
from django.conf import settings
from django.contrib.auth import get_user_model
from django.http import HttpResponse
from django.test import override_settings
from django.test import Client, override_settings
from django.urls import reverse
from pytest_django.asserts import assertInHTML

Expand Down Expand Up @@ -304,3 +306,105 @@ def test_returns_500_on_error(self, client, monkeypatch):

assert response.status_code == 500
assert response.json() == {"keys": []}


@pytest.mark.django_db
class TestCis2BackChannelLogout:
@pytest.fixture
def cis2_jwk(self):
return JsonWebKey.generate_key(
"RSA", 2048, is_private=True, options={"kid": "test-key-1"}
)

@pytest.fixture
def mock_cis2_client(self, monkeypatch, cis2_jwk):
mock_client = Mock()
mock_client.load_server_metadata.return_value = {"issuer": "test-issuer"}
mock_client.fetch_jwk_set.return_value = {
"keys": [cis2_jwk.as_dict(is_private=False)]
}
monkeypatch.setattr(
"manage_breast_screening.auth.views.get_cis2_client",
lambda: mock_client,
)
return mock_client

def _make_logout_token(self, jwk, sub, *, overrides=None):
now = int(time.time())
payload = {
"iss": "test-issuer",
"aud": settings.CIS2_CLIENT_ID,
"iat": now,
"exp": now + 300,
"events": {"https://schemas.openid.net/event/backchannel-logout": {}},
"sub": sub,
"sid": "not-used",
"jti": "not-used",
}
if overrides:
payload.update(overrides)
token = jwt.encode(
{"alg": "RS256", "kid": jwk.kid},
payload,
jwk.as_dict(is_private=True),
)
return token.decode("utf-8")

def test_logs_out_user_for_valid_token(self, mock_cis2_client, cis2_jwk):
User = get_user_model()
user = User.objects.create_user(nhs_uid="user-123", email="user@example.com")
# Sign in on one client (representing the user's browser session)
user_client = Client()
user_client.force_login(user)
assert user.session_set.count() == 1

token = self._make_logout_token(cis2_jwk, sub=user.nhs_uid)

response = Client().post(
reverse("auth:cis2_back_channel_logout"),
data={"logout_token": token},
)

assert response.status_code == 200
assert user.session_set.count() == 0

def test_rejects_request_with_missing_logout_token(self):
response = Client().post(reverse("auth:cis2_back_channel_logout"), data={})

assert response.status_code == 400
assert b"Missing logout_token" in response.content

def test_rejects_expired_token(self, mock_cis2_client, cis2_jwk):
User = get_user_model()
user = User.objects.create_user(nhs_uid="user-123", email="user@example.com")
user_client = Client()
user_client.force_login(user)

now = int(time.time())
token = self._make_logout_token(
cis2_jwk,
sub=user.nhs_uid,
overrides={"iat": now - 300, "exp": now - 120},
)

response = Client().post(
reverse("auth:cis2_back_channel_logout"),
data={"logout_token": token},
)

assert response.status_code == 400
assert b"Invalid logout token" in response.content
assert user.session_set.count() == 1

def test_returns_ok_when_user_does_not_exist_locally(
self, mock_cis2_client, cis2_jwk
):
token = self._make_logout_token(cis2_jwk, sub="unknown-user")

response = Client().post(
reverse("auth:cis2_back_channel_logout"),
data={"logout_token": token},
)

assert response.status_code == 200
assert response.json() == {"status": "ok"}
25 changes: 24 additions & 1 deletion manage_breast_screening/auth/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from authlib.integrations.base_client.errors import MismatchingStateError, OAuthError
from authlib.jose import JsonWebKey
from django.conf import settings
from django.contrib import messages
from django.contrib.auth import authenticate, get_user_model
Expand Down Expand Up @@ -169,7 +170,7 @@ def cis2_back_channel_logout(request):
# Get the CIS2 client and prepare key loader for token verification
client = get_cis2_client()
metadata = client.load_server_metadata()
key_loader = client.create_load_key()
key_loader = _create_cis2_key_loader(client)
try:
claims = decode_logout_token(metadata["issuer"], key_loader, logout_token)
except InvalidLogoutToken:
Expand All @@ -193,6 +194,28 @@ def cis2_back_channel_logout(request):
return JsonResponse({"status": "ok"})


def _create_cis2_key_loader(client):
"""Build a key loader for verifying CIS2-signed tokens.

Force-refreshes the cached JWKS on a kid miss so newly rotated CIS2 signing keys
are picked up without a process restart.
"""

def load_key(header, _payload):
jwk_set = JsonWebKey.import_key_set(client.fetch_jwk_set())
try:
return jwk_set.find_by_kid(
header.get("kid"), use="sig", alg=header.get("alg")
)
except ValueError:
jwk_set = JsonWebKey.import_key_set(client.fetch_jwk_set(force=True))
return jwk_set.find_by_kid(
header.get("kid"), use="sig", alg=header.get("alg")
)

return load_key


def _validate_id_assurance_level(level: int | str | None) -> str | None:
if level is not None:
level = int(level)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"azure-storage-queue (>=12.13.0,<13.0.0)",
"pyyaml (>=6.0.2,<7.0.0)",
"rules (>=3.5,<4.0)",
"authlib>=1.6.11,<2.0.0",
"authlib>=1.7.0,<2.0.0",
"django-qsessions (>=2.0.0,<3.0.0)",
"business-python (>=2.1.0,<3.0.0)",
"django-extensions (>=4.1,<5.0)",
Expand Down
21 changes: 17 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading