Skip to content

Commit 9305138

Browse files
committed
fail fast with no nhs_no and lambda int test
1 parent 023942a commit 9305138

6 files changed

Lines changed: 103 additions & 58 deletions

File tree

lambdas/mns_publisher/src/create_notification.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def create_mns_notification(sqs_event: SQSMessage) -> MnsNotificationPayload:
2828

2929
imms_map = new_image.get("Imms", {}).get("M", {})
3030
nhs_number = _unwrap_dynamodb_value(imms_map.get("NHS_NUMBER", {}))
31+
if not nhs_number:
32+
logger.error("Missing required field: Nhs Number")
33+
raise ValueError("NHS number is required to create MNS notification")
34+
3135
person_dob = _unwrap_dynamodb_value(imms_map.get("PERSON_DOB", {}))
3236
date_and_time = _unwrap_dynamodb_value(imms_map.get("DATE_AND_TIME", {}))
3337
site_code = _unwrap_dynamodb_value(imms_map.get("SITE_CODE", {}))
@@ -115,14 +119,11 @@ def _unwrap_dynamodb_value(value: dict) -> Any:
115119
if not isinstance(value, dict):
116120
return value
117121

118-
# DynamoDB type descriptors
119122
if "NULL" in value:
120123
return None
121124

122-
# Check other DynamoDB types
123125
for key in DYNAMO_DB_TYPE_DESCRIPTORS:
124126
if key in value:
125127
return value[key]
126128

127-
# Not a DynamoDB type, return as-is
128129
return value

lambdas/mns_publisher/src/process_records.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def process_record(record: SQSMessage, mns_service: MnsService) -> None:
5050

5151
mns_notification_payload = create_mns_notification(record)
5252
notification_id = mns_notification_payload.get("id")
53+
5354
action_flag = mns_notification_payload.get("filtering", {}).get("action")
5455
logger.info(
5556
"Processing message",
File renamed without changes.

lambdas/mns_publisher/tests/test_create_notification.py

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def setUp(self):
7373
@patch("create_notification.get_practitioner_details_from_pds")
7474
@patch("create_notification.get_service_url")
7575
@patch("create_notification.uuid.uuid4")
76-
def test_create_mns_notification_success_with_real_payload(self, mock_uuid, mock_get_service_url, mock_get_gp):
76+
def test_create_mns_notification_complete_payload(self, mock_uuid, mock_get_service_url, mock_get_gp):
7777
mock_uuid.return_value = MagicMock(hex="236a1d4a-5d69-4fa9-9c7f-e72bf505aa5b")
7878
mock_get_service_url.return_value = self.expected_immunisation_url
7979
mock_get_gp.return_value = self.expected_gp_ods_code
@@ -84,47 +84,33 @@ def test_create_mns_notification_success_with_real_payload(self, mock_uuid, mock
8484
self.assertEqual(result["type"], IMMUNISATION_TYPE)
8585
self.assertEqual(result["source"], self.expected_immunisation_url)
8686
self.assertEqual(result["subject"], "9481152782")
87-
self.assertIn("id", result)
88-
self.assertIn("time", result)
89-
self.assertIn("dataref", result)
90-
self.assertIn("filtering", result)
91-
92-
@patch("create_notification.get_practitioner_details_from_pds")
93-
@patch("create_notification.get_service_url")
94-
def test_create_mns_notification_dataref_format_real_payload(self, mock_get_service_url, mock_get_gp):
95-
mock_get_service_url.return_value = self.expected_immunisation_url
96-
mock_get_gp.return_value = self.expected_gp_ods_code
97-
98-
result = create_mns_notification(self.sample_sqs_event)
9987

10088
expected_dataref = f"{self.expected_immunisation_url}/Immunization/d058014c-b0fd-4471-8db9-3316175eb825"
10189
self.assertEqual(result["dataref"], expected_dataref)
10290

103-
@patch("create_notification.get_practitioner_details_from_pds")
104-
@patch("create_notification.get_service_url")
105-
def test_create_mns_notification_filtering_fields_real_payload(self, mock_get_service_url, mock_get_gp):
106-
mock_get_service_url.return_value = self.expected_immunisation_url
107-
mock_get_gp.return_value = self.expected_gp_ods_code
108-
109-
result = create_mns_notification(self.sample_sqs_event)
110-
11191
filtering = result["filtering"]
11292
self.assertEqual(filtering["generalpractitioner"], self.expected_gp_ods_code)
11393
self.assertEqual(filtering["sourceorganisation"], "B0C4P")
11494
self.assertEqual(filtering["sourceapplication"], "TPP")
11595
self.assertEqual(filtering["immunisationtype"], "hib")
11696
self.assertEqual(filtering["action"], "CREATE")
117-
self.assertIsInstance(filtering["subjectage"], str)
97+
self.assertEqual(filtering["subjectage"], "21")
98+
99+
self.assertIn("id", result)
100+
self.assertIsInstance(result["id"], str)
118101

119102
@patch("create_notification.get_practitioner_details_from_pds")
120103
@patch("create_notification.get_service_url")
121-
def test_create_mns_notification_age_calculation_real_payload(self, mock_get_service_url, mock_get_gp):
122-
mock_get_service_url.return_value = self.expected_immunisation_url
123-
mock_get_gp.return_value = self.expected_gp_ods_code
104+
def test_create_mns_notification_missing_nhs_number(self, mock_get_service_url, mock_get_gp):
105+
sqs_event_data = copy.deepcopy(self.sample_sqs_event)
124106

125-
result = create_mns_notification(self.sample_sqs_event)
107+
body = json.loads(sqs_event_data["body"])
108+
body["dynamodb"]["NewImage"]["Imms"]["M"]["NHS_NUMBER"]["S"] = ""
109+
sqs_event_data["body"] = json.dumps(body)
126110

127-
self.assertEqual(result["filtering"]["subjectage"], "21")
111+
with self.assertRaises(ValueError) as context:
112+
create_mns_notification(sqs_event_data)
113+
self.assertIn("NHS number is required", str(context.exception))
128114

129115
@patch("create_notification.get_practitioner_details_from_pds")
130116
@patch("create_notification.get_service_url")
@@ -136,17 +122,6 @@ def test_create_mns_notification_calls_get_practitioner_real_payload(self, mock_
136122

137123
mock_get_gp.assert_called_once_with("9481152782")
138124

139-
@patch("create_notification.get_practitioner_details_from_pds")
140-
@patch("create_notification.get_service_url")
141-
def test_create_mns_notification_uuid_generated(self, mock_get_service_url, mock_get_gp):
142-
mock_get_service_url.return_value = self.expected_immunisation_url
143-
mock_get_gp.return_value = self.expected_gp_ods_code
144-
145-
result1 = create_mns_notification(self.sample_sqs_event)
146-
result2 = create_mns_notification(self.sample_sqs_event)
147-
148-
self.assertNotEqual(result1["id"], result2["id"])
149-
150125
@patch("create_notification.get_practitioner_details_from_pds")
151126
@patch("create_notification.get_service_url")
152127
def test_create_mns_notification_invalid_json_body(self, mock_get_service_url, mock_get_gp):
@@ -200,7 +175,7 @@ def test_create_mns_notification_missing_imms_data_field(self, mock_get_service_
200175
"body": json.dumps({"dynamodb": {"NewImage": {"ImmsID": {"S": "test-id"}}}}),
201176
}
202177

203-
with self.assertRaises((KeyError, TypeError)):
178+
with self.assertRaises((KeyError, TypeError, ValueError)):
204179
create_mns_notification(incomplete_event)
205180

206181
@patch("create_notification.get_practitioner_details_from_pds")
@@ -385,7 +360,3 @@ def test_unwrap_list_type(self):
385360
value = {"L": [{"S": "item1"}, {"S": "item2"}]}
386361
result = _unwrap_dynamodb_value(value)
387362
self.assertEqual(result, [{"S": "item1"}, {"S": "item2"}])
388-
389-
390-
if __name__ == "__main__":
391-
unittest.main()

lambdas/mns_publisher/tests/test_lambda_handler.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import json
22
import unittest
3-
from pathlib import Path
43
from unittest.mock import Mock, patch
54

5+
import boto3
6+
import responses
7+
from moto import mock_aws
8+
69
from lambda_handler import lambda_handler
710
from process_records import extract_trace_ids, process_record, process_records
8-
from test_utils import load_sample_sqs_event
11+
from test_utils import generate_private_key_b64, load_sample_sqs_event
912

1013

1114
class TestExtractTraceIds(unittest.TestCase):
@@ -119,14 +122,7 @@ class TestProcessRecords(unittest.TestCase):
119122
@classmethod
120123
def setUpClass(cls):
121124
"""Load the sample SQS event once for all tests."""
122-
sample_event_path = Path(__file__).parent.parent / "tests/sqs_event.json"
123-
with open(sample_event_path, "r") as f:
124-
raw_event = json.load(f)
125-
126-
if isinstance(raw_event.get("body"), dict):
127-
raw_event["body"] = json.dumps(raw_event["body"])
128-
129-
cls.sample_sqs_record = raw_event
125+
cls.sample_sqs_record = load_sample_sqs_event()
130126

131127
@patch("process_records.logger")
132128
@patch("process_records.get_mns_service")
@@ -168,7 +164,7 @@ def test_process_records_partial_failure(self, mock_process_record, mock_get_mns
168164

169165
self.assertEqual(len(result["batchItemFailures"]), 1)
170166
self.assertEqual(result["batchItemFailures"][0]["itemIdentifier"], "msg-456")
171-
mock_logger.exception.assert_called_once()
167+
mock_logger.warning.assert_called_with("Batch completed with 1 failures")
172168

173169
@patch("process_records.logger")
174170
@patch("process_records.get_mns_service")
@@ -241,5 +237,66 @@ def test_lambda_handler_empty_records(self, mock_process_records):
241237
mock_process_records.assert_called_once_with([])
242238

243239

240+
@mock_aws
241+
class TestLambdaHandlerIntegration(unittest.TestCase):
242+
"""
243+
Integration tests
244+
"""
245+
246+
def setUp(self):
247+
"""Set up mocked AWS services and test data."""
248+
self.sample_sqs_record = load_sample_sqs_event()
249+
self.secrets_client = boto3.client("secretsmanager", region_name="eu-west-2")
250+
self.secrets_client.create_secret(
251+
Name="imms/pds/int/jwt-secrets",
252+
SecretString=json.dumps(
253+
{"api_key": "fake-pds-api-key", "kid": "fake-kid-123", "private_key_b64": generate_private_key_b64()}
254+
),
255+
)
256+
257+
@responses.activate
258+
@patch("common.api_clients.authentication.AppRestrictedAuth.get_access_token")
259+
@patch("process_records.logger")
260+
def test_successful_notification_creation_with_gp(self, mock_logger, mock_get_token):
261+
# Mock OAuth token response issued from Apigee
262+
mock_oauth_response = Mock()
263+
mock_oauth_response.status_code = 200
264+
mock_oauth_response.json.return_value = {"access_token": "fake-token"}
265+
mock_get_token.return_value = mock_oauth_response
266+
267+
# Intercepts actual request call to PDS and returns mocked responses
268+
responses.add(
269+
responses.GET,
270+
"https://int.api.service.nhs.uk/personal-demographics/FHIR/R4/Patient/9481152782",
271+
json={"generalPractitioner": [{"identifier": {"value": "Y12345", "period": {"start": "2024-01-01"}}}]},
272+
status=200,
273+
)
274+
275+
mns_response = responses.add(
276+
responses.POST,
277+
"https://int.api.service.nhs.uk/multicast-notification-service/events",
278+
json={"id": "236a1d4a-5d69-4fa9-9c7f-e72bf505aa5b"},
279+
status=200,
280+
)
281+
282+
sqs_event = {"Records": [self.sample_sqs_record]}
283+
result = lambda_handler(sqs_event, Mock())
284+
285+
self.assertEqual(result, {"batchItemFailures": []})
286+
287+
self.assertEqual(mns_response.call_count, 1)
288+
self.assertEqual(mns_response.calls[0].response.status_code, 200)
289+
mns_payload = json.loads(mns_response.calls[0].request.body)
290+
self.assertEqual(mns_payload["subject"], "9481152782")
291+
self.assertEqual(mns_payload["filtering"]["generalpractitioner"], "Y12345")
292+
self.assertEqual(mns_payload["filtering"]["sourceorganisation"], "B0C4P")
293+
self.assertEqual(mns_payload["filtering"]["sourceapplication"], "TPP")
294+
self.assertEqual(mns_payload["filtering"]["immunisationtype"], "hib")
295+
self.assertEqual(mns_payload["filtering"]["action"], "CREATE")
296+
self.assertEqual(mns_payload["filtering"]["subjectage"], "21")
297+
298+
mock_logger.info.assert_any_call("Successfully processed all 1 messages")
299+
300+
244301
if __name__ == "__main__":
245302
unittest.main()

lambdas/mns_publisher/tests/test_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
1+
import base64
12
import json
23
from pathlib import Path
34

5+
from cryptography.hazmat.primitives import serialization
6+
from cryptography.hazmat.primitives.asymmetric import rsa
7+
8+
9+
def generate_private_key_b64() -> str:
10+
# Generate a real RSA private key (PKCS8) and base64 encode the PEM
11+
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
12+
pem_bytes = private_key.private_bytes(
13+
encoding=serialization.Encoding.PEM,
14+
format=serialization.PrivateFormat.PKCS8,
15+
encryption_algorithm=serialization.NoEncryption(),
16+
)
17+
return base64.b64encode(pem_bytes).decode("utf-8")
18+
419

520
def load_sample_sqs_event() -> dict:
621
"""
722
Loads the sample SQS event and normalises body to a JSON string (as SQS delivers it).
823
Expects: lambdas/mns_publisher/tests/sqs_event.json
924
"""
10-
sample_event_path = Path(__file__).parent / "sqs_event.json"
25+
sample_event_path = Path(__file__).parent / "sample_data" / "sqs_event.json"
1126
with open(sample_event_path, "r") as f:
1227
raw_event = json.load(f)
1328

0 commit comments

Comments
 (0)