diff --git a/lambdas/id_sync/src/pds_details.py b/lambdas/id_sync/src/pds_details.py index 62ef6c247..e7ae540a8 100644 --- a/lambdas/id_sync/src/pds_details.py +++ b/lambdas/id_sync/src/pds_details.py @@ -4,7 +4,7 @@ import tempfile -from common.api_clients.authentication import AppRestrictedAuth, Service +from common.api_clients.authentication import AppRestrictedAuth from common.api_clients.pds_service import PdsService from common.cache import Cache from common.clients import get_secrets_manager_client, logger @@ -20,7 +20,6 @@ def pds_get_patient_details(nhs_number: str) -> dict: try: cache = Cache(directory=safe_tmp_dir) authenticator = AppRestrictedAuth( - service=Service.PDS, secret_manager_client=get_secrets_manager_client(), environment=pds_env, cache=cache, diff --git a/lambdas/mns_publisher/src/process_records.py b/lambdas/mns_publisher/src/process_records.py index e55924d70..19792b7f3 100644 --- a/lambdas/mns_publisher/src/process_records.py +++ b/lambdas/mns_publisher/src/process_records.py @@ -1,6 +1,5 @@ import json import os -from typing import Tuple from aws_lambda_typing.events.sqs import SQSMessage @@ -68,7 +67,7 @@ def process_record(record: SQSMessage, mns_service: MnsService | MockMnsService) logger.info("Successfully created MNS notification", extra={"mns_notification_id": notification_id}) -def extract_trace_ids(record: SQSMessage) -> Tuple[str, str | None]: +def extract_trace_ids(record: SQSMessage) -> tuple[str, str | None]: """ Extract identifiers for tracing from SQS record. Returns: Tuple of (message_id, immunisation_id) diff --git a/lambdas/mns_publisher/tests/test_utils.py b/lambdas/mns_publisher/tests/test_utils.py index 4c6a71c15..1783681ae 100644 --- a/lambdas/mns_publisher/tests/test_utils.py +++ b/lambdas/mns_publisher/tests/test_utils.py @@ -23,7 +23,7 @@ def load_sample_sqs_event() -> dict: Expects: lambdas/mns_publisher/tests/sqs_event.json """ sample_event_path = Path(__file__).parent / "sample_data" / "sqs_event.json" - with open(sample_event_path, "r") as f: + with open(sample_event_path) as f: raw_event = json.load(f) if isinstance(raw_event.get("body"), dict): diff --git a/lambdas/shared/src/common/api_clients/authentication.py b/lambdas/shared/src/common/api_clients/authentication.py index 396d41c19..892291ef0 100644 --- a/lambdas/shared/src/common/api_clients/authentication.py +++ b/lambdas/shared/src/common/api_clients/authentication.py @@ -2,34 +2,25 @@ import json import time import uuid -from enum import Enum import jwt import requests +from common.api_clients.constants import API_CACHE_KEY from common.clients import logger from common.models.errors import UnhandledResponseError from ..cache import Cache -class Service(Enum): - PDS = "pds" - IMMUNIZATION = "imms" - - class AppRestrictedAuth: - def __init__(self, service: Service, secret_manager_client, environment, cache: Cache): + def __init__(self, secret_manager_client, environment, cache: Cache): self.secret_manager_client = secret_manager_client self.cache = cache - self.cache_key = f"{service.value}_access_token" + self.cache_key = API_CACHE_KEY self.expiry = 30 - self.secret_name = ( - f"imms/pds/{environment}/jwt-secrets" - if service == Service.PDS - else f"imms/immunization/{environment}/jwt-secrets" - ) + self.secret_name = f"imms/outbound/{environment}/jwt-secrets" self.token_url = ( f"https://{environment}.api.service.nhs.uk/oauth2/token" diff --git a/lambdas/shared/src/common/api_clients/constants.py b/lambdas/shared/src/common/api_clients/constants.py index aa305f146..7f986cc7a 100644 --- a/lambdas/shared/src/common/api_clients/constants.py +++ b/lambdas/shared/src/common/api_clients/constants.py @@ -3,6 +3,7 @@ """Constants used by API clients""" DEV_ENVIRONMENT = "dev" +API_CACHE_KEY = "api_client_access_token" class Constants: diff --git a/lambdas/shared/src/common/api_clients/get_pds_details.py b/lambdas/shared/src/common/api_clients/get_pds_details.py index 63844b3cd..a464f2265 100644 --- a/lambdas/shared/src/common/api_clients/get_pds_details.py +++ b/lambdas/shared/src/common/api_clients/get_pds_details.py @@ -5,7 +5,7 @@ import os import tempfile -from common.api_clients.authentication import AppRestrictedAuth, Service +from common.api_clients.authentication import AppRestrictedAuth from common.api_clients.errors import PdsSyncException from common.api_clients.pds_service import PdsService from common.cache import Cache @@ -20,7 +20,6 @@ def pds_get_patient_details(nhs_number: str) -> dict: try: cache = Cache(directory=safe_tmp_dir) authenticator = AppRestrictedAuth( - service=Service.PDS, secret_manager_client=get_secrets_manager_client(), environment=PDS_ENV, cache=cache, diff --git a/lambdas/shared/src/common/api_clients/mns_setup.py b/lambdas/shared/src/common/api_clients/mns_setup.py index 5cecd4440..8b6b5a5a0 100644 --- a/lambdas/shared/src/common/api_clients/mns_setup.py +++ b/lambdas/shared/src/common/api_clients/mns_setup.py @@ -4,7 +4,7 @@ import boto3 from botocore.config import Config -from common.api_clients.authentication import AppRestrictedAuth, Service +from common.api_clients.authentication import AppRestrictedAuth from common.api_clients.constants import DEV_ENVIRONMENT from common.api_clients.mns_service import MnsService from common.api_clients.mock_mns_service import MockMnsService @@ -23,7 +23,6 @@ def get_mns_service(mns_env: str = "int"): cache = Cache(directory="/tmp") logging.info("Creating authenticator...") authenticator = AppRestrictedAuth( - service=Service.PDS, secret_manager_client=boto3.client("secretsmanager", config=boto_config), environment=mns_env, cache=cache, diff --git a/lambdas/shared/src/common/get_service_url.py b/lambdas/shared/src/common/get_service_url.py index 9188c0750..212340e0b 100644 --- a/lambdas/shared/src/common/get_service_url.py +++ b/lambdas/shared/src/common/get_service_url.py @@ -1,9 +1,7 @@ -from typing import Optional - from common.constants import DEFAULT_BASE_PATH, PR_ENV_PREFIX -def get_service_url(service_env: Optional[str], service_base_path: Optional[str]) -> str: +def get_service_url(service_env: str | None, service_base_path: str | None) -> str: """Sets the service URL based on service parameters derived from env vars. PR environments use internal-dev while we also default to this environment. The only other exceptions are preprod which maps to the Apigee int environment and prod which does not have a subdomain.""" @@ -22,5 +20,5 @@ def get_service_url(service_env: Optional[str], service_base_path: Optional[str] return f"https://{subdomain}api.service.nhs.uk/{service_base_path}" -def is_pr_env(service_env: Optional[str]) -> bool: +def is_pr_env(service_env: str | None) -> bool: return service_env is not None and service_env.startswith(PR_ENV_PREFIX) diff --git a/lambdas/shared/tests/test_common/api_clients/test_authentication.py b/lambdas/shared/tests/test_common/api_clients/test_authentication.py index 11fc2e1d8..98e56f401 100644 --- a/lambdas/shared/tests/test_common/api_clients/test_authentication.py +++ b/lambdas/shared/tests/test_common/api_clients/test_authentication.py @@ -7,7 +7,7 @@ import responses from responses import matchers -from common.api_clients.authentication import AppRestrictedAuth, Service +from common.api_clients.authentication import AppRestrictedAuth from common.models.errors import UnhandledResponseError @@ -33,7 +33,7 @@ def setUp(self): self.cache.get.return_value = None env = "an-env" - self.authenticator = AppRestrictedAuth(Service.PDS, self.secret_manager_client, env, self.cache) + self.authenticator = AppRestrictedAuth(self.secret_manager_client, env, self.cache) self.url = f"https://{env}.api.service.nhs.uk/oauth2/token" @responses.activate @@ -89,12 +89,12 @@ def test_env_mapping(self): """it should target int environment for none-prod environment, otherwise int""" # For env=none-prod env = "some-env" - auth = AppRestrictedAuth(Service.PDS, None, env, None) + auth = AppRestrictedAuth(None, env, None) self.assertTrue(auth.token_url.startswith(f"https://{env}.")) # For env=prod env = "prod" - auth = AppRestrictedAuth(Service.PDS, None, env, None) + auth = AppRestrictedAuth(None, env, None) self.assertTrue(env not in auth.token_url) def test_returned_cached_token(self): @@ -126,7 +126,7 @@ def test_update_cache(self): self.authenticator.get_access_token() # Then - self.cache.put.assert_called_once_with(f"{Service.PDS.value}_access_token", cached_token) + self.cache.put.assert_called_once_with("api_client_access_token", cached_token) @responses.activate def test_expired_token_in_cache(self):