Skip to content

Commit c2c82ab

Browse files
committed
🩹: streaming response
1 parent 46fb718 commit c2c82ab

2 files changed

Lines changed: 65 additions & 57 deletions

File tree

ai/src/commonMain/kotlin/xyz/junerver/compose/ai/usechat/ChatClient.kt

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import io.ktor.client.plugins.logging.SIMPLE
1010
import io.ktor.client.plugins.timeout
1111
import io.ktor.client.request.header
1212
import io.ktor.client.request.post
13+
import io.ktor.client.request.preparePost
1314
import io.ktor.client.request.setBody
1415
import io.ktor.client.statement.HttpResponse
1516
import io.ktor.client.statement.bodyAsChannel
@@ -91,7 +92,7 @@ internal class ChatClient(private val options: ChatOptions) {
9192
maxTokens = options.maxTokens,
9293
)
9394

94-
val response: HttpResponse = httpClient.post(options.buildEndpoint()) {
95+
httpClient.preparePost(options.buildEndpoint()) {
9596
// SSE streams need longer/no timeout
9697
timeout {
9798
requestTimeoutMillis = Long.MAX_VALUE
@@ -106,64 +107,64 @@ internal class ChatClient(private val options: ChatOptions) {
106107
header(key, value)
107108
}
108109
setBody(requestBody)
109-
}
110+
}.execute { response ->
111+
options.onResponse?.invoke(response)
110112

111-
options.onResponse?.invoke(response)
112-
113-
if (!response.status.isSuccess()) {
114-
val errorBody = response.bodyAsChannel().readUTF8Line() ?: "Unknown error"
115-
try {
116-
val errorResponse = json.decodeFromString<OpenAIErrorResponse>(errorBody)
117-
emit(
118-
StreamEvent.Error(
119-
OpenAIException(
120-
errorMessage = errorResponse.error.message,
121-
errorType = errorResponse.error.type,
122-
errorCode = errorResponse.error.code,
113+
if (!response.status.isSuccess()) {
114+
val errorBody = response.bodyAsChannel().readUTF8Line() ?: "Unknown error"
115+
try {
116+
val errorResponse = json.decodeFromString<OpenAIErrorResponse>(errorBody)
117+
emit(
118+
StreamEvent.Error(
119+
OpenAIException(
120+
errorMessage = errorResponse.error.message,
121+
errorType = errorResponse.error.type,
122+
errorCode = errorResponse.error.code,
123+
)
123124
)
124125
)
125-
)
126-
} catch (e: Exception) {
127-
emit(StreamEvent.Error(Exception("HTTP ${response.status.value}: $errorBody")))
126+
} catch (e: Exception) {
127+
emit(StreamEvent.Error(Exception("HTTP ${response.status.value}: $errorBody")))
128+
}
129+
return@execute
128130
}
129-
return@flow
130-
}
131131

132-
val channel = response.bodyAsChannel()
133-
while (!channel.isClosedForRead) {
134-
val line = channel.readUTF8Line() ?: continue
132+
val channel = response.bodyAsChannel()
133+
while (!channel.isClosedForRead) {
134+
val line = channel.readUTF8Line() ?: continue
135135

136-
if (line.isBlank()) continue
136+
if (line.isBlank()) continue
137137

138-
if (!line.startsWith("data: ")) continue
138+
if (!line.startsWith("data: ")) continue
139139

140-
val data = line.removePrefix("data: ").trim()
140+
val data = line.removePrefix("data: ").trim()
141141

142-
if (data == "[DONE]") {
143-
emit(StreamEvent.Done)
144-
break
145-
}
146-
147-
try {
148-
val chunk = json.decodeFromString<ChatCompletionChunk>(data)
149-
val choice = chunk.choices?.firstOrNull()
150-
val delta = choice?.delta
151-
val content = delta?.content ?: ""
152-
val role = delta?.role
153-
val finishReason = choice?.finishReason
142+
if (data == "[DONE]") {
143+
emit(StreamEvent.Done)
144+
break
145+
}
154146

155-
if (content.isNotEmpty() || role != null || finishReason != null) {
156-
emit(
157-
StreamEvent.Delta(
158-
content = content,
159-
role = role,
160-
finishReason = finishReason,
161-
usage = chunk.usage,
147+
try {
148+
val chunk = json.decodeFromString<ChatCompletionChunk>(data)
149+
val choice = chunk.choices?.firstOrNull()
150+
val delta = choice?.delta
151+
val content = delta?.content ?: ""
152+
val role = delta?.role
153+
val finishReason = choice?.finishReason
154+
155+
if (content.isNotEmpty() || role != null || finishReason != null) {
156+
emit(
157+
StreamEvent.Delta(
158+
content = content,
159+
role = role,
160+
finishReason = finishReason,
161+
usage = chunk.usage,
162+
)
162163
)
163-
)
164+
}
165+
} catch (e: Exception) {
166+
// Skip malformed JSON chunks
164167
}
165-
} catch (e: Exception) {
166-
// Skip malformed JSON chunks
167168
}
168169
}
169170
} catch (e: Exception) {

ai/src/commonMain/kotlin/xyz/junerver/compose/ai/usechat/useChat.kt

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import kotlinx.collections.immutable.ImmutableList
99
import kotlinx.collections.immutable.toImmutableList
1010
import kotlinx.coroutines.Dispatchers
1111
import kotlinx.coroutines.flow.collect
12-
import kotlinx.coroutines.flow.flowOn
1312
import kotlinx.coroutines.flow.onEach
13+
import kotlinx.coroutines.withContext
1414
import xyz.junerver.compose.hooks.MutableRef
1515
import xyz.junerver.compose.hooks._useGetState
1616
import xyz.junerver.compose.hooks.useCancelableAsync
@@ -180,16 +180,18 @@ fun useChat(optionsOf: ChatOptions.() -> Unit = {}): ChatHolder {
180180
lastUsage = event.usage
181181
}
182182

183-
// Update assistant message with accumulated content
183+
// Update assistant message with accumulated content on Main thread
184184
val updatedMessage = currentAssistantMessageRef.current?.copy(
185185
content = accumulatedContent
186186
)
187187
if (updatedMessage != null) {
188188
currentAssistantMessageRef.current = updatedMessage
189-
val msgs = getMessages().toMutableList()
190-
if (msgs.isNotEmpty()) {
191-
msgs[msgs.lastIndex] = updatedMessage
192-
setMessages(msgs.toImmutableList())
189+
withContext(Dispatchers.Main) {
190+
val msgs = getMessages().toMutableList()
191+
if (msgs.isNotEmpty()) {
192+
msgs[msgs.lastIndex] = updatedMessage
193+
setMessages(msgs.toImmutableList())
194+
}
193195
}
194196
}
195197
}
@@ -206,18 +208,23 @@ fun useChat(optionsOf: ChatOptions.() -> Unit = {}): ChatHolder {
206208
}
207209

208210
is StreamEvent.Error -> {
209-
setError(event.error)
211+
withContext(Dispatchers.Main) {
212+
setError(event.error)
213+
}
210214
optionsRef.current.onError?.invoke(event.error)
211215
}
212216
}
213217
}
214-
?.flowOn(Dispatchers.Main)
215218
?.collect()
216219
} catch (e: Exception) {
217-
setError(e)
220+
withContext(Dispatchers.Main) {
221+
setError(e)
222+
}
218223
optionsRef.current.onError?.invoke(e)
219224
} finally {
220-
setIsLoading(false)
225+
withContext(Dispatchers.Main) {
226+
setIsLoading(false)
227+
}
221228
currentAssistantMessageRef.current = null
222229
}
223230
}

0 commit comments

Comments
 (0)