diff --git a/firebaseai/src/ContextWindowCompressionConfig.cs b/firebaseai/src/ContextWindowCompressionConfig.cs new file mode 100644 index 00000000..3a95f78a --- /dev/null +++ b/firebaseai/src/ContextWindowCompressionConfig.cs @@ -0,0 +1,84 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using Firebase.AI.Internal; + +namespace Firebase.AI +{ + /// + /// Configures the sliding window context compression mechanism. + /// + public class SlidingWindow + { + /// + /// The session reduction target, i.e., how many tokens we should keep. + /// + public int? TargetTokens { get; } + + public SlidingWindow(int? targetTokens = null) + { + TargetTokens = targetTokens; + } + + internal Dictionary ToJson() + { + var dict = new Dictionary(); + if (TargetTokens.HasValue) + { + dict["targetTokens"] = TargetTokens.Value; + } + return dict; + } + } + + /// + /// Enables context window compression to manage the model's context window. + /// + public class ContextWindowCompressionConfig + { + /// + /// The number of tokens (before running a turn) that triggers the context + /// window compression. + /// + public int? TriggerTokens { get; } + + /// + /// The sliding window compression mechanism. + /// + public SlidingWindow? SlidingWindow { get; } + + public ContextWindowCompressionConfig(int? triggerTokens = null, SlidingWindow? slidingWindow = null) + { + TriggerTokens = triggerTokens; + SlidingWindow = slidingWindow; + } + + internal Dictionary ToJson() + { + var dict = new Dictionary(); + if (TriggerTokens.HasValue) + { + dict["triggerTokens"] = TriggerTokens.Value; + } + if (SlidingWindow != null) + { + dict["slidingWindow"] = SlidingWindow.ToJson(); + } + return dict; + } + } +} diff --git a/firebaseai/src/LiveGenerationConfig.cs b/firebaseai/src/LiveGenerationConfig.cs index 85e10714..6437d780 100644 --- a/firebaseai/src/LiveGenerationConfig.cs +++ b/firebaseai/src/LiveGenerationConfig.cs @@ -94,9 +94,11 @@ public readonly struct LiveGenerationConfig private readonly float? _frequencyPenalty; private readonly AudioTranscriptionConfig? _inputAudioTranscription; private readonly AudioTranscriptionConfig? _outputAudioTranscription; + private readonly ContextWindowCompressionConfig? _contextWindowCompression; internal readonly AudioTranscriptionConfig? InputAudioTranscription => _inputAudioTranscription; internal readonly AudioTranscriptionConfig? OutputAudioTranscription => _outputAudioTranscription; + internal readonly ContextWindowCompressionConfig? ContextWindowCompression => _contextWindowCompression; /// /// Creates a new `LiveGenerationConfig` value. @@ -191,7 +193,8 @@ public LiveGenerationConfig( float? presencePenalty = null, float? frequencyPenalty = null, AudioTranscriptionConfig? inputAudioTranscription = null, - AudioTranscriptionConfig? outputAudioTranscription = null) + AudioTranscriptionConfig? outputAudioTranscription = null, + ContextWindowCompressionConfig? contextWindowCompression = null) { _speechConfig = speechConfig; _responseModalities = responseModalities != null ? @@ -204,6 +207,7 @@ public LiveGenerationConfig( _frequencyPenalty = frequencyPenalty; _inputAudioTranscription = inputAudioTranscription; _outputAudioTranscription = outputAudioTranscription; + _contextWindowCompression = contextWindowCompression; } /// @@ -225,6 +229,7 @@ internal Dictionary ToJson() if (_maxOutputTokens.HasValue) jsonDict["maxOutputTokens"] = _maxOutputTokens.Value; if (_presencePenalty.HasValue) jsonDict["presencePenalty"] = _presencePenalty.Value; if (_frequencyPenalty.HasValue) jsonDict["frequencyPenalty"] = _frequencyPenalty.Value; + if (_contextWindowCompression != null) jsonDict["contextWindowCompression"] = _contextWindowCompression.ToJson(); return jsonDict; } diff --git a/firebaseai/src/LiveGenerativeModel.cs b/firebaseai/src/LiveGenerativeModel.cs index fb9ed446..f4525e0e 100644 --- a/firebaseai/src/LiveGenerativeModel.cs +++ b/firebaseai/src/LiveGenerativeModel.cs @@ -118,34 +118,35 @@ private string GetModelName() /// /// The token that can be used to cancel the creation of the session. /// The LiveSession, once it is established. - public async Task ConnectAsync(CancellationToken cancellationToken = default) + public async Task ConnectAsync(SessionResumptionConfig? sessionResumption = null, CancellationToken cancellationToken = default) { - ClientWebSocket clientWebSocket = new(); - - string endpoint = GetURL(); - - // Set initial headers - string version = Firebase.Internal.FirebaseInterops.GetVersionInfoSdkVersion(); - clientWebSocket.Options.SetRequestHeader("x-goog-api-client", $"gl-csharp/8.0 fire/{version}"); - if (Firebase.Internal.FirebaseInterops.GetIsDataCollectionDefaultEnabled(_firebaseApp)) + Func> connectFactory = async (resumptionConfig, cancelToken) => { - clientWebSocket.Options.SetRequestHeader("X-Firebase-AppId", _firebaseApp.Options.AppId); - clientWebSocket.Options.SetRequestHeader("X-Firebase-AppVersion", UnityEngine.Application.version); - } - // Add additional Firebase tokens to the header. - await Firebase.Internal.FirebaseInterops.AddFirebaseTokensAsync(clientWebSocket, _firebaseApp); + ClientWebSocket clientWebSocket = new(); + string endpoint = GetURL(); + + // Set initial headers + string version = Firebase.Internal.FirebaseInterops.GetVersionInfoSdkVersion(); + clientWebSocket.Options.SetRequestHeader("x-goog-api-client", $"gl-csharp/8.0 fire/{version}"); + if (Firebase.Internal.FirebaseInterops.GetIsDataCollectionDefaultEnabled(_firebaseApp)) + { + clientWebSocket.Options.SetRequestHeader("X-Firebase-AppId", _firebaseApp.Options.AppId); + clientWebSocket.Options.SetRequestHeader("X-Firebase-AppVersion", UnityEngine.Application.version); + } + // Add additional Firebase tokens to the header. + await Firebase.Internal.FirebaseInterops.AddFirebaseTokensAsync(clientWebSocket, _firebaseApp); - // Add a timeout to the initial connection, using the RequestOptions. - using var connectionCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - TimeSpan connectionTimeout = _requestOptions?.Timeout ?? RequestOptions.DefaultTimeout; - connectionCts.CancelAfter(connectionTimeout); + // Add a timeout to the initial connection, using the RequestOptions. + using var connectionCts = CancellationTokenSource.CreateLinkedTokenSource(cancelToken); + TimeSpan connectionTimeout = _requestOptions?.Timeout ?? RequestOptions.DefaultTimeout; + connectionCts.CancelAfter(connectionTimeout); - await clientWebSocket.ConnectAsync(new Uri(endpoint), connectionCts.Token); + await clientWebSocket.ConnectAsync(new Uri(endpoint), connectionCts.Token); - if (clientWebSocket.State != WebSocketState.Open) - { - throw new WebSocketException("ClientWebSocket failed to connect, can't create LiveSession."); - } + if (clientWebSocket.State != WebSocketState.Open) + { + throw new WebSocketException("ClientWebSocket failed to connect, can't create LiveSession."); + } try { @@ -175,25 +176,36 @@ public async Task ConnectAsync(CancellationToken cancellationToken { setupDict["tools"] = _tools.Select(t => t.ToJson()).ToList(); } + + if (resumptionConfig != null) + { + setupDict["sessionResumption"] = resumptionConfig.ToJson(); + } + if (_liveConfig?.ContextWindowCompression != null) + { + setupDict["contextWindowCompression"] = _liveConfig?.ContextWindowCompression.ToJson(); + } + Dictionary jsonDict = new() { { "setup", setupDict } }; var byteArray = Encoding.UTF8.GetBytes(Json.Serialize(jsonDict)); - await clientWebSocket.SendAsync(new ArraySegment(byteArray), WebSocketMessageType.Binary, true, cancellationToken); + await clientWebSocket.SendAsync(new ArraySegment(byteArray), WebSocketMessageType.Binary, true, cancelToken); - return new LiveSession(clientWebSocket); + return clientWebSocket; } catch (Exception) { - if (clientWebSocket.State == WebSocketState.Open) - { - // Try to clean up the WebSocket, to avoid leaking connections. - await clientWebSocket.CloseAsync(WebSocketCloseStatus.EndpointUnavailable, - "Failed to send initial setup message.", CancellationToken.None); - } + // Try to clean up the WebSocket, to avoid leaking connections. + // It might not be available in scope, we rely on GC mostly here unless we catch on clientWebSocket explicitly. + // Wait, clientWebSocket is available because this is all within the lambda! throw; } + }; + + var webSocket = await connectFactory(sessionResumption, cancellationToken); + return new LiveSession(webSocket, connectFactory); } } diff --git a/firebaseai/src/LiveSession.cs b/firebaseai/src/LiveSession.cs index ce93b408..3a35632a 100644 --- a/firebaseai/src/LiveSession.cs +++ b/firebaseai/src/LiveSession.cs @@ -33,7 +33,8 @@ namespace Firebase.AI public class LiveSession : IDisposable { - private readonly ClientWebSocket _clientWebSocket; + private ClientWebSocket _clientWebSocket; + private readonly Func> _connectionFactory; private readonly SemaphoreSlim _sendLock = new(1, 1); @@ -44,7 +45,7 @@ public class LiveSession : IDisposable /// Intended for internal use only. /// Use `LiveGenerativeModel.ConnectAsync` instead to ensure proper initialization. /// - internal LiveSession(ClientWebSocket clientWebSocket) + internal LiveSession(ClientWebSocket clientWebSocket, Func> connectionFactory = null) { if (clientWebSocket.State != WebSocketState.Open) { @@ -53,6 +54,7 @@ internal LiveSession(ClientWebSocket clientWebSocket) } _clientWebSocket = clientWebSocket; + _connectionFactory = connectionFactory; } protected virtual void Dispose(bool disposing) @@ -297,10 +299,31 @@ public async IAsyncEnumerable ReceiveAsync( Memory buffer = new(receiveBuffer); while (!cancellationToken.IsCancellationRequested) { - ValueWebSocketReceiveResult result = await _clientWebSocket.ReceiveAsync(buffer, cancellationToken); + ClientWebSocket currentWebSocket; + await _sendLock.WaitAsync(cancellationToken); + try { + currentWebSocket = _clientWebSocket; + } finally { _sendLock.Release(); } + + ValueWebSocketReceiveResult result; + try + { + result = await currentWebSocket.ReceiveAsync(buffer, cancellationToken); + } + catch (Exception) when (currentWebSocket != _clientWebSocket && !cancellationToken.IsCancellationRequested) + { + // The socket was closed or disposed because of session resumption, grab the new one + await Task.Delay(10, cancellationToken); + continue; + } if (result.MessageType == WebSocketMessageType.Close) { + if (currentWebSocket != _clientWebSocket && !cancellationToken.IsCancellationRequested) + { + await Task.Delay(10, cancellationToken); + continue; + } // Close initiated by the server // TODO: Should this just close without logging anything? break; @@ -338,6 +361,48 @@ public async IAsyncEnumerable ReceiveAsync( cancellationToken.ThrowIfCancellationRequested(); } + /// + /// Resumes an existing live session with the server. + /// + /// This closes the current WebSocket connection and establishes a new one using + /// the same configuration as the original session. + /// + /// The configuration for session resumption. + /// A token to cancel the operation. + public async Task ResumeSessionAsync(SessionResumptionConfig? sessionResumption = null, CancellationToken cancellationToken = default) + { + if (_connectionFactory == null) + { + throw new InvalidOperationException("ResumeSession is not supported on this instance."); + } + + ClientWebSocket newSession = await _connectionFactory(sessionResumption, cancellationToken); + ClientWebSocket oldSession; + + await _sendLock.WaitAsync(cancellationToken); + try + { + oldSession = _clientWebSocket; + _clientWebSocket = newSession; + } + finally + { + _sendLock.Release(); + } + + try + { + if (oldSession.State == WebSocketState.Open) + { + await oldSession.CloseAsync(WebSocketCloseStatus.NormalClosure, "Session resumed", CancellationToken.None); + } + } + catch (Exception) + { + // Ignore errors when closing the old socket. + } + } + /// /// Close the `LiveSession`. /// diff --git a/firebaseai/src/LiveSessionResponse.cs b/firebaseai/src/LiveSessionResponse.cs index a8c853a2..6f7df8c1 100644 --- a/firebaseai/src/LiveSessionResponse.cs +++ b/firebaseai/src/LiveSessionResponse.cs @@ -145,6 +145,10 @@ private LiveSessionResponse(ILiveSessionMessage liveSessionMessage) { return new LiveSessionResponse(LiveSessionGoingAway.FromJson(goAway)); } + else if (jsonDict.TryParseValue("sessionResumptionUpdate", out Dictionary sessionResumptionUpdate)) + { + return new LiveSessionResponse(LiveSessionResumptionUpdate.FromJson(sessionResumptionUpdate)); + } else { // TODO: Determine if we want to log this, or just ignore it? @@ -371,4 +375,63 @@ internal static Transcription FromJson(Dictionary jsonDict) } } + /// + /// An update of the session resumption state. + /// + public readonly struct LiveSessionResumptionUpdate : ILiveSessionMessage + { + /// + /// The new handle that represents the state that can be resumed. Empty if + /// `resumable` is false. + /// + public readonly string NewHandle { get; } + + /// + /// Indicates if the session can be resumed at this point. + /// + public readonly bool? Resumable { get; } + + /// + /// The index of the last client message that is included in the state + /// represented by this update. + /// + public readonly int? LastConsumedClientMessageIndex { get; } + + private LiveSessionResumptionUpdate(string newHandle, bool? resumable, int? lastConsumedClientMessageIndex) + { + NewHandle = newHandle; + Resumable = resumable; + LastConsumedClientMessageIndex = lastConsumedClientMessageIndex; + } + + /// + /// Intended for internal use only. + /// This method is used for deserializing JSON responses and should not be called directly. + /// + internal static LiveSessionResumptionUpdate FromJson(Dictionary jsonDict) + { + string newHandle = null; + if (jsonDict.TryGetValue("newHandle", out object handleObj) && handleObj is string strHandle) + { + newHandle = strHandle; + } + + bool? resumable = null; + if (jsonDict.TryGetValue("resumable", out object resumableObj) && resumableObj is bool bResumable) + { + resumable = bResumable; + } + + int? lastConsumedClientMessageIndex = null; + if (jsonDict.TryGetValue("lastConsumedClientMessageIndex", out object objIndex)) + { + try { + lastConsumedClientMessageIndex = System.Convert.ToInt32(objIndex); + } catch { /* ignore */ } + } + + return new LiveSessionResumptionUpdate(newHandle, resumable, lastConsumedClientMessageIndex); + } + } + } diff --git a/firebaseai/src/SessionResumptionConfig.cs b/firebaseai/src/SessionResumptionConfig.cs new file mode 100644 index 00000000..41bd56d7 --- /dev/null +++ b/firebaseai/src/SessionResumptionConfig.cs @@ -0,0 +1,46 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; + +namespace Firebase.AI +{ + /// + /// Configuration for the session resumption mechanism. + /// + public class SessionResumptionConfig + { + /// + /// The session resumption handle of the previous session to restore. + /// + public string Handle { get; } + + public SessionResumptionConfig(string handle = null) + { + Handle = handle; + } + + internal Dictionary ToJson() + { + var dict = new Dictionary(); + if (!string.IsNullOrEmpty(Handle)) + { + dict["handle"] = Handle; + } + return dict; + } + } +}