Skip to content

Commit a96582a

Browse files
committed
resolve type hinting and remove try catch
1 parent 0297ee5 commit a96582a

7 files changed

Lines changed: 94 additions & 125 deletions

File tree

lambdas/mns_publisher/src/create_notification.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,36 +63,32 @@ def calculate_age_at_vaccination(birth_date: str, vaccination_date: str) -> int:
6363

6464

6565
def get_practitioner_details_from_pds(nhs_number: str) -> str | None:
66-
try:
67-
patient_details = pds_get_patient_details(nhs_number)
66+
patient_details = pds_get_patient_details(nhs_number)
6867

69-
general_practitioners = patient_details.get("generalPractitioner", [])
70-
if not general_practitioners or len(general_practitioners) == 0:
71-
logger.warning("No GP details found for patient")
72-
return None
68+
general_practitioners = patient_details.get("generalPractitioner", [])
69+
if not general_practitioners or len(general_practitioners) == 0:
70+
logger.warning("No GP details found for patient")
71+
return None
7372

74-
patient_gp = general_practitioners[0]
75-
patient_gp_identifier = patient_gp.get("identifier", {})
73+
patient_gp = general_practitioners[0]
74+
patient_gp_identifier = patient_gp.get("identifier", {})
7675

77-
gp_ods_code = patient_gp_identifier.get("value")
78-
if not gp_ods_code:
79-
logger.warning("GP ODS code not found in practitioner details")
80-
return None
76+
gp_ods_code = patient_gp_identifier.get("value")
77+
if not gp_ods_code:
78+
logger.warning("GP ODS code not found in practitioner details")
79+
return None
8180

82-
# Check if registration is current
83-
period = patient_gp_identifier.get("period", {})
84-
gp_period_end_date = period.get("end", None)
81+
# Check if registration is current
82+
period = patient_gp_identifier.get("period", {})
83+
gp_period_end_date = period.get("end", None)
8584

86-
if gp_period_end_date:
87-
# Parse end date (format: YYYY-MM-DD)
88-
end_date = datetime.strptime(gp_period_end_date, "%Y-%m-%d").date()
89-
today = datetime.now().date()
85+
if gp_period_end_date:
86+
# Parse end date (format: YYYY-MM-DD)
87+
end_date = datetime.strptime(gp_period_end_date, "%Y-%m-%d").date()
88+
today = datetime.now().date()
9089

91-
if end_date < today:
92-
logger.warning("GP registration has ended")
93-
return None
90+
if end_date < today:
91+
logger.warning("No current GP registration found for patient")
92+
return None
9493

95-
return gp_ods_code
96-
except Exception as error:
97-
logger.exception("Failed to get practitioner details from pds", error)
98-
raise
94+
return gp_ods_code
Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
from aws_lambda_typing import context, events
22

3-
from common.clients import logger
43
from process_records import process_records
54

65

76
def lambda_handler(event: events.SQSEvent, _: context.Context) -> dict[str, list]:
87
event_records = event.get("Records", [])
9-
batch_item_failures = process_records(event_records)
108

11-
if batch_item_failures:
12-
logger.warning(f"Batch completed with {len(batch_item_failures)} failures")
13-
else:
14-
logger.info(f"Successfully processed all {len(event_records)} messages")
15-
16-
return {"batchItemFailures": batch_item_failures}
9+
return process_records(event_records)

lambdas/mns_publisher/src/process_records.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
import os
33
from typing import Tuple
44

5-
from aws_lambda_typing import events
65
from aws_lambda_typing.events.sqs import SQSMessage
76

7+
from common.api_clients.mns_service import MnsService
88
from common.api_clients.mns_setup import get_mns_service
99
from common.clients import logger
1010
from create_notification import create_mns_notification
1111

1212
mns_env = os.getenv("MNS_ENV", "int")
1313

1414

15-
def process_records(records: events.SQSEvent) -> list[dict]:
15+
def process_records(records: list[SQSMessage]) -> list[dict]:
1616
"""
1717
Process multiple SQS records.
1818
Args: records: List of SQS records to process
@@ -22,14 +22,23 @@ def process_records(records: events.SQSEvent) -> list[dict]:
2222
mns_service = get_mns_service(mns_env=mns_env)
2323

2424
for record in records:
25-
failed_batch_item = process_record(record, mns_service)
26-
if failed_batch_item:
27-
batch_item_failures.append(failed_batch_item)
25+
try:
26+
failed_batch_item = process_record(record, mns_service)
27+
if failed_batch_item:
28+
batch_item_failures.append(failed_batch_item)
29+
except Exception:
30+
message_id = record.get("messageId", "unknown")
31+
batch_item_failures.append({"itemIdentifier": message_id})
2832

29-
return batch_item_failures
33+
if batch_item_failures:
34+
logger.warning(f"Batch completed with {len(batch_item_failures)} failures")
35+
else:
36+
logger.info(f"Successfully processed all {len(records)} messages")
3037

38+
return {"batchItemFailures": batch_item_failures}
3139

32-
def process_record(record: SQSMessage, mns_service) -> dict | None:
40+
41+
def process_record(record: SQSMessage, mns_service: MnsService) -> dict | None:
3342
"""
3443
Process a single SQS record.
3544
Args:
@@ -40,40 +49,25 @@ def process_record(record: SQSMessage, mns_service) -> dict | None:
4049
message_id, immunisation_id = extract_trace_ids(record)
4150
notification_id = None
4251

43-
try:
44-
# Create notification payload
45-
mns_notification_payload = create_mns_notification(record)
46-
notification_id = mns_notification_payload.get("id")
47-
action_flag = mns_notification_payload.get("filtering", {}).get("action")
48-
logger.info(
49-
"Processing message",
50-
trace_ids={
51-
"notification_id": notification_id,
52-
"message_id": message_id,
53-
"immunisation_id": immunisation_id,
54-
"action_flag": action_flag,
55-
},
56-
)
57-
58-
# Publish to MNS
59-
mns_pub_response = mns_service.publish_notification(mns_notification_payload)
60-
if mns_pub_response["status_code"] != 200:
61-
raise RuntimeError("MNS publish failed")
62-
logger.info("Successfully created MNS notification", trace_ids={"mns_notification_id": notification_id})
63-
64-
return None
65-
66-
except Exception as e:
67-
logger.exception(
68-
"Failed to process message",
69-
trace_ids={
70-
"message_id": message_id,
71-
"immunisation_id": immunisation_id,
72-
"mns_notification_id": notification_id,
73-
"error": str(e),
74-
},
75-
)
76-
return {"itemIdentifier": message_id}
52+
# Create notification payload
53+
mns_notification_payload = create_mns_notification(record)
54+
notification_id = mns_notification_payload.get("id")
55+
action_flag = mns_notification_payload.get("filtering", {}).get("action")
56+
logger.info(
57+
"Processing message",
58+
trace_ids={
59+
"notification_id": notification_id,
60+
"message_id": message_id,
61+
"immunisation_id": immunisation_id,
62+
"action_flag": action_flag,
63+
},
64+
)
65+
66+
# Publish to MNS
67+
mns_service.publish_notification(mns_notification_payload)
68+
logger.info("Successfully created MNS notification", trace_ids={"mns_notification_id": notification_id})
69+
70+
return None
7771

7872

7973
def extract_trace_ids(record: SQSMessage) -> Tuple[str, str | None]:

lambdas/mns_publisher/src/sqs_dynamo_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def extract_sqs_imms_data(sqs_record: dict) -> ImmsData:
3333
}
3434

3535

36-
def _unwrap_dynamodb_value(value) -> Any:
36+
def _unwrap_dynamodb_value(value: dict) -> Any:
3737
"""
3838
Unwrap DynamoDB type descriptor to get the actual value.
3939
DynamoDB types: S (String), N (Number), BOOL, M (Map), L (List), NULL

lambdas/mns_publisher/tests/test_create_notification.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,10 @@ def test_get_practitioner_empty_value(self, mock_logger, mock_pds_get):
143143
def test_get_practitioner_pds_exception(self, mock_logger, mock_pds_get):
144144
"""Test when PDS API raises exception."""
145145
mock_pds_get.side_effect = Exception("PDS API error")
146-
147146
with self.assertRaises(Exception) as context:
148147
get_practitioner_details_from_pds("9481152782")
149-
150-
self.assertEqual(str(context.exception), "PDS API error")
151-
mock_logger.exception.assert_called_once()
148+
self.assertEqual(str(context.exception), "PDS API error")
149+
mock_logger.exception.assert_called_once()
152150

153151
@patch("create_notification.pds_get_patient_details")
154152
@patch("create_notification.logger")

lambdas/mns_publisher/tests/test_lambda_handler.py

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def setUp(self):
8787
self.sample_notification = {
8888
"id": "notif-789",
8989
"specversion": "1.0",
90-
"type": "imms-vaccinations-2",
90+
"type": "imms-vaccinations-1",
9191
"filtering": {"action": "CREATE"},
9292
}
9393
self.mock_mns_service = Mock()
@@ -97,11 +97,11 @@ def setUp(self):
9797
def test_process_record_success(self, mock_logger, mock_create_notification):
9898
"""Test successful processing of a single record."""
9999
mock_create_notification.return_value = self.sample_notification
100-
self.mock_mns_service.publish_notification.return_value = {"status_code": 200}
100+
self.mock_mns_service.publish_notification.return_value = None
101101

102-
result = process_record(self.sample_sqs_record, self.mock_mns_service)
102+
# Should not raise exception
103+
process_record(self.sample_sqs_record, self.mock_mns_service)
103104

104-
self.assertIsNone(result)
105105
mock_create_notification.assert_called_once_with(self.sample_sqs_record)
106106
self.mock_mns_service.publish_notification.assert_called_once_with(self.sample_notification)
107107
mock_logger.exception.assert_not_called()
@@ -112,10 +112,10 @@ def test_process_record_create_notification_failure(self, mock_logger, mock_crea
112112
"""Test handling when notification creation fails."""
113113
mock_create_notification.side_effect = Exception("Creation error")
114114

115-
result = process_record(self.sample_sqs_record, self.mock_mns_service)
115+
# Should raise exception
116+
with self.assertRaises(Exception):
117+
process_record(self.sample_sqs_record, self.mock_mns_service)
116118

117-
self.assertEqual(result, {"itemIdentifier": "98ed30eb-829f-41df-8a73-57fef70cf161"})
118-
mock_logger.exception.assert_called_once()
119119
self.mock_mns_service.publish_notification.assert_not_called()
120120

121121
@patch("process_records.create_mns_notification")
@@ -125,22 +125,9 @@ def test_process_record_publish_failure(self, mock_logger, mock_create_notificat
125125
mock_create_notification.return_value = self.sample_notification
126126
self.mock_mns_service.publish_notification.side_effect = Exception("Publish error")
127127

128-
result = process_record(self.sample_sqs_record, self.mock_mns_service)
129-
130-
self.assertEqual(result, {"itemIdentifier": "98ed30eb-829f-41df-8a73-57fef70cf161"})
131-
mock_logger.exception.assert_called_once()
132-
133-
@patch("process_records.create_mns_notification")
134-
@patch("process_records.logger")
135-
def test_process_record_logs_trace_ids(self, mock_logger, mock_create_notification):
136-
"""Test that trace IDs are logged correctly."""
137-
mock_create_notification.return_value = self.sample_notification
138-
139-
process_record(self.sample_sqs_record, self.mock_mns_service)
140-
141-
# Check info log was called with trace IDs
142-
info_calls = [call for call in mock_logger.info.call_args_list if "Processing message" in str(call)]
143-
self.assertEqual(len(info_calls), 1)
128+
# Should raise exception
129+
with self.assertRaises(Exception):
130+
process_record(self.sample_sqs_record, self.mock_mns_service)
144131

145132

146133
class TestProcessRecords(unittest.TestCase):
@@ -158,33 +145,36 @@ def setUpClass(cls):
158145

159146
cls.sample_sqs_record = raw_event
160147

148+
@patch("process_records.logger")
161149
@patch("process_records.get_mns_service")
162150
@patch("process_records.process_record")
163-
def test_process_records_all_success(self, mock_process_record, mock_get_mns):
151+
def test_process_records_all_success(self, mock_process_record, mock_get_mns, mock_logger):
164152
"""Test processing multiple records with all successes."""
165153
mock_mns_service = Mock()
166154
mock_get_mns.return_value = mock_mns_service
167-
mock_process_record.return_value = None # Success
155+
mock_process_record.return_value = None # No exception
168156

169157
record_2 = self.sample_sqs_record.copy()
170158
record_2["messageId"] = "different-id"
171159
records = [self.sample_sqs_record, record_2]
172160

173161
result = process_records(records)
174162

175-
self.assertEqual(result, [])
163+
self.assertEqual(result, {"batchItemFailures": []})
176164
self.assertEqual(mock_process_record.call_count, 2)
177165
mock_get_mns.assert_called_once()
166+
mock_logger.info.assert_called_with("Successfully processed all 2 messages")
178167

168+
@patch("process_records.logger")
179169
@patch("process_records.get_mns_service")
180170
@patch("process_records.process_record")
181-
def test_process_records_partial_failure(self, mock_process_record, mock_get_mns):
171+
def test_process_records_partial_failure(self, mock_process_record, mock_get_mns, mock_logger):
182172
"""Test processing with some failures."""
183173
mock_mns_service = Mock()
184174
mock_get_mns.return_value = mock_mns_service
185175
mock_process_record.side_effect = [
186176
None, # Success
187-
{"itemIdentifier": "msg-456"}, # Failure
177+
Exception("Processing error"), # Failure
188178
]
189179

190180
record_2 = self.sample_sqs_record.copy()
@@ -193,24 +183,28 @@ def test_process_records_partial_failure(self, mock_process_record, mock_get_mns
193183

194184
result = process_records(records)
195185

196-
self.assertEqual(len(result), 1)
197-
self.assertEqual(result[0]["itemIdentifier"], "msg-456")
186+
self.assertEqual(len(result["batchItemFailures"]), 1)
187+
self.assertEqual(result["batchItemFailures"][0]["itemIdentifier"], "msg-456")
188+
mock_logger.warning.assert_called_with("Batch completed with 1 failures")
198189

190+
@patch("process_records.logger")
199191
@patch("process_records.get_mns_service")
200192
@patch("process_records.process_record")
201-
def test_process_records_empty_list(self, mock_process_record, mock_get_mns):
193+
def test_process_records_empty_list(self, mock_process_record, mock_get_mns, mock_logger):
202194
"""Test processing empty record list."""
203195
mock_mns_service = Mock()
204196
mock_get_mns.return_value = mock_mns_service
205197

206198
result = process_records([])
207199

208-
self.assertEqual(result, [])
200+
self.assertEqual(result, {"batchItemFailures": []})
209201
mock_process_record.assert_not_called()
202+
mock_logger.info.assert_called_with("Successfully processed all 0 messages")
210203

204+
@patch("process_records.logger")
211205
@patch("process_records.get_mns_service")
212206
@patch("process_records.process_record")
213-
def test_process_records_mns_service_created_once(self, mock_process_record, mock_get_mns):
207+
def test_process_records_mns_service_created_once(self, mock_process_record, mock_get_mns, mock_logger):
214208
"""Test that MNS service is created only once for batch."""
215209
mock_mns_service = Mock()
216210
mock_get_mns.return_value = mock_mns_service
@@ -239,35 +233,30 @@ def setUpClass(cls):
239233
cls.sample_sqs_record = raw_event
240234

241235
@patch("lambda_handler.process_records")
242-
@patch("lambda_handler.logger")
243-
def test_lambda_handler_all_success(self, mock_logger, mock_process_records):
236+
def test_lambda_handler_all_success(self, mock_process_records):
244237
"""Test lambda handler with all records succeeding."""
245-
mock_process_records.return_value = []
238+
mock_process_records.return_value = {"batchItemFailures": []}
246239

247240
event = {"Records": [self.sample_sqs_record]}
248241
result = lambda_handler(event, Mock())
249242

250243
self.assertEqual(result, {"batchItemFailures": []})
251244
mock_process_records.assert_called_once_with([self.sample_sqs_record])
252-
mock_logger.info.assert_called_with("Successfully processed all 1 messages")
253245

254246
@patch("lambda_handler.process_records")
255-
@patch("lambda_handler.logger")
256-
def test_lambda_handler_with_failures(self, mock_logger, mock_process_records):
247+
def test_lambda_handler_with_failures(self, mock_process_records):
257248
"""Test lambda handler with some failures."""
258-
mock_process_records.return_value = [{"itemIdentifier": "msg-123"}]
249+
mock_process_records.return_value = {"batchItemFailures": [{"itemIdentifier": "msg-123"}]}
259250

260251
event = {"Records": [self.sample_sqs_record]}
261252
result = lambda_handler(event, Mock())
262253

263254
self.assertEqual(result, {"batchItemFailures": [{"itemIdentifier": "msg-123"}]})
264-
mock_logger.warning.assert_called_with("Batch completed with 1 failures")
265255

266256
@patch("lambda_handler.process_records")
267-
@patch("lambda_handler.logger")
268-
def test_lambda_handler_empty_records(self, mock_logger, mock_process_records):
257+
def test_lambda_handler_empty_records(self, mock_process_records):
269258
"""Test lambda handler with no records."""
270-
mock_process_records.return_value = []
259+
mock_process_records.return_value = {"batchItemFailures": []}
271260

272261
event = {"Records": []}
273262
result = lambda_handler(event, Mock())

lambdas/shared/src/common/api_clients/mns_service.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def get_subscription(self) -> dict | None:
7979
headers = self._build_headers()
8080
response = request_with_retry_backoff("GET", f"{MNS_BASE_URL}/subscriptions", headers, timeout=10)
8181
logging.info(f"GET {MNS_BASE_URL}/subscriptions")
82-
logging.debug(f"Headers: {headers}")
8382

8483
if response.status_code == 200:
8584
bundle = response.json()

0 commit comments

Comments
 (0)