-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathDatabaseHelper.cs
More file actions
64 lines (54 loc) · 2.27 KB
/
DatabaseHelper.cs
File metadata and controls
64 lines (54 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
using Azure.Core;
using Azure.Identity;
using Microsoft.Data.SqlClient;
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
namespace dtos_cohort_manager_e2e_tests.Helpers;
public static class DatabaseHelper
{
// Whitelist of allowed table names
private static readonly HashSet<string> AllowedTables = new HashSet<string>
{
"PARTICIPANT_MANAGEMENT",
"PARTICIPANT_DEMOGRAPHIC",
"BS_COHORT_DISTRIBUTION",
};
public static async Task<int> ExecuteNonQueryAsync(
SqlConnectionWithAuthentication sqlAuthConnection,
string query,
params SqlParameter[] parameters)
{
await using var connection = await sqlAuthConnection.GetOpenConnectionAsync();
await using var command = new SqlCommand(query, connection);
command.Parameters.AddRange(parameters);
return await command.ExecuteNonQueryAsync();
}
public static async Task<int> GetRecordCountAsync(SqlConnectionWithAuthentication sqlConnectionWithAuthentication, string tableName)
{
// Check if the table name is in the whitelist
if (!AllowedTables.Contains(tableName.ToUpper()))
{
throw new ArgumentException($"Table '{tableName}' is not in the list of allowed tables.");
}
// Get the open connection (with token if using Managed Identity)
using var connection = await sqlConnectionWithAuthentication.GetOpenConnectionAsync();
// Check if the table actually exists in the database
if (!await TableExistsAsync(connection, tableName))
{
throw new ArgumentException($"Table '{tableName}' does not exist in the database.");
}
var query = "SELECT COUNT(*) FROM " + tableName;
using var command = new SqlCommand(query, connection);
return (int)await command.ExecuteScalarAsync();
}
private static async Task<bool> TableExistsAsync(SqlConnection connection, string tableName)
{
var query = "SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = @TableName";
using (var command = new SqlCommand(query, connection))
{
command.Parameters.AddWithValue("@TableName", tableName);
return (int)await command.ExecuteScalarAsync() > 0;
}
}
}