Skip to content

Commit 6d92d49

Browse files
committed
Fix TryOpenInner race cast and move regression tests to UnitTests
1 parent 84a4a78 commit 6d92d49

3 files changed

Lines changed: 144 additions & 135 deletions

File tree

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ private SqlConnection(SqlConnection connection)
266266

267267
internal static bool TryGetSystemColumnEncryptionKeyStoreProvider(string keyStoreName, out SqlColumnEncryptionKeyStoreProvider provider)
268268
{
269-
return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider);
269+
return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider);
270270
}
271271

272272
/// <summary>
@@ -1332,7 +1332,7 @@ public override void ChangeDatabase(string database)
13321332
SqlStatistics statistics = null;
13331333
RepairInnerConnection();
13341334
SqlClientEventSource.Log.TryCorrelationTraceEvent("SqlConnection.ChangeDatabase | API | Correlation | Object Id {0}, Activity Id {1}, Database {2}", ObjectID, ActivityCorrelator.Current, database);
1335-
1335+
13361336
try
13371337
{
13381338
statistics = SqlStatistics.StartTimer(Statistics);
@@ -1408,7 +1408,7 @@ public override void Close()
14081408

14091409
SqlStatistics statistics = null;
14101410
Exception e = null;
1411-
1411+
14121412
try
14131413
{
14141414
statistics = SqlStatistics.StartTimer(Statistics);
@@ -1901,7 +1901,7 @@ internal void Abort(Exception e)
19011901
}
19021902

19031903
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/OpenAsync/*' />
1904-
public override Task OpenAsync(CancellationToken cancellationToken)
1904+
public override Task OpenAsync(CancellationToken cancellationToken)
19051905
=> OpenAsync(SqlConnectionOverrides.None, cancellationToken);
19061906

19071907
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/OpenAsyncWithOverrides/*' />
@@ -2224,7 +2224,18 @@ private bool TryOpenInner(TaskCompletionSource<DbConnectionInternal> retry)
22242224
}
22252225
// does not require GC.KeepAlive(this) because of ReRegisterForFinalize below.
22262226

2227-
var tdsInnerConnection = (SqlConnectionInternal)InnerConnection;
2227+
// Capture InnerConnection once into a local to avoid a TOCTOU race: another thread
2228+
// concurrently calling Open() on the same SqlConnection instance can change
2229+
// _innerConnection to DbConnectionClosedConnecting between the TryOpenConnection()
2230+
// call above and the cast below. Without this local capture the second read of
2231+
// InnerConnection may return DbConnectionClosedConnecting, which is not assignable
2232+
// to SqlConnectionInternal and would produce an opaque InvalidCastException.
2233+
// See GitHub issue #3314.
2234+
var innerConnection = InnerConnection;
2235+
if (innerConnection is not SqlConnectionInternal tdsInnerConnection)
2236+
{
2237+
throw ADP.ConnectionAlreadyOpen(State);
2238+
}
22282239

22292240
Debug.Assert(tdsInnerConnection.Parser != null, "Where's the parser?");
22302241

src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionTest.ConcurrentOpen.cs

Lines changed: 0 additions & 130 deletions
This file was deleted.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Data;
7+
using System.Reflection;
8+
using System.Threading.Tasks;
9+
using Xunit;
10+
11+
namespace Microsoft.Data.SqlClient.UnitTests.Microsoft.Data.SqlClient
12+
{
13+
/// <summary>
14+
/// Regression tests for GitHub issue #3314.
15+
///
16+
/// Root cause: TryOpenInner() read InnerConnection twice - once for TryOpenConnection() and
17+
/// again for the cast to SqlConnectionInternal. Between those two reads another thread could
18+
/// change _innerConnection to DbConnectionClosedConnecting, which is not assignable to
19+
/// SqlConnectionInternal, causing an opaque InvalidCastException.
20+
///
21+
/// Fix: InnerConnection is now captured into a local variable once; if it is not a
22+
/// SqlConnectionInternal an InvalidOperationException with a descriptive message is thrown
23+
/// instead of an InvalidCastException.
24+
/// </summary>
25+
public class SqlConnectionConcurrentOpenTests
26+
{
27+
private static object GetConnectingSingleton()
28+
{
29+
Type closedConnectingType = typeof(SqlConnection).Assembly
30+
.GetType("Microsoft.Data.ProviderBase.DbConnectionClosedConnecting", throwOnError: true)!;
31+
FieldInfo singletonField = closedConnectingType
32+
.GetField("SingletonInstance", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public)!;
33+
object? singletonInstance = singletonField.GetValue(null);
34+
Assert.NotNull(singletonInstance);
35+
return singletonInstance!;
36+
}
37+
38+
private static void ForceInnerConnection(SqlConnection connection, object innerConnectionValue)
39+
{
40+
FieldInfo innerConnectionField = typeof(SqlConnection)
41+
.GetField("_innerConnection", BindingFlags.Instance | BindingFlags.NonPublic)!;
42+
innerConnectionField.SetValue(connection, innerConnectionValue);
43+
}
44+
45+
[Fact]
46+
public void InnerConnection_DbConnectionClosedConnecting_IsNotAssignableToSqlConnectionInternal()
47+
{
48+
object connectingSingleton = GetConnectingSingleton();
49+
50+
Type sqlConnectionInternalType = typeof(SqlConnection).Assembly
51+
.GetType("Microsoft.Data.SqlClient.Connection.SqlConnectionInternal", throwOnError: true)!;
52+
53+
Assert.False(
54+
sqlConnectionInternalType.IsInstanceOfType(connectingSingleton),
55+
"DbConnectionClosedConnecting must NOT be assignable to SqlConnectionInternal. " +
56+
"If it were, the race condition in #3314 would not manifest.");
57+
}
58+
59+
[Fact]
60+
public void InnerConnection_InConnectingState_ReportsConnectingState()
61+
{
62+
object connectingSingleton = GetConnectingSingleton();
63+
64+
var connection = new SqlConnection("Data Source=localhost");
65+
ForceInnerConnection(connection, connectingSingleton);
66+
67+
Assert.Equal(ConnectionState.Connecting, connection.State);
68+
}
69+
70+
[Fact]
71+
public void Open_WhenAlreadyConnecting_ThrowsInvalidOperation()
72+
{
73+
object connectingSingleton = GetConnectingSingleton();
74+
75+
var connection = new SqlConnection("Data Source=localhost");
76+
ForceInnerConnection(connection, connectingSingleton);
77+
78+
Assert.Throws<InvalidOperationException>(() => connection.Open());
79+
}
80+
81+
[Fact]
82+
public void TryOpenInner_WhenInnerConnectionRacesToConnectingState_ThrowsInvalidOperation_NotInvalidCast()
83+
{
84+
object connectingSingleton = GetConnectingSingleton();
85+
86+
Type openBusyType = typeof(SqlConnection).Assembly
87+
.GetType("Microsoft.Data.ProviderBase.DbConnectionOpenBusy", throwOnError: true)!;
88+
FieldInfo openBusySingleton = openBusyType
89+
.GetField("SingletonInstance", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public)!;
90+
object? openBusyInstance = openBusySingleton.GetValue(null);
91+
Assert.NotNull(openBusyInstance);
92+
93+
var connection = new SqlConnection("Data Source=localhost");
94+
ForceInnerConnection(connection, connectingSingleton);
95+
96+
Type dbConnectionInternalType = typeof(SqlConnection).Assembly
97+
.GetType("Microsoft.Data.ProviderBase.DbConnectionInternal", throwOnError: true)!;
98+
Type tcsType = typeof(TaskCompletionSource<>).MakeGenericType(dbConnectionInternalType);
99+
object completedRetry = Activator.CreateInstance(tcsType)!;
100+
MethodInfo setResultMethod = tcsType.GetMethod("SetResult")!;
101+
setResultMethod.Invoke(completedRetry, new[] { openBusyInstance! });
102+
103+
MethodInfo tryOpenInner = typeof(SqlConnection)
104+
.GetMethod("TryOpenInner", BindingFlags.Instance | BindingFlags.NonPublic)!;
105+
Assert.NotNull(tryOpenInner);
106+
107+
Exception ex = Assert.ThrowsAny<Exception>(() =>
108+
{
109+
try
110+
{
111+
tryOpenInner.Invoke(connection, new[] { completedRetry });
112+
}
113+
catch (TargetInvocationException tie) when (tie.InnerException != null)
114+
{
115+
throw tie.InnerException;
116+
}
117+
});
118+
119+
Assert.True(
120+
ex is InvalidOperationException,
121+
$"Expected InvalidOperationException but got {ex.GetType().Name}: {ex.Message}. " +
122+
"The fix for #3314 must throw InvalidOperationException (not InvalidCastException) " +
123+
"when _innerConnection races to a non-SqlConnectionInternal state inside TryOpenInner.");
124+
125+
Assert.Contains("connection", ex.Message, StringComparison.OrdinalIgnoreCase);
126+
}
127+
}
128+
}

0 commit comments

Comments
 (0)