Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lambdas/id_sync/src/pds_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tempfile

from common.api_clients.authentication import AppRestrictedAuth, Service
from common.api_clients.authentication import AppRestrictedAuth
from common.api_clients.pds_service import PdsService
from common.cache import Cache
from common.clients import get_secrets_manager_client, logger
Expand All @@ -20,7 +20,6 @@ def pds_get_patient_details(nhs_number: str) -> dict:
try:
cache = Cache(directory=safe_tmp_dir)
authenticator = AppRestrictedAuth(
service=Service.PDS,
secret_manager_client=get_secrets_manager_client(),
environment=pds_env,
cache=cache,
Expand Down
3 changes: 1 addition & 2 deletions lambdas/mns_publisher/src/process_records.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
from typing import Tuple

from aws_lambda_typing.events.sqs import SQSMessage

Expand Down Expand Up @@ -68,7 +67,7 @@ def process_record(record: SQSMessage, mns_service: MnsService | MockMnsService)
logger.info("Successfully created MNS notification", extra={"mns_notification_id": notification_id})


def extract_trace_ids(record: SQSMessage) -> Tuple[str, str | None]:
def extract_trace_ids(record: SQSMessage) -> tuple[str, str | None]:
"""
Extract identifiers for tracing from SQS record.
Returns: Tuple of (message_id, immunisation_id)
Expand Down
2 changes: 1 addition & 1 deletion lambdas/mns_publisher/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def load_sample_sqs_event() -> dict:
Expects: lambdas/mns_publisher/tests/sqs_event.json
"""
sample_event_path = Path(__file__).parent / "sample_data" / "sqs_event.json"
with open(sample_event_path, "r") as f:
with open(sample_event_path) as f:
raw_event = json.load(f)

if isinstance(raw_event.get("body"), dict):
Expand Down
17 changes: 4 additions & 13 deletions lambdas/shared/src/common/api_clients/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,25 @@
import json
import time
import uuid
from enum import Enum

import jwt
import requests

from common.api_clients.constants import API_CACHE_KEY
from common.clients import logger
from common.models.errors import UnhandledResponseError

from ..cache import Cache


class Service(Enum):
PDS = "pds"
IMMUNIZATION = "imms"


class AppRestrictedAuth:
def __init__(self, service: Service, secret_manager_client, environment, cache: Cache):
def __init__(self, secret_manager_client, environment, cache: Cache):
self.secret_manager_client = secret_manager_client
self.cache = cache
self.cache_key = f"{service.value}_access_token"
self.cache_key = API_CACHE_KEY

self.expiry = 30
self.secret_name = (
f"imms/pds/{environment}/jwt-secrets"
if service == Service.PDS
else f"imms/immunization/{environment}/jwt-secrets"
)
self.secret_name = f"imms/outbound/{environment}/jwt-secrets"

self.token_url = (
f"https://{environment}.api.service.nhs.uk/oauth2/token"
Expand Down
1 change: 1 addition & 0 deletions lambdas/shared/src/common/api_clients/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Constants used by API clients"""

DEV_ENVIRONMENT = "dev"
API_CACHE_KEY = "api_client_access_token"


class Constants:
Expand Down
3 changes: 1 addition & 2 deletions lambdas/shared/src/common/api_clients/get_pds_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import tempfile

from common.api_clients.authentication import AppRestrictedAuth, Service
from common.api_clients.authentication import AppRestrictedAuth
from common.api_clients.errors import PdsSyncException
from common.api_clients.pds_service import PdsService
from common.cache import Cache
Expand All @@ -20,7 +20,6 @@ def pds_get_patient_details(nhs_number: str) -> dict:
try:
cache = Cache(directory=safe_tmp_dir)
authenticator = AppRestrictedAuth(
service=Service.PDS,
secret_manager_client=get_secrets_manager_client(),
environment=PDS_ENV,
cache=cache,
Expand Down
3 changes: 1 addition & 2 deletions lambdas/shared/src/common/api_clients/mns_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import boto3
from botocore.config import Config

from common.api_clients.authentication import AppRestrictedAuth, Service
from common.api_clients.authentication import AppRestrictedAuth
from common.api_clients.constants import DEV_ENVIRONMENT
from common.api_clients.mns_service import MnsService
from common.api_clients.mock_mns_service import MockMnsService
Expand All @@ -23,7 +23,6 @@ def get_mns_service(mns_env: str = "int"):
cache = Cache(directory="/tmp")
logging.info("Creating authenticator...")
authenticator = AppRestrictedAuth(
service=Service.PDS,
secret_manager_client=boto3.client("secretsmanager", config=boto_config),
environment=mns_env,
cache=cache,
Expand Down
6 changes: 2 additions & 4 deletions lambdas/shared/src/common/get_service_url.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional

from common.constants import DEFAULT_BASE_PATH, PR_ENV_PREFIX


def get_service_url(service_env: Optional[str], service_base_path: Optional[str]) -> str:
def get_service_url(service_env: str | None, service_base_path: str | None) -> str:
"""Sets the service URL based on service parameters derived from env vars. PR environments use internal-dev while
we also default to this environment. The only other exceptions are preprod which maps to the Apigee int environment
and prod which does not have a subdomain."""
Expand All @@ -22,5 +20,5 @@ def get_service_url(service_env: Optional[str], service_base_path: Optional[str]
return f"https://{subdomain}api.service.nhs.uk/{service_base_path}"


def is_pr_env(service_env: Optional[str]) -> bool:
def is_pr_env(service_env: str | None) -> bool:
return service_env is not None and service_env.startswith(PR_ENV_PREFIX)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import responses
from responses import matchers

from common.api_clients.authentication import AppRestrictedAuth, Service
from common.api_clients.authentication import AppRestrictedAuth
from common.models.errors import UnhandledResponseError


Expand All @@ -33,7 +33,7 @@ def setUp(self):
self.cache.get.return_value = None

env = "an-env"
self.authenticator = AppRestrictedAuth(Service.PDS, self.secret_manager_client, env, self.cache)
self.authenticator = AppRestrictedAuth(self.secret_manager_client, env, self.cache)
self.url = f"https://{env}.api.service.nhs.uk/oauth2/token"

@responses.activate
Expand Down Expand Up @@ -89,12 +89,12 @@ def test_env_mapping(self):
"""it should target int environment for none-prod environment, otherwise int"""
# For env=none-prod
env = "some-env"
auth = AppRestrictedAuth(Service.PDS, None, env, None)
auth = AppRestrictedAuth(None, env, None)
self.assertTrue(auth.token_url.startswith(f"https://{env}."))

# For env=prod
env = "prod"
auth = AppRestrictedAuth(Service.PDS, None, env, None)
auth = AppRestrictedAuth(None, env, None)
self.assertTrue(env not in auth.token_url)

def test_returned_cached_token(self):
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_update_cache(self):
self.authenticator.get_access_token()

# Then
self.cache.put.assert_called_once_with(f"{Service.PDS.value}_access_token", cached_token)
self.cache.put.assert_called_once_with("api_client_access_token", cached_token)

@responses.activate
def test_expired_token_in_cache(self):
Expand Down
Loading