-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathauthentication.py
More file actions
97 lines (77 loc) · 3.13 KB
/
authentication.py
File metadata and controls
97 lines (77 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import base64
import json
import time
import uuid
from enum import Enum
import jwt
import requests
from common.clients import logger
from common.models.errors import UnhandledResponseError
from ..cache import Cache
class Service(Enum):
PDS = "pds"
IMMUNIZATION = "imms"
class AppRestrictedAuth:
def __init__(self, service: Service, secret_manager_client, environment, cache: Cache):
self.secret_manager_client = secret_manager_client
self.cache = cache
self.cache_key = f"{service.value}_access_token"
self.expiry = 30
self.secret_name = (
f"imms/outbound/{environment}/jwt-secrets"
if service == Service.PDS
else f"imms/immunization/{environment}/jwt-secrets"
)
self.token_url = (
f"https://{environment}.api.service.nhs.uk/oauth2/token"
if environment != "prod"
else "https://api.service.nhs.uk/oauth2/token"
)
def get_service_secrets(self):
kwargs = {"SecretId": self.secret_name}
response = self.secret_manager_client.get_secret_value(**kwargs)
secret_object = json.loads(response["SecretString"])
secret_object["private_key"] = base64.b64decode(secret_object["private_key_b64"]).decode()
return secret_object
def create_jwt(self, now: int):
logger.info("create_jwt")
secret_object = self.get_service_secrets()
claims = {
"iss": secret_object["api_key"],
"sub": secret_object["api_key"],
"aud": self.token_url,
"iat": now,
"exp": now + self.expiry,
"jti": str(uuid.uuid4()),
}
return jwt.encode(
claims,
secret_object["private_key"],
algorithm="RS512",
headers={"kid": secret_object["kid"]},
)
def get_access_token(self):
logger.info("get_access_token")
now = int(time.time())
logger.info(f"Current time: {now}, Expiry time: {now + self.expiry}")
# Check if token is cached and not expired
logger.info(f"Cache key: {self.cache_key}")
logger.info("Checking cache for access token")
cached = self.cache.get(self.cache_key)
if cached and cached["expires_at"] > now:
logger.info("Returning cached access token")
return cached["token"]
logger.info("No valid cached token found, creating new token")
_jwt = self.create_jwt(now)
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"grant_type": "client_credentials",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": _jwt,
}
token_response = requests.post(self.token_url, data=data, headers=headers)
if token_response.status_code != 200:
raise UnhandledResponseError(response=token_response.text, message="Failed to get access token")
token = token_response.json().get("access_token")
self.cache.put(self.cache_key, {"token": token, "expires_at": now + self.expiry})
return token