Skip to content

Commit 940c9e3

Browse files
authored
VED-000: Refactor Api_clients Authentication (#1278)
1 parent 9765c1f commit 940c9e3

22 files changed

Lines changed: 277 additions & 479 deletions

lambdas/id_sync/src/pds_details.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,6 @@
22
Operations related to PDS (Patient Demographic Service)
33
"""
44

5-
import tempfile
6-
7-
from common.api_clients.authentication import AppRestrictedAuth, Service
8-
from common.api_clients.pds_service import PdsService
9-
from common.cache import Cache
10-
from common.clients import get_secrets_manager_client, logger
11-
from exceptions.id_sync_exception import IdSyncException
12-
from os_vars import get_pds_env
13-
14-
pds_env = get_pds_env()
15-
safe_tmp_dir = tempfile.mkdtemp(dir="/tmp")
16-
17-
18-
# Get Patient details from external service PDS using NHS number from MNS notification
19-
def pds_get_patient_details(nhs_number: str) -> dict:
20-
try:
21-
cache = Cache(directory=safe_tmp_dir)
22-
authenticator = AppRestrictedAuth(
23-
service=Service.PDS,
24-
secret_manager_client=get_secrets_manager_client(),
25-
environment=pds_env,
26-
cache=cache,
27-
)
28-
pds_service = PdsService(authenticator, pds_env)
29-
patient = pds_service.get_patient_details(nhs_number)
30-
return patient
31-
except Exception as e:
32-
msg = "Error retrieving patient details from PDS"
33-
logger.exception(msg)
34-
raise IdSyncException(message=msg) from e
35-
365

376
def get_nhs_number_from_pds_resource(pds_resource: dict) -> str:
387
"""Simple helper to get the NHS Number from a PDS Resource. No handling as this is a mandatory field in the PDS

lambdas/mns_publisher/poetry.lock

Lines changed: 38 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lambdas/mns_publisher/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ boto3 = "~1.42.37"
1919
mypy-boto3-dynamodb = "^1.42.33"
2020
moto = "~5.1.20"
2121
cache = "^1.0.3"
22+
aws-lambda-powertools = {version = "3.24.0"}
2223

2324
[build-system]
2425
requires = ["poetry-core >= 1.5.0"]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Static constants for the MNS notification creation process
22
SPEC_VERSION = "1.0"
3-
IMMUNISATION_TYPE = "imms-vaccinations-1"
3+
IMMUNISATION_EVENT_SOURCE = "uk.nhs.vaccinations-data-flow-management"
4+
IMMUNISATION_EVENT_TYPE = "imms-vaccination-record-change-1"
45

56
DYNAMO_DB_TYPE_DESCRIPTORS = ("S", "N", "BOOL", "M", "L")

lambdas/mns_publisher/src/create_notification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from common.api_clients.get_pds_details import pds_get_patient_details
1111
from common.clients import logger
1212
from common.get_service_url import get_service_url
13-
from constants import DYNAMO_DB_TYPE_DESCRIPTORS, IMMUNISATION_TYPE, SPEC_VERSION
13+
from constants import DYNAMO_DB_TYPE_DESCRIPTORS, IMMUNISATION_EVENT_SOURCE, IMMUNISATION_EVENT_TYPE, SPEC_VERSION
1414

1515
IMMUNIZATION_ENV = os.getenv("IMMUNIZATION_ENV")
1616
IMMUNIZATION_BASE_PATH = os.getenv("IMMUNIZATION_BASE_PATH")
@@ -43,8 +43,8 @@ def create_mns_notification(sqs_event: SQSMessage) -> MnsNotificationPayload:
4343
return {
4444
"specversion": SPEC_VERSION,
4545
"id": str(uuid.uuid4()),
46-
"source": immunisation_url,
47-
"type": IMMUNISATION_TYPE,
46+
"source": IMMUNISATION_EVENT_SOURCE,
47+
"type": IMMUNISATION_EVENT_TYPE,
4848
"time": date_and_time,
4949
"subject": nhs_number,
5050
"dataref": f"{immunisation_url}/Immunization/{imms_id}",
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
Centralised observability for MNS publisher Lambda.
3+
4+
log_uncaught_exceptions=True ensures unexpected exceptions are captured as
5+
structured JSON logs at the Lambda boundary.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import os
11+
12+
from aws_lambda_powertools import Logger
13+
14+
_SERVICE_NAME = "mns-immunisation-publisher."
15+
16+
logger: Logger = Logger(
17+
service=_SERVICE_NAME,
18+
level=os.environ.get("LOG_LEVEL", "INFO"),
19+
log_uncaught_exceptions=True,
20+
location=os.environ.get("POWERTOOLS_LOGGER_LOG_CALLABLE_LOCATION", "false").lower() == "true",
21+
)

lambdas/mns_publisher/src/process_records.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,63 @@
11
import json
22
import os
3-
from typing import Tuple
43

4+
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
55
from aws_lambda_typing.events.sqs import SQSMessage
66

77
from common.api_clients.mns_service import MnsService
88
from common.api_clients.mns_setup import get_mns_service
99
from common.api_clients.mock_mns_service import MockMnsService
10-
from common.clients import logger
1110
from create_notification import create_mns_notification
11+
from observability import logger
1212

1313
mns_env = os.getenv("MNS_ENV", "int")
14-
MNS_TEST_QUEUE_URL = os.getenv("MNS_TEST_QUEUE_URL")
14+
_mns_service: MnsService | MockMnsService | None = None
15+
SqsRecord = SQSRecord | SQSMessage
1516

1617

17-
def process_records(records: list[SQSMessage]) -> dict[str, list]:
18+
def _get_message_id(record: SqsRecord) -> str:
19+
if isinstance(record, SQSRecord):
20+
return record.message_id
21+
22+
return record.get("messageId", "unknown")
23+
24+
25+
def _get_body(record: SqsRecord) -> dict | str:
26+
if isinstance(record, SQSRecord):
27+
return record.body
28+
29+
return record.get("body", {})
30+
31+
32+
def _as_sqs_message(record: SqsRecord) -> SQSMessage:
33+
if isinstance(record, SQSRecord):
34+
return record.raw_event
35+
36+
return record
37+
38+
39+
def _get_runtime_mns_service() -> MnsService | MockMnsService:
40+
global _mns_service
41+
if _mns_service is None:
42+
_mns_service = get_mns_service(mns_env=mns_env)
43+
44+
return _mns_service
45+
46+
47+
def process_records(records: list[SqsRecord]) -> dict[str, list]:
1848
"""
1949
Process multiple SQS records.
2050
Args: records: List of SQS records to process
2151
Returns: List of failed item identifiers for partial batch failure
2252
"""
2353
batch_item_failures = []
24-
mns_service = get_mns_service(mns_env=mns_env)
54+
mns_service = _get_runtime_mns_service()
2555

2656
for record in records:
2757
try:
2858
process_record(record, mns_service)
2959
except Exception:
30-
message_id = record.get("messageId", "unknown")
60+
message_id = _get_message_id(record)
3161
batch_item_failures.append({"itemIdentifier": message_id})
3262
logger.exception("Failed to process record", extra={"message_id": message_id})
3363

@@ -39,7 +69,7 @@ def process_records(records: list[SQSMessage]) -> dict[str, list]:
3969
return {"batchItemFailures": batch_item_failures}
4070

4171

42-
def process_record(record: SQSMessage, mns_service: MnsService | MockMnsService) -> None:
72+
def process_record(record: SqsRecord, mns_service: MnsService | MockMnsService) -> None:
4373
"""
4474
Process a single SQS record.
4575
Args:
@@ -50,34 +80,36 @@ def process_record(record: SQSMessage, mns_service: MnsService | MockMnsService)
5080
message_id, immunisation_id = extract_trace_ids(record)
5181
notification_id = None
5282

53-
mns_notification_payload = create_mns_notification(record)
83+
mns_notification_payload = create_mns_notification(_as_sqs_message(record))
5484
notification_id = mns_notification_payload.get("id")
5585

5686
action_flag = mns_notification_payload.get("filtering", {}).get("action")
5787
logger.info(
5888
"Processing message",
59-
extra={
60-
"notification_id": notification_id,
61-
"message_id": message_id,
62-
"immunisation_id": immunisation_id,
63-
"action_flag": action_flag,
64-
},
89+
notification_id=notification_id,
90+
message_id=message_id,
91+
immunisation_id=immunisation_id,
92+
action_flag=action_flag,
6593
)
6694

6795
mns_service.publish_notification(mns_notification_payload)
68-
logger.info("Successfully created MNS notification", extra={"mns_notification_id": notification_id})
96+
97+
logger.info(
98+
"Successfully created MNS notification",
99+
mns_notification_id=notification_id,
100+
)
69101

70102

71-
def extract_trace_ids(record: SQSMessage) -> Tuple[str, str | None]:
103+
def extract_trace_ids(record: SqsRecord) -> tuple[str, str | None]:
72104
"""
73105
Extract identifiers for tracing from SQS record.
74106
Returns: Tuple of (message_id, immunisation_id)
75107
"""
76-
sqs_message_id = record.get("messageId", "unknown")
108+
sqs_message_id = _get_message_id(record)
77109
immunisation_id = None
78110

79111
try:
80-
sqs_event_body = record.get("body", {})
112+
sqs_event_body = _get_body(record)
81113
if isinstance(sqs_event_body, str):
82114
sqs_event_body = json.loads(sqs_event_body)
83115

lambdas/mns_publisher/tests/test_create_notification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import unittest
44
from unittest.mock import MagicMock, patch
55

6-
from constants import IMMUNISATION_TYPE, SPEC_VERSION
6+
from constants import IMMUNISATION_EVENT_SOURCE, IMMUNISATION_EVENT_TYPE, SPEC_VERSION
77
from create_notification import (
88
_unwrap_dynamodb_value,
99
calculate_age_at_vaccination,
@@ -81,8 +81,8 @@ def test_success_create_mns_notification_complete_payload(self, mock_uuid, mock_
8181
result = create_mns_notification(self.sample_sqs_event)
8282

8383
self.assertEqual(result["specversion"], SPEC_VERSION)
84-
self.assertEqual(result["type"], IMMUNISATION_TYPE)
85-
self.assertEqual(result["source"], self.expected_immunisation_url)
84+
self.assertEqual(result["type"], IMMUNISATION_EVENT_TYPE)
85+
self.assertEqual(result["source"], IMMUNISATION_EVENT_SOURCE)
8686
self.assertEqual(result["subject"], "9481152782")
8787

8888
expected_dataref = f"{self.expected_immunisation_url}/Immunization/d058014c-b0fd-4471-8db9-3316175eb825"

lambdas/mns_publisher/tests/test_lambda_handler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def setUpClass(cls):
125125
cls.sample_sqs_record = load_sample_sqs_event()
126126

127127
@patch("process_records.logger")
128-
@patch("process_records.get_mns_service")
128+
@patch("process_records._get_runtime_mns_service")
129129
@patch("process_records.process_record")
130130
def test_process_records_all_success(self, mock_process_record, mock_get_mns, mock_logger):
131131
"""Test processing multiple records with all successes."""
@@ -145,7 +145,7 @@ def test_process_records_all_success(self, mock_process_record, mock_get_mns, mo
145145
mock_logger.info.assert_called_with("Successfully processed all 2 messages")
146146

147147
@patch("process_records.logger")
148-
@patch("process_records.get_mns_service")
148+
@patch("process_records._get_runtime_mns_service")
149149
@patch("process_records.process_record")
150150
def test_process_records_partial_failure(self, mock_process_record, mock_get_mns, mock_logger):
151151
"""Test processing with some failures."""
@@ -167,7 +167,7 @@ def test_process_records_partial_failure(self, mock_process_record, mock_get_mns
167167
mock_logger.warning.assert_called_with("Batch completed with 1 failures")
168168

169169
@patch("process_records.logger")
170-
@patch("process_records.get_mns_service")
170+
@patch("process_records._get_runtime_mns_service")
171171
@patch("process_records.process_record")
172172
def test_process_records_empty_list(self, mock_process_record, mock_get_mns, mock_logger):
173173
"""Test processing empty record list."""
@@ -181,7 +181,7 @@ def test_process_records_empty_list(self, mock_process_record, mock_get_mns, moc
181181
mock_logger.info.assert_called_with("Successfully processed all 0 messages")
182182

183183
@patch("process_records.logger")
184-
@patch("process_records.get_mns_service")
184+
@patch("process_records._get_runtime_mns_service")
185185
@patch("process_records.process_record")
186186
def test_process_records_mns_service_created_once(self, mock_process_record, mock_get_mns, mock_logger):
187187
"""Test that MNS service is created only once for batch."""
@@ -300,7 +300,7 @@ def test_successful_notification_creation_with_gp(self, mock_logger, mock_get_to
300300

301301
@responses.activate
302302
@patch("common.api_clients.authentication.AppRestrictedAuth.get_access_token")
303-
@patch("process_records.get_mns_service")
303+
@patch("process_records._get_runtime_mns_service")
304304
@patch("process_records.logger")
305305
def test_pds_failure(self, mock_logger, mock_get_mns, mock_get_token):
306306
"""

0 commit comments

Comments
 (0)