Skip to content

Commit 67cbb18

Browse files
committed
Fix back-channel logout broken by authlib 1.7.0
authlib 1.7.0 removed `DjangoOAuth2App.create_load_key()` as part of its internal migration to joserfc. Re-implement the deleted method as a local helper `_create_cis2_key_loader`, preserving the JWKS force-refresh on kid miss to handle CIS2 key rotation without a process restart. This commit also adds unit tests for the back-channel logout view.
1 parent 1e43ea4 commit 67cbb18

2 files changed

Lines changed: 129 additions & 2 deletions

File tree

manage_breast_screening/auth/tests/test_views.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import time
12
from unittest.mock import ANY, Mock
23

34
import pytest
5+
from authlib.jose import JsonWebKey, jwt
46
from django.conf import settings
57
from django.contrib.auth import get_user_model
68
from django.http import HttpResponse
7-
from django.test import override_settings
9+
from django.test import Client, override_settings
810
from django.urls import reverse
911
from pytest_django.asserts import assertInHTML
1012

@@ -304,3 +306,105 @@ def test_returns_500_on_error(self, client, monkeypatch):
304306

305307
assert response.status_code == 500
306308
assert response.json() == {"keys": []}
309+
310+
311+
@pytest.mark.django_db
312+
class TestCis2BackChannelLogout:
313+
@pytest.fixture
314+
def cis2_jwk(self):
315+
return JsonWebKey.generate_key(
316+
"RSA", 2048, is_private=True, options={"kid": "test-key-1"}
317+
)
318+
319+
@pytest.fixture
320+
def mock_cis2_client(self, monkeypatch, cis2_jwk):
321+
mock_client = Mock()
322+
mock_client.load_server_metadata.return_value = {"issuer": "test-issuer"}
323+
mock_client.fetch_jwk_set.return_value = {
324+
"keys": [cis2_jwk.as_dict(is_private=False)]
325+
}
326+
monkeypatch.setattr(
327+
"manage_breast_screening.auth.views.get_cis2_client",
328+
lambda: mock_client,
329+
)
330+
return mock_client
331+
332+
def _make_logout_token(self, jwk, sub, *, overrides=None):
333+
now = int(time.time())
334+
payload = {
335+
"iss": "test-issuer",
336+
"aud": settings.CIS2_CLIENT_ID,
337+
"iat": now,
338+
"exp": now + 300,
339+
"events": {"http://schemas.openid.net/event/backchannel-logout": {}},
340+
"sub": sub,
341+
"sid": "not-used",
342+
"jti": "not-used",
343+
}
344+
if overrides:
345+
payload.update(overrides)
346+
token = jwt.encode(
347+
{"alg": "RS256", "kid": jwk.kid},
348+
payload,
349+
jwk.as_dict(is_private=True),
350+
)
351+
return token.decode("utf-8")
352+
353+
def test_logs_out_user_for_valid_token(self, mock_cis2_client, cis2_jwk):
354+
User = get_user_model()
355+
user = User.objects.create_user(nhs_uid="user-123", email="user@example.com")
356+
# Sign in on one client (representing the user's browser session)
357+
user_client = Client()
358+
user_client.force_login(user)
359+
assert user.session_set.count() == 1
360+
361+
token = self._make_logout_token(cis2_jwk, sub=user.nhs_uid)
362+
363+
response = Client().post(
364+
reverse("auth:cis2_back_channel_logout"),
365+
data={"logout_token": token},
366+
)
367+
368+
assert response.status_code == 200
369+
assert user.session_set.count() == 0
370+
371+
def test_rejects_request_with_missing_logout_token(self):
372+
response = Client().post(reverse("auth:cis2_back_channel_logout"), data={})
373+
374+
assert response.status_code == 400
375+
assert b"Missing logout_token" in response.content
376+
377+
def test_rejects_expired_token(self, mock_cis2_client, cis2_jwk):
378+
User = get_user_model()
379+
user = User.objects.create_user(nhs_uid="user-123", email="user@example.com")
380+
user_client = Client()
381+
user_client.force_login(user)
382+
383+
now = int(time.time())
384+
token = self._make_logout_token(
385+
cis2_jwk,
386+
sub=user.nhs_uid,
387+
overrides={"iat": now - 300, "exp": now - 120},
388+
)
389+
390+
response = Client().post(
391+
reverse("auth:cis2_back_channel_logout"),
392+
data={"logout_token": token},
393+
)
394+
395+
assert response.status_code == 400
396+
assert b"Invalid logout token" in response.content
397+
assert user.session_set.count() == 1
398+
399+
def test_returns_ok_when_user_does_not_exist_locally(
400+
self, mock_cis2_client, cis2_jwk
401+
):
402+
token = self._make_logout_token(cis2_jwk, sub="unknown-user")
403+
404+
response = Client().post(
405+
reverse("auth:cis2_back_channel_logout"),
406+
data={"logout_token": token},
407+
)
408+
409+
assert response.status_code == 200
410+
assert response.json() == {"status": "ok"}

manage_breast_screening/auth/views.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
from authlib.integrations.base_client.errors import MismatchingStateError, OAuthError
4+
from authlib.jose import JsonWebKey
45
from django.conf import settings
56
from django.contrib import messages
67
from django.contrib.auth import authenticate, get_user_model
@@ -169,7 +170,7 @@ def cis2_back_channel_logout(request):
169170
# Get the CIS2 client and prepare key loader for token verification
170171
client = get_cis2_client()
171172
metadata = client.load_server_metadata()
172-
key_loader = client.create_load_key()
173+
key_loader = _create_cis2_key_loader(client)
173174
try:
174175
claims = decode_logout_token(metadata["issuer"], key_loader, logout_token)
175176
except InvalidLogoutToken:
@@ -193,6 +194,28 @@ def cis2_back_channel_logout(request):
193194
return JsonResponse({"status": "ok"})
194195

195196

197+
def _create_cis2_key_loader(client):
198+
"""Build a key loader for verifying CIS2-signed tokens.
199+
200+
Force-refreshes the cached JWKS on a kid miss so newly rotated CIS2 signing keys
201+
are picked up without a process restart.
202+
"""
203+
204+
def load_key(header, _payload):
205+
jwk_set = JsonWebKey.import_key_set(client.fetch_jwk_set())
206+
try:
207+
return jwk_set.find_by_kid(
208+
header.get("kid"), use="sig", alg=header.get("alg")
209+
)
210+
except ValueError:
211+
jwk_set = JsonWebKey.import_key_set(client.fetch_jwk_set(force=True))
212+
return jwk_set.find_by_kid(
213+
header.get("kid"), use="sig", alg=header.get("alg")
214+
)
215+
216+
return load_key
217+
218+
196219
def _validate_id_assurance_level(level: int | str | None) -> str | None:
197220
if level is not None:
198221
level = int(level)

0 commit comments

Comments
 (0)