-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathauthentication.py
More file actions
97 lines (84 loc) · 3.36 KB
/
authentication.py
File metadata and controls
97 lines (84 loc) · 3.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import logging
import os
from functools import cached_property
import jwt
from django.conf import settings
from ninja.security import HttpBearer
logger = logging.getLogger(__name__)
ALLOWED_ALGORITHMS = ["RS256"]
JWT_SET_CACHE_TTL_SECONDS = 3600
class Authentication(HttpBearer):
def authenticate(self, request, token) -> dict | None:
"""
Authenticates the incoming request by validating the JWT token.
"""
if self.bypass_authentication:
logger.warning("Authentication bypass is enabled.")
return {"oid": "bypass_object_id", "sub": "bypass_user"}
request.auth = self._decode(token)
return request.auth
def _decode(self, token: str) -> dict | None:
"""
Decodes and validates the JWT token using the provided RSA key.
Checks the signature, audience, and issuer claims to ensure the token is valid and intended for this API.
"""
try:
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
payload = jwt.decode(
token,
signing_key.key,
algorithms=ALLOWED_ALGORITHMS,
audience=self.audience,
issuer=self.issuers,
)
return payload
except jwt.PyJWKClientError:
logger.exception("Error fetching JWKS keys from Azure AD.")
except jwt.ExpiredSignatureError:
logger.exception("Token is expired")
except (jwt.InvalidAudienceError, jwt.InvalidIssuerError):
logger.exception("Invalid claims. Please check the audience and issuer.")
except jwt.InvalidTokenError:
logger.exception("Token is invalid")
except Exception:
logger.exception("Unable to parse authentication token.")
@cached_property
def jwks_client(self) -> jwt.PyJWKClient:
"""
Creates a PyJWKClient instance for fetching and caching the JWKS keys from Azure AD.
Caching is enabled to improve performance and reduce the number of network requests to Azure AD.
The cache will be refreshed after the specified TTL expires.
"""
return jwt.PyJWKClient(
self.discovery_keys_url,
cache_jwk_set=True,
cache_keys=True,
lifespan=JWT_SET_CACHE_TTL_SECONDS,
)
@cached_property
def discovery_keys_url(self) -> str:
return f"https://login.microsoftonline.com/{self.tenant_id}/discovery/v2.0/keys"
@property
def audience(self) -> str | None:
"""
The expected audience claim in the JWT token. This should match the API_AUDIENCE environment variable.
"""
return os.getenv("API_AUDIENCE")
@property
def tenant_id(self) -> str | None:
"""
The Azure AD tenant ID. This should be set as the TENANT_ID environment variable.
"""
return os.getenv("TENANT_ID", "")
@cached_property
def issuers(self) -> list:
"""
The expected issuer claim(s) in the JWT token. This should match the tenant ID and the Azure AD endpoints.
"""
return [
f"https://sts.windows.net/{self.tenant_id}/",
f"https://login.microsoftonline.com/{self.tenant_id}/v2.0/",
]
@property
def bypass_authentication(self) -> bool:
return getattr(settings, "BYPASS_API_AUTHENTICATION", False)