Skip to content

Commit e226845

Browse files
committed
Adding Access Token Headers
1 parent 2e1ab5a commit e226845

11 files changed

Lines changed: 192 additions & 28 deletions

File tree

application/CohortManager/src/Functions/Shared/Common/Authentication/AuthConfig.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ namespace Common;
44

55
public class AuthConfig
66
{
7-
[Required]
7+
[Required, Url]
88
public required string AuthMetaDataUrl { get; init; }
99
[Required]
1010
public required string AuthClientId { get; init; }
11+
[Required, Url]
12+
public required string UserInfoUrl { get; init; }
1113

1214
}
1315

application/CohortManager/src/Functions/Shared/Common/Authentication/AuthHelper.cs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@ namespace Common;
55

66
public static class AuthHelper
77
{
8-
public static bool TryGetTokenFromHeaders(FunctionContext context, out string token)
8+
public static bool TryGetIdTokenFromHeaders(FunctionContext context, out string token)
9+
{
10+
return TryGetBearerTokenFromHeaders(context, "Authorization", out token);
11+
}
12+
13+
public static bool TryGetAccessTokenFromHeaders(FunctionContext context, out string accessToken)
14+
{
15+
return TryGetBearerTokenFromHeaders(context, "X-Access-Token", out accessToken);
16+
}
17+
18+
private static bool TryGetBearerTokenFromHeaders(FunctionContext context, string headerName, out string token)
919
{
1020
token = null!;
1121

@@ -15,18 +25,19 @@ public static bool TryGetTokenFromHeaders(FunctionContext context, out string to
1525
{
1626
return false;
1727
}
18-
var headers = JsonSerializer.Deserialize<Dictionary<string, string>>(headersStr);
19-
if(headers == null)
20-
{
21-
return false;
22-
}
23-
24-
if(!headers.TryGetValue("Authorization", out var authHeader) || !authHeader.StartsWith("Bearer "))
25-
{
26-
return false;
27-
}
28-
29-
token = authHeader.Substring("Bearer ".Length).Trim();
30-
return true;
28+
29+
var headers = JsonSerializer.Deserialize<Dictionary<string, string>>(headersStr);
30+
if(headers == null)
31+
{
32+
return false;
33+
}
34+
35+
if(!headers.TryGetValue(headerName, out var authHeader) || !authHeader.StartsWith("Bearer "))
36+
{
37+
return false;
38+
}
39+
40+
token = authHeader.Substring("Bearer ".Length).Trim();
41+
return true;
3142
}
3243
}

application/CohortManager/src/Functions/Shared/Common/Authentication/CIS2AuthMiddleware.cs

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ namespace Common;
22

33
using System.Net;
44
using Microsoft.Azure.Functions.Worker;
5+
using Microsoft.Azure.Functions.Worker.Http;
56
using Microsoft.Azure.Functions.Worker.Middleware;
67
using Microsoft.Extensions.Logging;
78

@@ -11,39 +12,53 @@ public class Cis2AuthMiddleware : IFunctionsWorkerMiddleware
1112
private readonly ILogger<Cis2AuthMiddleware> _logger;
1213
private readonly ICreateResponse _createResponse;
1314
private readonly IAuthenticationService _authService;
15+
private readonly ICis2UserService _cis2UserService;
1416

15-
public Cis2AuthMiddleware(ILogger<Cis2AuthMiddleware> logger, ICreateResponse createResponse, IAuthenticationService authService)
17+
public Cis2AuthMiddleware(ILogger<Cis2AuthMiddleware> logger, ICreateResponse createResponse, IAuthenticationService authService, ICis2UserService cis2UserService)
1618
{
1719
_logger = logger;
1820
_createResponse = createResponse;
1921
_authService = authService;
22+
_cis2UserService = cis2UserService;
2023
}
2124

2225
public async Task Invoke(FunctionContext context, FunctionExecutionDelegate next)
2326
{
2427
var req = await context.GetHttpRequestDataAsync();
28+
var accessToken = string.Empty;
29+
var tokensExist = AuthHelper.TryGetIdTokenFromHeaders(context, out var token);
30+
tokensExist = tokensExist && AuthHelper.TryGetAccessTokenFromHeaders(context, out accessToken);
2531

26-
var tokenExists = AuthHelper.TryGetTokenFromHeaders(context, out var token);
27-
28-
if(!tokenExists)
32+
if(!tokensExist)
2933
{
30-
_logger.LogWarning("Authorization header is missing or invalid");
31-
var response = await _createResponse.CreateHttpResponseWithBodyAsync(HttpStatusCode.Unauthorized, req!, "Unauthorized: Missing or invalid Authorization header.");
32-
context.GetInvocationResult().Value = response;
34+
await HandleUnauthorizedAsync(context, req!, "Authorization header is missing or invalid", "Unauthorized: Missing or invalid Authorization header.");
3335
return;
3436
}
3537

3638
var validateToken = await _authService.ValidateTokenAsync(token);
3739

3840
if(!validateToken)
3941
{
40-
_logger.LogWarning("Token validation failed");
41-
var response = await _createResponse.CreateHttpResponseWithBodyAsync(HttpStatusCode.Unauthorized, req!, "Unauthorized: Invalid token.");
42-
context.GetInvocationResult().Value = response;
42+
await HandleUnauthorizedAsync(context, req!, "Token validation failed", "Unauthorized: Invalid token.");
43+
return;
44+
}
45+
46+
var cis2User = await _cis2UserService.GetUserFromToken(accessToken);
47+
if(cis2User == null)
48+
{
49+
await HandleUnauthorizedAsync(context, req!, "Failed to retrieve user from token", "Unauthorized: Failed to retrieve user from token.");
4350
return;
4451
}
4552

53+
context.Items["Cis2User"] = cis2User;
4654
context.Items["AuthToken"] = token;
4755
await next(context);
4856
}
57+
58+
private async Task HandleUnauthorizedAsync(FunctionContext context, HttpRequestData request, string logMessage, string responseMessage)
59+
{
60+
_logger.LogWarning(logMessage);
61+
var response = await _createResponse.CreateHttpResponseWithBodyAsync(HttpStatusCode.Unauthorized, request, responseMessage);
62+
context.GetInvocationResult().Value = response;
63+
}
4964
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
namespace Common;
2+
3+
using System.Collections.Generic;
4+
using System.Text.Json.Serialization;
5+
6+
public class Cis2User
7+
{
8+
[JsonPropertyName("nhsid_useruid")]
9+
public string NhsidUseruid { get; set; }
10+
11+
[JsonPropertyName("name")]
12+
public string Name { get; set; }
13+
14+
[JsonPropertyName("nhsid_nrbac_roles")]
15+
public List<NhsidNrbacRole> NhsidNrbacRoles { get; set; }
16+
17+
[JsonPropertyName("given_name")]
18+
public string GivenName { get; set; }
19+
20+
[JsonPropertyName("family_name")]
21+
public string FamilyName { get; set; }
22+
23+
[JsonPropertyName("uid")]
24+
public string Uid { get; set; }
25+
26+
[JsonPropertyName("sub")]
27+
public string Sub { get; set; }
28+
}
29+
public class NhsidNrbacRole
30+
{
31+
[JsonPropertyName("person_orgid")]
32+
public string PersonOrgid { get; set; }
33+
34+
[JsonPropertyName("person_roleid")]
35+
public string PersonRoleid { get; set; }
36+
37+
[JsonPropertyName("org_code")]
38+
public string OrgCode { get; set; }
39+
40+
[JsonPropertyName("role_name")]
41+
public string RoleName { get; set; }
42+
43+
[JsonPropertyName("role_code")]
44+
public string RoleCode { get; set; }
45+
46+
[JsonPropertyName("workgroups")]
47+
public List<string> Workgroups { get; set; }
48+
49+
[JsonPropertyName("workgroups_codes")]
50+
public List<string> WorkgroupsCodes { get; set; }
51+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
namespace Common;
2+
3+
using System.Text.Json;
4+
using Hl7.FhirPath.Sprache;
5+
using Microsoft.Extensions.Logging;
6+
using Microsoft.Extensions.Options;
7+
8+
public class Cis2UserService : ICis2UserService
9+
{
10+
ILogger<Cis2UserService> _logger;
11+
IHttpClientFunction _httpClient;
12+
AuthConfig _authConfig;
13+
14+
public Cis2UserService(ILogger<Cis2UserService> logger, IHttpClientFunction httpClient, IOptions<AuthConfig> authConfig)
15+
{
16+
_logger = logger;
17+
_httpClient = httpClient;
18+
_authConfig = authConfig.Value;
19+
}
20+
21+
public async Task<Cis2User?> GetUserFromToken(string token)
22+
{
23+
try{
24+
_httpClient.SetBearerToken(token);
25+
var response = await _httpClient.SendGetOrThrowAsync(_authConfig.UserInfoUrl);
26+
if(response == null)
27+
{
28+
_logger.LogError("Failed to get user info from token, response is null");
29+
return null;
30+
}
31+
var cis2User = JsonSerializer.Deserialize<Cis2User>(response);
32+
33+
return cis2User;
34+
}
35+
catch(Exception ex)
36+
{
37+
_logger.LogError(ex, "Failed to get user info from token, message: {Message}", ex.Message);
38+
return null;
39+
}
40+
}
41+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
namespace Common;
2+
3+
public interface ICis2UserService
4+
{
5+
/// <summary>
6+
/// Gets the user information from the token using the configured UserInfo endpoint.
7+
/// </summary>
8+
/// <param name="token">The token to get the user information from.</param>
9+
/// <returns>A Cis2User object containing the user information, or null if the user information could not be retrieved.</returns>
10+
Task<Cis2User?> GetUserFromToken(string token);
11+
}

application/CohortManager/src/Functions/Shared/Common/Extensions/AuthenticationExtension.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public static IHostBuilder AddAuthentication(this IHostBuilder hostBuilder)
1616
hostBuilder.ConfigureServices((context, services) =>
1717
{
1818
services.AddSingleton<IAuthenticationService, JwtAuthentication>();
19+
services.AddSingleton<ICis2UserService,Cis2UserService>();
1920
});
2021
return hostBuilder;
2122
}

application/CohortManager/src/Functions/Shared/Common/HttpClientFunction.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public class HttpClientFunction : IHttpClientFunction
1111
private readonly IHttpClientFactory _factory;
1212
public static readonly TimeSpan _timeout = TimeSpan.FromSeconds(300);
1313
private const string errorMessage = "Failed to execute request to {Url}, message: {Message}";
14+
private string _bearerToken = string.Empty;
1415

1516
public HttpClientFunction(ILogger<HttpClientFunction> logger, IHttpClientFactory factory)
1617
{
@@ -25,7 +26,7 @@ public async Task<HttpResponseMessage> SendPost(string url, string data)
2526

2627
client.BaseAddress = new Uri(url);
2728
client.Timeout = _timeout;
28-
29+
SetClientBearerToken(client);
2930
try
3031
{
3132
HttpResponseMessage response = await client.PostAsync(url, jsonContent);
@@ -45,6 +46,7 @@ public async Task<HttpResponseMessage> SendPost(string url, Dictionary<string, s
4546

4647
client.BaseAddress = new Uri(url);
4748
client.Timeout = _timeout;
49+
SetClientBearerToken(client);
4850

4951
try
5052
{
@@ -64,6 +66,7 @@ public async Task<string> SendGet(string url)
6466

6567
client.BaseAddress = new Uri(url);
6668
client.Timeout = _timeout;
69+
SetClientBearerToken(client);
6770

6871
return await GetAsync(client);
6972
}
@@ -76,6 +79,7 @@ public async Task<string> SendGet(string url, Dictionary<string, string> paramet
7679

7780
client.BaseAddress = new Uri(url);
7881
client.Timeout = _timeout;
82+
SetClientBearerToken(client);
7983

8084
return await GetAsync(client);
8185
}
@@ -88,6 +92,7 @@ public async Task<HttpResponseMessage> SendGetResponse(string url, Dictionary<st
8892

8993
client.BaseAddress = new Uri(url);
9094
client.Timeout = _timeout;
95+
SetClientBearerToken(client);
9196

9297
return await client.GetAsync(url);
9398
}
@@ -99,6 +104,7 @@ public async Task<HttpResponseMessage> SendGetResponse(string url)
99104

100105
client.BaseAddress = new Uri(url);
101106
client.Timeout = _timeout;
107+
SetClientBearerToken(client);
102108

103109
return await client.GetAsync(url);
104110
}
@@ -109,6 +115,7 @@ public async Task<string> SendGetOrThrowAsync(string url)
109115

110116
client.BaseAddress = new Uri(url);
111117
client.Timeout = _timeout;
118+
SetClientBearerToken(client);
112119

113120
return await GetOrThrowAsync(client);
114121
}
@@ -155,6 +162,7 @@ public async Task<HttpResponseMessage> SendPut(string url, string data)
155162

156163
client.BaseAddress = new Uri(url);
157164
client.Timeout = _timeout;
165+
SetClientBearerToken(client);
158166

159167
try
160168
{
@@ -174,6 +182,7 @@ public async Task<bool> SendDelete(string url)
174182

175183
client.BaseAddress = new Uri(url);
176184
client.Timeout = _timeout;
185+
SetClientBearerToken(client);
177186

178187
try
179188
{
@@ -198,6 +207,19 @@ public async Task<string> GetResponseText(HttpResponseMessage response)
198207
return await response.Content.ReadAsStringAsync();
199208
}
200209

210+
public void SetBearerToken(string token)
211+
{
212+
_bearerToken = token;
213+
}
214+
215+
private void SetClientBearerToken(HttpClient client)
216+
{
217+
if (!string.IsNullOrEmpty(_bearerToken))
218+
{
219+
client.DefaultRequestHeaders.Add("Authorization", "Bearer " + _bearerToken);
220+
}
221+
}
222+
201223
/// <summary>
202224
/// Removes the query string from the URL to prevent us logging sensitive information.
203225
/// </summary>

application/CohortManager/src/Functions/Shared/Common/Interfaces/IHttpClientFunction.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,10 @@ public interface IHttpClientFunction
8585
/// <param name="response">HTTP response message.</param>
8686
/// <returns>string<returns>
8787
Task<string> GetResponseText(HttpResponseMessage response);
88+
89+
/// <summary>
90+
/// Sets the bearer token to be used in HttpClient requests.
91+
/// </summary>
92+
/// <param name="token">Bearer token to be used in HttpClient requests.</param>
93+
void SetBearerToken(string token);
8894
}

application/CohortManager/src/Web/app/lib/fetchExceptions.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,17 @@ function buildQueryString(params: FetchExceptionsParams): string {
4343

4444
async function getAuthHeaders() {
4545
const session = await auth();
46-
const bearerToken = session?.idToken ?? session?.accessToken;
46+
const bearerToken = session?.idToken;
47+
const accessToken = session?.accessToken;
4748

48-
if (!bearerToken) {
49+
if (!bearerToken || !accessToken) {
4950
return undefined;
5051
}
5152

53+
5254
return {
5355
Authorization: `Bearer ${bearerToken}`,
56+
"X-Access-Token": accessToken
5457
};
5558
}
5659

0 commit comments

Comments
 (0)