Skip to content

Commit 57e73c9

Browse files
committed
⚡️: refactor model provider
1 parent bd785c6 commit 57e73c9

11 files changed

Lines changed: 1070 additions & 215 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ google-services.json
4141
/.kotlin/
4242
/app/src/commonMain/kotlin/xyz/junerver/composehooks/secret.kt
4343
docs/blog/**
44+
/.claude/
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
package xyz.junerver.compose.ai.usechat
2+
3+
import kotlinx.serialization.SerialName
4+
import kotlinx.serialization.Serializable
5+
6+
/*
7+
Description: Anthropic API request/response models
8+
Author: Junerver
9+
Date: 2024
10+
Email: junerver@gmail.com
11+
Version: v2.0
12+
*/
13+
14+
// region Request Models
15+
16+
/**
17+
* Request body for Anthropic messages API.
18+
*
19+
* @see <a href="https://docs.anthropic.com/en/api/messages">Anthropic Messages API</a>
20+
*/
21+
@Serializable
22+
internal data class AnthropicRequest(
23+
val model: String,
24+
val messages: List<AnthropicMessage>,
25+
@SerialName("max_tokens")
26+
val maxTokens: Int = 4096,
27+
val stream: Boolean = true,
28+
val system: String? = null,
29+
val temperature: Float? = null,
30+
)
31+
32+
/**
33+
* Message format for Anthropic API.
34+
*/
35+
@Serializable
36+
internal data class AnthropicMessage(
37+
val role: String,
38+
val content: String,
39+
)
40+
41+
// endregion
42+
43+
// region Response Models (Non-streaming)
44+
45+
/**
46+
* Response body for Anthropic messages API.
47+
*/
48+
@Serializable
49+
internal data class AnthropicResponse(
50+
val id: String,
51+
val type: String,
52+
val role: String,
53+
val content: List<AnthropicContentBlock>,
54+
val model: String,
55+
@SerialName("stop_reason")
56+
val stopReason: String? = null,
57+
@SerialName("stop_sequence")
58+
val stopSequence: String? = null,
59+
val usage: AnthropicUsage,
60+
)
61+
62+
/**
63+
* Content block in Anthropic response.
64+
*/
65+
@Serializable
66+
internal data class AnthropicContentBlock(
67+
val type: String,
68+
val text: String? = null,
69+
)
70+
71+
/**
72+
* Token usage in Anthropic response.
73+
*/
74+
@Serializable
75+
internal data class AnthropicUsage(
76+
@SerialName("input_tokens")
77+
val inputTokens: Int = 0,
78+
@SerialName("output_tokens")
79+
val outputTokens: Int = 0,
80+
)
81+
82+
// endregion
83+
84+
// region Streaming Response Models
85+
86+
/**
87+
* Streaming event from Anthropic API.
88+
*
89+
* Anthropic uses Server-Sent Events with different event types:
90+
* - `message_start`: Initial message metadata
91+
* - `content_block_start`: Start of a content block
92+
* - `content_block_delta`: Text delta in a content block
93+
* - `content_block_stop`: End of a content block
94+
* - `message_delta`: Final message metadata (stop_reason, usage)
95+
* - `message_stop`: End of message
96+
*
97+
* @see <a href="https://docs.anthropic.com/en/api/messages-streaming">Anthropic Streaming</a>
98+
*/
99+
@Serializable
100+
internal data class AnthropicStreamEvent(
101+
val type: String,
102+
val index: Int? = null,
103+
val delta: AnthropicDelta? = null,
104+
val usage: AnthropicStreamUsage? = null,
105+
@SerialName("content_block")
106+
val contentBlock: AnthropicContentBlock? = null,
107+
val message: AnthropicStreamMessage? = null,
108+
)
109+
110+
/**
111+
* Delta content in streaming response.
112+
*/
113+
@Serializable
114+
internal data class AnthropicDelta(
115+
val type: String? = null,
116+
val text: String? = null,
117+
@SerialName("stop_reason")
118+
val stopReason: String? = null,
119+
@SerialName("stop_sequence")
120+
val stopSequence: String? = null,
121+
)
122+
123+
/**
124+
* Usage in streaming response.
125+
*/
126+
@Serializable
127+
internal data class AnthropicStreamUsage(
128+
@SerialName("input_tokens")
129+
val inputTokens: Int? = null,
130+
@SerialName("output_tokens")
131+
val outputTokens: Int? = null,
132+
)
133+
134+
/**
135+
* Message metadata in message_start event.
136+
*/
137+
@Serializable
138+
internal data class AnthropicStreamMessage(
139+
val id: String? = null,
140+
val type: String? = null,
141+
val role: String? = null,
142+
val model: String? = null,
143+
val usage: AnthropicStreamUsage? = null,
144+
)
145+
146+
// endregion
147+
148+
// region Error Response
149+
150+
/**
151+
* Error response from Anthropic API.
152+
*/
153+
@Serializable
154+
internal data class AnthropicErrorResponse(
155+
val type: String,
156+
val error: AnthropicError,
157+
)
158+
159+
@Serializable
160+
internal data class AnthropicError(
161+
val type: String,
162+
val message: String,
163+
)
164+
165+
// endregion

ai/src/commonMain/kotlin/xyz/junerver/compose/ai/usechat/ChatClient.kt

Lines changed: 26 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ import kotlinx.coroutines.flow.flow
2727
import kotlinx.serialization.json.Json
2828

2929
/*
30-
Description: Ktor-based HTTP client for OpenAI chat completions
30+
Description: Ktor-based HTTP client for multi-provider chat completions
3131
Author: Junerver
3232
Date: 2024
3333
Email: junerver@gmail.com
34-
Version: v1.0
34+
Version: v2.0
3535
*/
3636

3737
/**
3838
* Represents a streaming event from the chat API.
3939
*/
40-
internal sealed class StreamEvent {
40+
sealed class StreamEvent {
4141
data class Delta(
4242
val content: String,
4343
val role: String? = null,
@@ -51,10 +51,11 @@ internal sealed class StreamEvent {
5151
}
5252

5353
/**
54-
* HTTP client for interacting with OpenAI-compatible chat APIs.
54+
* HTTP client for interacting with chat APIs.
55+
*
56+
* Supports multiple providers through the [ChatProvider] abstraction.
5557
*/
5658
internal class ChatClient(private val options: ChatOptions) {
57-
5859
private val json = Json {
5960
ignoreUnknownKeys = true
6061
isLenient = true
@@ -84,13 +85,7 @@ internal class ChatClient(private val options: ChatOptions) {
8485
*/
8586
suspend fun streamChat(messages: List<Message>): Flow<StreamEvent> = flow {
8687
try {
87-
val requestBody = ChatCompletionRequest(
88-
model = options.model,
89-
messages = messages.toRequestMessages(),
90-
stream = true,
91-
temperature = options.temperature,
92-
maxTokens = options.maxTokens,
93-
)
88+
val requestBody = options.buildRequestBody(messages, stream = true)
9489

9590
httpClient.preparePost(options.buildEndpoint()) {
9691
// SSE streams need longer/no timeout
@@ -99,7 +94,10 @@ internal class ChatClient(private val options: ChatOptions) {
9994
socketTimeoutMillis = Long.MAX_VALUE
10095
}
10196
contentType(ContentType.Application.Json.withCharset(Charsets.UTF_8))
102-
header(HttpHeaders.Authorization, "Bearer ${options.apiKey}")
97+
// Use provider-specific auth headers
98+
options.buildAuthHeaders().forEach { (key, value) ->
99+
header(key, value)
100+
}
103101
header(HttpHeaders.Accept, "text/event-stream")
104102
header(HttpHeaders.CacheControl, "no-cache")
105103
header(HttpHeaders.Connection, "keep-alive")
@@ -120,8 +118,8 @@ internal class ChatClient(private val options: ChatOptions) {
120118
errorMessage = errorResponse.error.message,
121119
errorType = errorResponse.error.type,
122120
errorCode = errorResponse.error.code,
123-
)
124-
)
121+
),
122+
),
125123
)
126124
} catch (e: Exception) {
127125
emit(StreamEvent.Error(Exception("HTTP ${response.status.value}: $errorBody")))
@@ -133,37 +131,11 @@ internal class ChatClient(private val options: ChatOptions) {
133131
while (!channel.isClosedForRead) {
134132
val line = channel.readUTF8Line() ?: continue
135133

136-
if (line.isBlank()) continue
137-
138-
if (!line.startsWith("data: ")) continue
139-
140-
val data = line.removePrefix("data: ").trim()
141-
142-
if (data == "[DONE]") {
143-
emit(StreamEvent.Done)
144-
break
145-
}
146-
147-
try {
148-
val chunk = json.decodeFromString<ChatCompletionChunk>(data)
149-
val choice = chunk.choices?.firstOrNull()
150-
val delta = choice?.delta
151-
val content = delta?.content ?: ""
152-
val role = delta?.role
153-
val finishReason = choice?.finishReason
154-
155-
if (content.isNotEmpty() || role != null || finishReason != null) {
156-
emit(
157-
StreamEvent.Delta(
158-
content = content,
159-
role = role,
160-
finishReason = finishReason,
161-
usage = chunk.usage,
162-
)
163-
)
164-
}
165-
} catch (e: Exception) {
166-
// Skip malformed JSON chunks
134+
// Use provider-specific stream parsing
135+
val event = options.provider.parseStreamLine(line)
136+
if (event != null) {
137+
emit(event)
138+
if (event is StreamEvent.Done) break
167139
}
168140
}
169141
}
@@ -179,17 +151,14 @@ internal class ChatClient(private val options: ChatOptions) {
179151
* @return The complete assistant message
180152
*/
181153
suspend fun chat(messages: List<Message>): Message {
182-
val requestBody = ChatCompletionRequest(
183-
model = options.model,
184-
messages = messages.toRequestMessages(),
185-
stream = false,
186-
temperature = options.temperature,
187-
maxTokens = options.maxTokens,
188-
)
154+
val requestBody = options.buildRequestBody(messages, stream = false)
189155

190156
val response: HttpResponse = httpClient.post(options.buildEndpoint()) {
191157
contentType(ContentType.Application.Json.withCharset(Charsets.UTF_8))
192-
header(HttpHeaders.Authorization, "Bearer ${options.apiKey}")
158+
// Use provider-specific auth headers
159+
options.buildAuthHeaders().forEach { (key, value) ->
160+
header(key, value)
161+
}
193162
options.headers.forEach { (key, value) ->
194163
header(key, value)
195164
}
@@ -215,11 +184,9 @@ internal class ChatClient(private val options: ChatOptions) {
215184
}
216185

217186
val responseBody = response.bodyAsChannel().readUTF8Line() ?: throw Exception("Empty response")
218-
val completionResponse = json.decodeFromString<ChatCompletionResponse>(responseBody)
219-
val choice = completionResponse.choices.firstOrNull()
220-
?: throw Exception("No choices in response")
221-
222-
return Message.assistant(content = choice.message.content ?: "")
187+
// Use provider-specific response parsing
188+
val result = options.provider.parseResponse(responseBody)
189+
return result.message
223190
}
224191

225192
/**

ai/src/commonMain/kotlin/xyz/junerver/compose/ai/usechat/ChatOptions.kt

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import xyz.junerver.compose.hooks.Options
1111
Author: Junerver
1212
Date: 2024
1313
Email: junerver@gmail.com
14-
Version: v1.0
14+
Version: v2.0
1515
*/
1616

1717
/**
@@ -25,9 +25,8 @@ typealias OnStreamCallback = (delta: String) -> Unit
2525
/**
2626
* Configuration options for the useChat hook.
2727
*
28-
* @property baseUrl The base URL of the OpenAI-compatible API endpoint
29-
* @property apiKey The API key for authentication
30-
* @property model The model to use for chat completions
28+
* @property provider The chat provider to use (includes apiKey, baseUrl, model)
29+
* @property model Override the provider's default model (null = use provider default)
3130
* @property systemPrompt Optional system prompt to prepend to conversations
3231
* @property initialMessages Initial messages to populate the chat
3332
* @property temperature Sampling temperature (0-2), higher values make output more random
@@ -42,9 +41,8 @@ typealias OnStreamCallback = (delta: String) -> Unit
4241
*/
4342
@Stable
4443
data class ChatOptions internal constructor(
45-
var baseUrl: String = "https://api.openai.com/v1",
46-
var apiKey: String = "",
47-
var model: String = "gpt-3.5-turbo",
44+
var provider: ChatProvider = Providers.OpenAI(apiKey = ""),
45+
var model: String? = null,
4846
var systemPrompt: String? = null,
4947
var initialMessages: List<Message> = emptyList(),
5048
var temperature: Float? = null,
@@ -60,11 +58,34 @@ data class ChatOptions internal constructor(
6058
) {
6159
companion object : Options<ChatOptions>(::ChatOptions)
6260

61+
/**
62+
* The effective model (override or provider default).
63+
*/
64+
val effectiveModel: String
65+
get() = model ?: provider.defaultModel
66+
6367
/**
6468
* Builds the full API endpoint URL for chat completions.
6569
*/
6670
internal fun buildEndpoint(): String {
67-
val base = baseUrl.trimEnd('/')
68-
return "$base/chat/completions"
71+
val base = provider.baseUrl.trimEnd('/')
72+
return "$base${provider.chatEndpoint}"
6973
}
74+
75+
/**
76+
* Builds authentication headers using the provider.
77+
*/
78+
internal fun buildAuthHeaders(): Map<String, String> = provider.buildAuthHeaders()
79+
80+
/**
81+
* Builds the request body using the provider.
82+
*/
83+
internal fun buildRequestBody(messages: List<Message>, stream: Boolean): String = provider.buildRequestBody(
84+
messages = messages,
85+
model = effectiveModel,
86+
stream = stream,
87+
temperature = temperature,
88+
maxTokens = maxTokens,
89+
systemPrompt = systemPrompt,
90+
)
7091
}

0 commit comments

Comments
 (0)