Skip to content

Commit 322b672

Browse files
committed
- Securely clean up unmanaged memory buffers before deallocation.
- Add `EnsureInitialized` and enhanced connection status validation methods to `SshSession` for operational consistency. - Introduce host key hash retrieval with `HostKeyHashType` enum for SSH connections. - Re-enable and include host key retrieval test in `NullOpsDevs.LibSsh.Test`. - Securely wipe credentials after use in `SshPasswordCredential`.
1 parent 8e2ec85 commit 322b672

5 files changed

Lines changed: 75 additions & 3 deletions

File tree

NullOpsDevs.LibSsh.Test/Program.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ private static async Task RunTestCategory(string categoryName, Func<Task> testFu
8282

8383
private static async Task RunAuthenticationTests()
8484
{
85+
await RunTest("Host key retrival", TestHostKeyRetrival);
8586
await RunTest("Password Authentication", TestPasswordAuth);
8687
await RunTest("Public Key Authentication (no passphrase)", TestPublicKeyAuth);
8788
await RunTest("Public Key Authentication (with passphrase)", TestPublicKeyAuthWithPassphrase);
8889
await RunTest("Public Key from Memory", TestPublicKeyFromMemory);
89-
// await RunTest("SSH Agent Authentication", TestSshAgentAuth);
90+
await RunTest("SSH Agent Authentication", TestSshAgentAuth);
9091
}
9192

9293
private static Task<bool> TestPasswordAuth()
@@ -160,6 +161,13 @@ private static Task<bool> TestSshAgentAuth()
160161
return Task.FromResult(true);
161162
}
162163
}
164+
165+
private static Task<bool> TestHostKeyRetrival()
166+
{
167+
using var session = TestHelper.CreateAndConnect();
168+
var hostKey = session.GetHostKeyHash(HostKeyHashType.SHA256);
169+
return Task.FromResult(hostKey.Length == 32);
170+
}
163171

164172
#endregion
165173

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System.Diagnostics.CodeAnalysis;
2+
using JetBrains.Annotations;
3+
using NullOpsDevs.LibSsh.Generated;
4+
5+
namespace NullOpsDevs.LibSsh.Core;
6+
7+
[PublicAPI]
8+
[SuppressMessage("ReSharper", "InconsistentNaming")]
9+
public enum HostKeyHashType
10+
{
11+
MD5 = LibSshNative.LIBSSH2_HOSTKEY_HASH_MD5,
12+
13+
SHA1 = LibSshNative.LIBSSH2_HOSTKEY_HASH_SHA1,
14+
15+
SHA256 = LibSshNative.LIBSSH2_HOSTKEY_HASH_SHA256
16+
}

NullOpsDevs.LibSsh/Core/SshSession.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ private void EnsureInStatus(SshConnectionStatus status)
5757
throw new SshException($"SshConnection must be in status '{status:G}' to perform that operation.", SshError.DevWrongUse);
5858
}
5959

60+
private void EnsureInStatuses(params SshConnectionStatus[] statuses)
61+
{
62+
foreach (var status in statuses)
63+
{
64+
if (ConnectionStatus == status)
65+
return;
66+
}
67+
68+
throw new SshException($"SshConnection must be in one of the statuses '{string.Join(", ", statuses.Select(s => $"{s:G}"))}' to perform that operation.", SshError.DevWrongUse);
69+
}
70+
6071
/// <summary>
6172
/// Connects to an SSH server at the specified host and port.
6273
/// </summary>
@@ -110,6 +121,26 @@ public unsafe void Connect(string host, int port)
110121
}
111122
}
112123

124+
public unsafe byte[] GetHostKeyHash(HostKeyHashType keyHashType)
125+
{
126+
EnsureInitialized();
127+
EnsureInStatuses(SshConnectionStatus.Connected, SshConnectionStatus.LoggedIn);
128+
129+
var keySize = keyHashType switch
130+
{
131+
HostKeyHashType.MD5 => 16,
132+
HostKeyHashType.SHA1 => 20,
133+
HostKeyHashType.SHA256 => 32,
134+
_ => throw new ArgumentOutOfRangeException(nameof(keyHashType), keyHashType, null)
135+
};
136+
137+
var hash = libssh2_hostkey_hash(session, (int) keyHashType);
138+
var keyHash = new byte[keySize];
139+
Marshal.Copy(new IntPtr(hash), keyHash, 0, keySize);
140+
141+
return keyHash;
142+
}
143+
113144
/// <summary>
114145
/// Asynchronously connects to an SSH server at the specified host and port.
115146
/// </summary>
@@ -140,6 +171,9 @@ public Task ConnectAsync(string host, int port, CancellationToken cancellationTo
140171
/// </remarks>
141172
public unsafe SshCommandResult ExecuteCommand(string command, CommandExecutionOptions? options = null, CancellationToken cancellationToken = default)
142173
{
174+
EnsureInitialized();
175+
EnsureInStatus(SshConnectionStatus.LoggedIn);
176+
143177
options ??= CommandExecutionOptions.Default;
144178

145179
LibSsh2.Log($"Opening channel for command execution: '{command}'");
@@ -287,6 +321,9 @@ public Task<SshCommandResult> ExecuteCommandAsync(string command, CommandExecuti
287321
/// </remarks>
288322
public unsafe bool ReadFile(string path, Stream destination, int bufferSize = 32768, CancellationToken cancellationToken = default)
289323
{
324+
EnsureInitialized();
325+
EnsureInStatus(SshConnectionStatus.LoggedIn);
326+
290327
LibSsh2.Log($"Starting SCP download of file: '{path}'");
291328
using var remotePathBuffer = NativeBuffer.Allocate(path);
292329
using var statBuffer = NativeBuffer.Allocate(512);
@@ -363,6 +400,9 @@ public Task<bool> ReadFileAsync(string path, Stream destination, int bufferSize
363400
/// </remarks>
364401
public unsafe bool WriteFile(string path, Stream source, int mode = 420, int bufferSize = 32768, CancellationToken cancellationToken = default)
365402
{
403+
EnsureInitialized();
404+
EnsureInStatus(SshConnectionStatus.LoggedIn);
405+
366406
LibSsh2.Log($"Starting SCP upload to file: '{path}'");
367407

368408
if (!source.CanRead)

NullOpsDevs.LibSsh/Credentials/SshPasswordCredential.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Runtime.InteropServices;
1+
using System.Runtime.CompilerServices;
2+
using System.Runtime.InteropServices;
23
using NullOpsDevs.LibSsh.Generated;
34

45
namespace NullOpsDevs.LibSsh.Credentials;
@@ -24,6 +25,9 @@ public override unsafe bool Authenticate(_LIBSSH2_SESSION* session)
2425
(sbyte*) usernameBuffer, (uint)username.Length,
2526
(sbyte*) passwordBuffer, (uint)password.Length, null);
2627

28+
Unsafe.InitBlockUnaligned(usernameBuffer.ToPointer(), 0, (uint)username.Length);
29+
Unsafe.InitBlockUnaligned(passwordBuffer.ToPointer(), 0, (uint)username.Length);
30+
2731
Marshal.FreeHGlobal(usernameBuffer);
2832
Marshal.FreeHGlobal(passwordBuffer);
2933

NullOpsDevs.LibSsh/Interop/NativeBuffer.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ internal readonly ref struct NativeBuffer(IntPtr pointer, int length) : IDisposa
4545
public unsafe T* AsPointer<T>() where T : unmanaged => (T*)Pointer.ToPointer();
4646

4747
/// <inheritdoc />
48-
public void Dispose() => Marshal.FreeHGlobal(Pointer);
48+
public unsafe void Dispose()
49+
{
50+
Unsafe.InitBlockUnaligned(AsPointer(), 0, (uint)Length);
51+
Marshal.FreeHGlobal(Pointer);
52+
}
4953

5054
/// <summary>
5155
/// Allocates a native buffer of the specified length.

0 commit comments

Comments
 (0)