Skip to content

Commit e90011c

Browse files
committed
data attributes and test fix
1 parent 1a93d61 commit e90011c

8 files changed

Lines changed: 164 additions & 9 deletions

File tree

application/CohortManager/src/Functions/DemographicServices/DemographicDurableFunction/DemographicDurableFunction.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
<PackageReference Include="Microsoft.Azure.Functions.Worker.Extensions.Http.AspNetCore" />
1717
<PackageReference Include="Contrib.Grpc.Core.M1" />
1818
<PackageReference Include="Microsoft.Extensions.Diagnostics.HealthChecks" />
19-
<PackageReference Include="Contrib.Grpc.Core.M1" />
2019
<PackageReference Include="Microsoft.Azure.Functions.Worker.Sdk" />
2120
</ItemGroup>
2221
<ItemGroup>
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
namespace Common;
2+
3+
[AttributeUsage(AttributeTargets.Method)]
4+
public class AuthenticationAttribute : Attribute
5+
{
6+
public Role[]? Roles { get; }
7+
public bool RequiresAuthentication => Roles != null && Roles.Length > 0;
8+
public AuthenticationAttribute(Role role)
9+
{
10+
Roles = new[] { role };
11+
}
12+
public AuthenticationAttribute(Role[] roles)
13+
{
14+
Roles = roles;
15+
}
16+
public AuthenticationAttribute()
17+
{
18+
Roles = null;
19+
}
20+
}

application/CohortManager/src/Functions/Shared/Common/Authentication/CIS2AuthMiddleware.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ public Cis2AuthMiddleware(ILogger<Cis2AuthMiddleware> logger, ICreateResponse cr
2424

2525
public async Task Invoke(FunctionContext context, FunctionExecutionDelegate next)
2626
{
27+
if(!context.RequiresAuthentication())
28+
{
29+
await next(context);
30+
return;
31+
}
32+
2733
var req = await context.GetHttpRequestDataAsync();
2834
var accessToken = string.Empty;
2935
var tokensExist = AuthHelper.TryGetIdTokenFromHeaders(context, out var token);

application/CohortManager/src/Functions/Shared/Common/Authentication/AuthConfig.cs renamed to application/CohortManager/src/Functions/Shared/Common/Authentication/Config/AuthConfig.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public class AuthConfig
1010
public required string AuthClientId { get; init; }
1111
[Required, Url]
1212
public required string UserInfoUrl { get; init; }
13+
public bool RequireAuthentication { get; init; } = true;
1314

1415
}
1516

application/CohortManager/src/Functions/Shared/Common/Authentication/RoleConfig.cs renamed to application/CohortManager/src/Functions/Shared/Common/Authentication/Config/RoleConfig.cs

File renamed without changes.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
namespace Common;
2+
3+
using System.Collections.Concurrent;
4+
using System.Reflection;
5+
using Microsoft.Azure.Functions.Worker;
6+
7+
public static class FunctionContextExtension
8+
{
9+
public static bool RequiresAuthentication(this FunctionContext context)
10+
{
11+
var authAttribute = context.GetEndpoint()?.Metadata.GetMetadata<AuthenticationAttribute>();
12+
return authAttribute != null;
13+
}
14+
15+
public static Role[] GetRequiredRoles(this FunctionContext context)
16+
{
17+
var authAttribute = context.GetEndpoint()?.Metadata.GetMetadata<AuthenticationAttribute>();
18+
return authAttribute?.Roles ?? Array.Empty<Role>();
19+
}
20+
21+
public static FunctionEndpoint? GetEndpoint(this FunctionContext context)
22+
{
23+
ArgumentNullException.ThrowIfNull(context);
24+
25+
return FunctionEndpointCache.GetOrAdd(context.FunctionDefinition.EntryPoint, CreateEndpoint);
26+
}
27+
28+
private static readonly ConcurrentDictionary<string, FunctionEndpoint?> FunctionEndpointCache = new();
29+
30+
private static FunctionEndpoint? CreateEndpoint(string entryPoint)
31+
{
32+
if (string.IsNullOrWhiteSpace(entryPoint))
33+
{
34+
return null;
35+
}
36+
37+
var separatorIndex = entryPoint.LastIndexOf('.');
38+
if (separatorIndex <= 0 || separatorIndex == entryPoint.Length - 1)
39+
{
40+
return null;
41+
}
42+
43+
var typeName = entryPoint[..separatorIndex];
44+
var methodName = entryPoint[(separatorIndex + 1)..];
45+
46+
var declaringType = AppDomain.CurrentDomain
47+
.GetAssemblies()
48+
.Select(assembly => assembly.GetType(typeName, throwOnError: false, ignoreCase: false))
49+
.FirstOrDefault(type => type != null);
50+
51+
var method = declaringType?.GetMethod(
52+
methodName,
53+
BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);
54+
55+
return method == null
56+
? null
57+
: new FunctionEndpoint(method.GetCustomAttributes(inherit: true));
58+
}
59+
}
60+
61+
public sealed class FunctionEndpoint
62+
{
63+
public FunctionEndpoint(IEnumerable<object> metadata)
64+
{
65+
Metadata = new FunctionEndpointMetadataCollection(metadata);
66+
}
67+
68+
public FunctionEndpointMetadataCollection Metadata { get; }
69+
}
70+
71+
public sealed class FunctionEndpointMetadataCollection
72+
{
73+
private readonly IReadOnlyList<object> _metadata;
74+
75+
public FunctionEndpointMetadataCollection(IEnumerable<object> metadata)
76+
{
77+
_metadata = metadata.ToArray();
78+
}
79+
80+
public T? GetMetadata<T>() where T : class
81+
{
82+
return _metadata.OfType<T>().FirstOrDefault();
83+
}
84+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
namespace Common;
2+
3+
4+
using Microsoft.Azure.Functions.Worker;
5+
using Microsoft.Azure.Functions.Worker.Http;
6+
using Microsoft.Azure.Functions.Worker.Middleware;
7+
using Microsoft.Extensions.Logging;
8+
9+
public class PermissionsMiddleware : IFunctionsWorkerMiddleware
10+
{
11+
private readonly ICreateResponse _createResponse;
12+
private readonly IRoleManager _roleManager;
13+
private readonly ILogger<PermissionsMiddleware> _logger;
14+
15+
public PermissionsMiddleware(ICreateResponse createResponse, IRoleManager roleManager, ILogger<PermissionsMiddleware> logger)
16+
{
17+
_createResponse = createResponse;
18+
_roleManager = roleManager;
19+
_logger = logger;
20+
}
21+
22+
public async Task Invoke(FunctionContext context, FunctionExecutionDelegate next)
23+
{
24+
if (!context.RequiresAuthentication())
25+
{
26+
await next(context);
27+
return;
28+
}
29+
30+
var req = await context.GetHttpRequestDataAsync();
31+
var user = (Cis2User)context.Items["Cis2User"]!;
32+
var requiredRoles = context.GetRequiredRoles();
33+
34+
if (requiredRoles.Length == 0 || requiredRoles.Any(role => _roleManager.ValidateRole(user, role)))
35+
{
36+
await next(context);
37+
return;
38+
}
39+
40+
await HandleUnauthorizedAsync(context, req, $"User {user.Uid} does not have required roles to access this resource.", "Forbidden: You do not have permission to access this resource.");
41+
return;
42+
}
43+
44+
private async Task HandleUnauthorizedAsync(FunctionContext context, HttpRequestData request, string logMessage, string responseMessage)
45+
{
46+
var logger = context.GetLogger<PermissionsMiddleware>();
47+
logger.LogWarning("Authorization Error: {LogMessage}", logMessage);
48+
var response = await _createResponse.CreateHttpResponseWithBodyAsync(System.Net.HttpStatusCode.Forbidden, request, responseMessage);
49+
context.GetInvocationResult().Value = response;
50+
}
51+
}

application/CohortManager/src/Functions/screeningDataServices/GetValidationExceptions/GetValidationExceptions.cs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,14 @@ public class GetValidationExceptions
2424
private readonly IValidationExceptionData _validationData;
2525
private readonly IHttpParserHelper _httpParserHelper;
2626
private readonly IPaginationService<ValidationException> _paginationService;
27-
private readonly IRoleManager _roleManager;
2827

29-
public GetValidationExceptions(ILogger<GetValidationExceptions> logger, ICreateResponse createResponse, IValidationExceptionData validationData, IHttpParserHelper httpParserHelper, IPaginationService<ValidationException> paginationService, IRoleManager roleManager)
28+
public GetValidationExceptions(ILogger<GetValidationExceptions> logger, ICreateResponse createResponse, IValidationExceptionData validationData, IHttpParserHelper httpParserHelper, IPaginationService<ValidationException> paginationService)
3029
{
3130
_logger = logger;
3231
_createResponse = createResponse;
3332
_validationData = validationData;
3433
_httpParserHelper = httpParserHelper;
3534
_paginationService = paginationService;
36-
_roleManager = roleManager;
3735
}
3836

3937
/// <summary>
@@ -46,6 +44,7 @@ public GetValidationExceptions(ILogger<GetValidationExceptions> logger, ICreateR
4644
/// Returns 200 OK with data, 204 No Content if empty, 400 Bad Request for validation errors, or 500 Internal Server Error.
4745
/// </returns>
4846
[Function(nameof(GetValidationExceptions))]
47+
[Authentication(Role.CohortManagerUser)]
4948
public async Task<HttpResponseData> Run([HttpTrigger(AuthorizationLevel.Anonymous, "get")] HttpRequestData req)
5049
{
5150
var exceptionId = _httpParserHelper.GetQueryParameterAsInt(req, "exceptionId");
@@ -60,11 +59,6 @@ public async Task<HttpResponseData> Run([HttpTrigger(AuthorizationLevel.Anonymou
6059
var ruleId = _httpParserHelper.GetQueryParameterAsNullableInt(req, "ruleId");
6160
var dateCreated = _httpParserHelper.GetQueryParameterAsDateTime(req, "dateCreated");
6261

63-
if (!_roleManager.ValidateRole((Cis2User)req.FunctionContext.Items["Cis2User"]!, Role.CohortManagerUser))
64-
{
65-
return _createResponse.CreateHttpResponse(HttpStatusCode.Forbidden, req);
66-
}
67-
6862
try
6963
{
7064
if (exceptionId > 0)

0 commit comments

Comments
 (0)