Skip to content

Commit a3f7b35

Browse files
committed
⚡️ [AI]: Refactor message model to support multimodal and tools
- Replace simple Message class with polymorphic ChatMessage design - Add support for images, files, and tool calls - Implement reasoning content for extended thinking models - Update providers to handle new message format - Inspired by Vercel AI SDK's message structure
1 parent 5bfbf80 commit a3f7b35

8 files changed

Lines changed: 938 additions & 110 deletions

File tree

ai/src/commonMain/kotlin/xyz/junerver/compose/ai/usechat/AnthropicModels.kt

Lines changed: 277 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
package xyz.junerver.compose.ai.usechat
22

3+
import kotlinx.serialization.KSerializer
34
import kotlinx.serialization.SerialName
45
import kotlinx.serialization.Serializable
6+
import kotlinx.serialization.descriptors.SerialDescriptor
7+
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
8+
import kotlinx.serialization.encoding.Decoder
9+
import kotlinx.serialization.encoding.Encoder
10+
import kotlinx.serialization.json.JsonEncoder
11+
import kotlinx.serialization.json.JsonObject
12+
import kotlinx.serialization.json.JsonPrimitive
13+
import kotlinx.serialization.json.buildJsonArray
14+
import kotlinx.serialization.json.buildJsonObject
15+
import kotlinx.serialization.json.put
516

617
/*
7-
Description: Anthropic API request/response models
18+
Description: Anthropic API request/response models (multimodal + tool use support)
819
Author: Junerver
920
Date: 2024
1021
Email: junerver@gmail.com
11-
Version: v2.0
22+
Version: v4.0
1223
*/
1324

1425
// region Request Models
@@ -30,14 +41,168 @@ internal data class AnthropicRequest(
3041
)
3142

3243
/**
33-
* Message format for Anthropic API.
44+
* Message format for Anthropic API with multimodal and tool support.
3445
*/
3546
@Serializable
3647
internal data class AnthropicMessage(
3748
val role: String,
38-
val content: String,
49+
@Serializable(with = AnthropicMessageContentSerializer::class)
50+
val content: AnthropicMessageContent,
3951
)
4052

53+
/**
54+
* Content for Anthropic messages - can be text string or array of content blocks.
55+
*/
56+
internal sealed class AnthropicMessageContent {
57+
data class Text(val text: String) : AnthropicMessageContent()
58+
59+
data class Parts(val parts: List<AnthropicContentPart>) : AnthropicMessageContent()
60+
}
61+
62+
/**
63+
* Content block types for Anthropic multimodal messages.
64+
*/
65+
internal sealed class AnthropicContentPart {
66+
data class Text(val text: String) : AnthropicContentPart()
67+
68+
data class Image(val source: AnthropicImageSource) : AnthropicContentPart()
69+
70+
data class Document(val source: AnthropicDocumentSource, val cacheControl: AnthropicCacheControl? = null) : AnthropicContentPart()
71+
72+
data class ToolUse(val id: String, val name: String, val input: JsonObject) : AnthropicContentPart()
73+
74+
data class ToolResult(val toolUseId: String, val content: String, val isError: Boolean = false) : AnthropicContentPart()
75+
76+
data class Thinking(val thinking: String) : AnthropicContentPart()
77+
}
78+
79+
/**
80+
* Image source for Anthropic API.
81+
*/
82+
internal data class AnthropicImageSource(
83+
val type: String, // "base64" or "url"
84+
val mediaType: String = "",
85+
val data: String = "",
86+
val url: String = "",
87+
) {
88+
companion object {
89+
fun fromBase64(base64: String, mediaType: String) = AnthropicImageSource(type = "base64", mediaType = mediaType, data = base64)
90+
91+
fun fromUrl(url: String) = AnthropicImageSource(type = "url", url = url)
92+
}
93+
}
94+
95+
/**
96+
* Document source for Anthropic API (PDF support).
97+
*/
98+
internal data class AnthropicDocumentSource(
99+
val type: String, // "base64"
100+
val mediaType: String, // "application/pdf"
101+
val data: String,
102+
)
103+
104+
/**
105+
* Cache control for Anthropic API.
106+
*/
107+
internal data class AnthropicCacheControl(
108+
val type: String = "ephemeral",
109+
)
110+
111+
/**
112+
* Custom serializer for AnthropicMessageContent.
113+
* Serializes Text as a plain string, Parts as a JSON array of content blocks.
114+
*/
115+
internal object AnthropicMessageContentSerializer : KSerializer<AnthropicMessageContent> {
116+
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("AnthropicMessageContent")
117+
118+
override fun serialize(encoder: Encoder, value: AnthropicMessageContent) {
119+
val jsonEncoder = encoder as JsonEncoder
120+
when (value) {
121+
is AnthropicMessageContent.Text -> jsonEncoder.encodeJsonElement(JsonPrimitive(value.text))
122+
is AnthropicMessageContent.Parts -> {
123+
val jsonArray = buildJsonArray {
124+
value.parts.forEach { part ->
125+
when (part) {
126+
is AnthropicContentPart.Text -> add(
127+
buildJsonObject {
128+
put("type", "text")
129+
put("text", part.text)
130+
},
131+
)
132+
is AnthropicContentPart.Image -> add(
133+
buildJsonObject {
134+
put("type", "image")
135+
put(
136+
"source",
137+
buildJsonObject {
138+
put("type", part.source.type)
139+
if (part.source.type == "base64") {
140+
put("media_type", part.source.mediaType)
141+
put("data", part.source.data)
142+
} else {
143+
put("url", part.source.url)
144+
}
145+
},
146+
)
147+
},
148+
)
149+
is AnthropicContentPart.Document -> add(
150+
buildJsonObject {
151+
put("type", "document")
152+
put(
153+
"source",
154+
buildJsonObject {
155+
put("type", part.source.type)
156+
put("media_type", part.source.mediaType)
157+
put("data", part.source.data)
158+
},
159+
)
160+
part.cacheControl?.let {
161+
put(
162+
"cache_control",
163+
buildJsonObject {
164+
put("type", it.type)
165+
},
166+
)
167+
}
168+
},
169+
)
170+
is AnthropicContentPart.ToolUse -> add(
171+
buildJsonObject {
172+
put("type", "tool_use")
173+
put("id", part.id)
174+
put("name", part.name)
175+
put("input", part.input)
176+
},
177+
)
178+
is AnthropicContentPart.ToolResult -> add(
179+
buildJsonObject {
180+
put("type", "tool_result")
181+
put("tool_use_id", part.toolUseId)
182+
put("content", part.content)
183+
if (part.isError) {
184+
put("is_error", true)
185+
}
186+
},
187+
)
188+
is AnthropicContentPart.Thinking -> add(
189+
buildJsonObject {
190+
put("type", "thinking")
191+
put("thinking", part.thinking)
192+
},
193+
)
194+
}
195+
}
196+
}
197+
jsonEncoder.encodeJsonElement(jsonArray)
198+
}
199+
}
200+
}
201+
202+
override fun deserialize(decoder: Decoder): AnthropicMessageContent =
203+
throw NotImplementedError("Deserialization not needed for request models")
204+
}
205+
41206
// endregion
42207

43208
// region Response Models (Non-streaming)
@@ -66,6 +231,10 @@ internal data class AnthropicResponse(
66231
internal data class AnthropicContentBlock(
67232
val type: String,
68233
val text: String? = null,
234+
val id: String? = null,
235+
val name: String? = null,
236+
val input: JsonObject? = null,
237+
val thinking: String? = null,
69238
)
70239

71240
/**
@@ -114,6 +283,9 @@ internal data class AnthropicStreamEvent(
114283
internal data class AnthropicDelta(
115284
val type: String? = null,
116285
val text: String? = null,
286+
val thinking: String? = null,
287+
@SerialName("partial_json")
288+
val partialJson: String? = null,
117289
@SerialName("stop_reason")
118290
val stopReason: String? = null,
119291
@SerialName("stop_sequence")
@@ -163,3 +335,104 @@ internal data class AnthropicError(
163335
)
164336

165337
// endregion
338+
339+
// region Internal Helpers
340+
341+
/**
342+
* Converts a list of ChatMessage to AnthropicMessage format for API calls.
343+
* Filters out system messages (handled separately) and supports multimodal content and tools.
344+
*/
345+
internal fun List<ChatMessage>.toAnthropicMessages(): List<AnthropicMessage> = filter { it !is SystemMessage }
346+
.map { msg ->
347+
when (msg) {
348+
is UserMessage -> AnthropicMessage(
349+
role = "user",
350+
content = msg.content.toAnthropicContent(),
351+
)
352+
is AssistantMessage -> {
353+
val parts = mutableListOf<AnthropicContentPart>()
354+
355+
// Add reasoning/thinking blocks first
356+
msg.content.filterIsInstance<ReasoningPart>().forEach {
357+
parts.add(AnthropicContentPart.Thinking(it.text))
358+
}
359+
360+
// Add text parts
361+
msg.content.filterIsInstance<TextPart>().forEach {
362+
parts.add(AnthropicContentPart.Text(it.text))
363+
}
364+
365+
// Add tool calls
366+
msg.toolCalls.forEach { tc ->
367+
parts.add(
368+
AnthropicContentPart.ToolUse(
369+
id = tc.toolCallId,
370+
name = tc.toolName,
371+
input = tc.args,
372+
),
373+
)
374+
}
375+
376+
AnthropicMessage(
377+
role = "assistant",
378+
content = if (parts.size == 1 && parts.first() is AnthropicContentPart.Text) {
379+
AnthropicMessageContent.Text((parts.first() as AnthropicContentPart.Text).text)
380+
} else {
381+
AnthropicMessageContent.Parts(parts)
382+
},
383+
)
384+
}
385+
is ToolMessage -> AnthropicMessage(
386+
role = "user",
387+
content = AnthropicMessageContent.Parts(
388+
msg.content.map { result ->
389+
AnthropicContentPart.ToolResult(
390+
toolUseId = result.toolCallId,
391+
content = result.result.toString(),
392+
isError = result.isError,
393+
)
394+
},
395+
),
396+
)
397+
is SystemMessage -> throw IllegalStateException("System messages should be filtered out")
398+
}
399+
}
400+
401+
/**
402+
* Converts UserContentPart list to Anthropic content format.
403+
*/
404+
private fun List<UserContentPart>.toAnthropicContent(): AnthropicMessageContent {
405+
// Single text part - use plain string for efficiency
406+
if (size == 1 && first() is TextPart) {
407+
return AnthropicMessageContent.Text((first() as TextPart).text)
408+
}
409+
410+
// Multiple parts or multimodal - use array format
411+
return AnthropicMessageContent.Parts(
412+
map { part ->
413+
when (part) {
414+
is TextPart -> AnthropicContentPart.Text(part.text)
415+
is ImagePart -> {
416+
val source = if (part.isUrl) {
417+
AnthropicImageSource.fromUrl(part.data)
418+
} else {
419+
AnthropicImageSource.fromBase64(part.data, part.mimeType)
420+
}
421+
AnthropicContentPart.Image(source)
422+
}
423+
is FilePart -> {
424+
// Anthropic supports PDF documents
425+
AnthropicContentPart.Document(
426+
source = AnthropicDocumentSource(
427+
type = "base64",
428+
mediaType = part.mimeType,
429+
data = part.data,
430+
),
431+
)
432+
}
433+
}
434+
},
435+
)
436+
}
437+
438+
// endregion

ai/src/commonMain/kotlin/xyz/junerver/compose/ai/usechat/ChatClient.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ internal class ChatClient(private val options: ChatOptions) {
6060
ignoreUnknownKeys = true
6161
isLenient = true
6262
encodeDefaults = true
63+
explicitNulls = false
6364
}
6465

6566
private val httpClient = HttpClient {
@@ -83,7 +84,7 @@ internal class ChatClient(private val options: ChatOptions) {
8384
* @param messages The list of messages to send
8485
* @return A Flow emitting StreamEvent objects
8586
*/
86-
suspend fun streamChat(messages: List<Message>): Flow<StreamEvent> = flow {
87+
suspend fun streamChat(messages: List<ChatMessage>): Flow<StreamEvent> = flow {
8788
try {
8889
val requestBody = options.buildRequestBody(messages, stream = true)
8990

@@ -150,7 +151,7 @@ internal class ChatClient(private val options: ChatOptions) {
150151
* @param messages The list of messages to send
151152
* @return The complete assistant message
152153
*/
153-
suspend fun chat(messages: List<Message>): Message {
154+
suspend fun chat(messages: List<ChatMessage>): AssistantMessage {
154155
val requestBody = options.buildRequestBody(messages, stream = false)
155156

156157
val response: HttpResponse = httpClient.post(options.buildEndpoint()) {

ai/src/commonMain/kotlin/xyz/junerver/compose/ai/usechat/ChatOptions.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import xyz.junerver.compose.hooks.Options
1818
* Callback type definitions for chat events
1919
*/
2020
typealias OnResponseCallback = (response: HttpResponse) -> Unit
21-
typealias OnFinishCallback = (message: Message, usage: ChatUsage?, finishReason: FinishReason?) -> Unit
21+
typealias OnFinishCallback = (message: ChatMessage, usage: ChatUsage?, finishReason: FinishReason?) -> Unit
2222
typealias OnErrorCallback = (error: Throwable) -> Unit
2323
typealias OnStreamCallback = (delta: String) -> Unit
2424

@@ -44,7 +44,7 @@ data class ChatOptions internal constructor(
4444
var provider: ChatProvider = Providers.OpenAI(apiKey = ""),
4545
var model: String? = null,
4646
var systemPrompt: String? = null,
47-
var initialMessages: List<Message> = emptyList(),
47+
var initialMessages: List<ChatMessage> = emptyList(),
4848
var temperature: Float? = null,
4949
var maxTokens: Int? = null,
5050
var timeout: Duration = 60.seconds,
@@ -80,7 +80,7 @@ data class ChatOptions internal constructor(
8080
/**
8181
* Builds the request body using the provider.
8282
*/
83-
internal fun buildRequestBody(messages: List<Message>, stream: Boolean): String = provider.buildRequestBody(
83+
internal fun buildRequestBody(messages: List<ChatMessage>, stream: Boolean): String = provider.buildRequestBody(
8484
messages = messages,
8585
model = effectiveModel,
8686
stream = stream,

0 commit comments

Comments
 (0)