From a9d5c4b86e2816d2a4b0f42f9b6d16079613b5e2 Mon Sep 17 00:00:00 2001 From: slam Date: Wed, 8 Jan 2025 18:26:07 +0800 Subject: [PATCH] refactor(gemini-api): adjust message content handling Refactored the handling of message parts by introducing specific `TextPart` class usage instead of generic `Part`. Updated related functions and tests to accommodate these changes, ensuring clarity in message conversion processes and enhancing type safety. --- .../api/GenerateContentResponse.kt | 4 +- .../com/tddworks/gemini/api/GeminiITest.kt | 331 +++++++++++++++++- .../api/textGeneration/api/Extensions.kt | 55 +-- .../api/textGeneration/api/ExtensionsTest.kt | 21 +- 4 files changed, 379 insertions(+), 32 deletions(-) diff --git a/gemini-client/gemini-client-core/src/commonMain/kotlin/com/tddworks/gemini/api/textGeneration/api/GenerateContentResponse.kt b/gemini-client/gemini-client-core/src/commonMain/kotlin/com/tddworks/gemini/api/textGeneration/api/GenerateContentResponse.kt index 5aad699..01edfab 100644 --- a/gemini-client/gemini-client-core/src/commonMain/kotlin/com/tddworks/gemini/api/textGeneration/api/GenerateContentResponse.kt +++ b/gemini-client/gemini-client-core/src/commonMain/kotlin/com/tddworks/gemini/api/textGeneration/api/GenerateContentResponse.kt @@ -3,7 +3,7 @@ package com.tddworks.gemini.api.textGeneration.api import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable -internal typealias Base64 = String +internal typealias Base64String = String @Serializable data class GenerateContentResponse( @@ -55,7 +55,7 @@ sealed interface Part { @Serializable data class InlineData( @SerialName("mime_type") val mimeType: String, - val data: Base64 + val data: Base64String ) } } diff --git a/gemini-client/gemini-client-core/src/jvmTest/kotlin/com/tddworks/gemini/api/GeminiITest.kt b/gemini-client/gemini-client-core/src/jvmTest/kotlin/com/tddworks/gemini/api/GeminiITest.kt index eb23a77..85a0099 100644 --- a/gemini-client/gemini-client-core/src/jvmTest/kotlin/com/tddworks/gemini/api/GeminiITest.kt +++ b/gemini-client/gemini-client-core/src/jvmTest/kotlin/com/tddworks/gemini/api/GeminiITest.kt @@ -11,10 +11,12 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable import org.koin.test.junit5.AutoCloseKoinTest import kotlin.time.Duration.Companion.seconds +import java.util.* @EnabledIfEnvironmentVariable(named = "GEMINI_API_KEY", matches = ".+") class GeminiITest : AutoCloseKoinTest() { + fun ByteArray.toBase64() = Base64.getEncoder().encodeToString(this) @BeforeEach fun setUp() { @@ -25,6 +27,35 @@ class GeminiITest : AutoCloseKoinTest() { )) } + @Test + fun `should return generateContent response with image`() = runTest { + val gemini = getInstance() + val image = javaClass.classLoader.getResource("origami.png")?.readBytes() + ?.toBase64() + val response = gemini.generateContent( + GenerateContentRequest( + contents = listOf( + Content( + parts = listOf( + Part.TextPart(text = "Detect items, with no more than 20 items. Output a json list where each entry contains the 2D bounding box in \"box_2d\" and a text label in \"label\"."), + Part.InlineDataPart( + inlineData = Part.InlineDataPart.InlineData( + "image/png", + image!! + ) + ) + ) + ) + ), + stream = false + ) + ) + + println(response) + + assertNotNull(response) + } + @Test fun `should return stream generateContent response`() = runTest { val gemini = getInstance() @@ -53,4 +84,302 @@ class GeminiITest : AutoCloseKoinTest() { assertNotNull(response) } -} \ No newline at end of file +} + +// jpg +//[ +// { +// "box_2d": { +// "x_min": 346, +// "y_min": 345, +// "x_max": 500, +// "y_max": 507 +// }, +// "name": "origami fox" +// }, +// { +// "box_2d": { +// "x_min": 783, +// "y_min": 338, +// "x_max": 999, +// "y_max": 472 +// }, +// "name": "origami armadillo" +// } +//] + + +/** + * png + * [ + * { + * "box_2d": { + * "x_min": 348, + * "y_min": 343, + * "x_max": 500, + * "y_max": 506 + * }, + * "label": "origami fox" + * }, + * { + * "box_2d": { + * "x_min": 506, + * "y_min": 338, + * "x_max": 655, + * "y_max": 473 + * }, + * "label": "origami armadillo" + * }, + * { + * "box_2d": { + * "x_min": 436, + * "y_min": 474, + * "x_max": 540, + * "y_max": 502 + * }, + * "label": "window sill" + * }, + * { + * "box_2d": { + * "x_min": 406, + * "y_min": 492, + * "x_max": 999, + * "y_max": 999 + * }, + * "label": "window" + * }, + * { + * "box_2d": { + * "x_min": 0, + * "y_min": 0, + * "x_max": 500, + * "y_max": 999 + * }, + * "label": "wall" + * }, + * { + * "box_2d": { + * "x_min": 461, + * "y_min": 0, + * "x_max": 616, + * "y_max": 999 + * }, + * "label": "window frame" + * }, + * { + * "box_2d": { + * "x_min": 474, + * "y_min": 40, + * "x_max": 670, + * "y_max": 999 + * }, + * "label": "window" + * }, + * { + * "box_2d": { + * "x_min": 627, + * "y_min": 645, + * "x_max": 999, + * "y_max": 999 + * }, + * "label": "house" + * }, + * { + * "box_2d": { + * "x_min": 660, + * "y_min": 675, + * "x_max": 765, + * "y_max": 770 + * }, + * "label": "bush" + * }, + * { + * "box_2d": { + * "x_min": 770, + * "y_min": 700, + * "x_max": 999, + * "y_max": 999 + * }, + * "label": "trees" + * }, + * { + * "box_2d": { + * "x_min": 640, + * "y_min": 800, + * "x_max": 999, + * "y_max": 999 + * }, + * "label": "ground" + * }, + * { + * "box_2d": { + * "x_min": 450, + * "y_min": 0, + * "x_max": 999, + * "y_max": 999 + * }, + * "label": "sky" + * }, + * { + * "box_2d": { + * "x_min": 0, + * "y_min": 450, + * "x_max": 325, + * "y_max": 700 + * }, + * "label": "shadow" + * } + * ] + */ + +/** + * [ + * { + * "box_2d": { + * "x_min": 343, + * "y_min": 345, + * "x_max": 500, + * "y_max": 510 + * }, + * "label": "origami fox" + * }, + * { + * "box_2d": { + * "x_min": 525, + * "y_min": 330, + * "x_max": 681, + * "y_max": 472 + * }, + * "label": "origami armadillo" + * }, + * { + * "box_2d": { + * "x_min": 412, + * "y_min": 435, + * "x_max": 630, + * "y_max": 523 + * }, + * "label": "windowsill" + * }, + * { + * "box_2d": { + * "x_min": 460, + * "y_min": 495, + * "x_max": 999, + * "y_max": 999 + * }, + * "label": "window" + * }, + * { + * "box_2d": { + * "x_min": 0, + * "y_min": 430, + * "x_max": 385, + * "y_max": 999 + * }, + * "label": "wall" + * }, + * { + * "box_2d": { + * "x_min": 424, + * "y_min": 0, + * "x_max": 673, + * "y_max": 505 + * }, + * "label": "window frame" + * }, + * { + * "box_2d": { + * "x_min": 440, + * "y_min": 450, + * "x_max": 700, + * "y_max": 490 + * }, + * "label": "window ledge" + * }, + * { + * "box_2d": { + * "x_min": 450, + * "y_min": 480, + * "x_max": 999, + * "y_max": 999 + * }, + * "label": "glass" + * }, + * { + * "box_2d": { + * "x_min": 660, + * "y_min": 660, + * "x_max": 850, + * "y_max": 800 + * }, + * "label": "house" + * }, + * { + * "box_2d": { + * "x_min": 670, + * "y_min": 750, + * "x_max": 999, + * "y_max": 999 + * }, + * "label": "yard" + * }, + * { + * "box_2d": { + * "x_min": 0, + * "y_min": 400, + * "x_max": 350, + * "y_max": 600 + * }, + * "label": "shadow" + * } + * ] + */ + + +//[ +// {"box_2d": [474,360,731,624], "name": "lens"}, +// {"box_2d": [12,308,487,643], "name": "lens hood"}, +// {"box_2d": [253,308,487,643], "name": "lens hood"}, +// {"box_2d": [376,358,731,624], "name": "camera lens"}, +// {"box_2d": [539,6,634,164], "name": "clamp"}, +// {"box_2d": [500,781,595,979], "name": "clamp"}, +// {"box_2d": [421,378,653,589], "name": "lens support"}, +// {"box_2d": [725,296,885,767], "name": "camera"}, +// {"box_2d": [743,596,818,656], "name": "dial"}, +// {"box_2d": [757,690,839,754], "name": "dial"}, +// {"box_2d": [725,296,885,571], "name": "camera body"}, +// {"box_2d": [474,360,731,624], "name": "lens barrel"}, +// {"box_2d": [421,378,653,589], "name": "lens collar"}, +// {"box_2d": [725,296,885,767], "name": "camera with lens"}, +// {"box_2d": [329,43,948,1000], "name": "table"}, +// {"box_2d": [725,296,885,767], "name": "camera with lens and accessories"}, +// {"box_2d": [725,296,885,767], "name": "camera with lens and tripod mount"}, +// {"box_2d": [725,296,885,767], "name": "camera with lens and tripod"}, +// {"box_2d": [725,296,885,767], "name": "camera with lens and tripod accessories"}, +// {"box_2d": [725,296,885,767], "name": "camera with lens and tripod and accessories"} +//] + + +//[ +// {"box_2d": [474,360,731,624], "name": "镜头"}, +// {"box_2d": [12,308,487,643], "name": "遮光罩"}, +// {"box_2d": [253,308,487,643], "name": "遮光罩"}, +// {"box_2d": [376,358,731,624], "name": "相机镜头"}, +// {"box_2d": [539,6,634,164], "name": "夹子"}, +// {"box_2d": [500,781,595,979], "name": "夹子"}, +// {"box_2d": [421,378,653,589], "name": "镜头支架"}, +// {"box_2d": [725,296,885,767], "name": "相机"}, +// {"box_2d": [743,596,818,656], "name": "拨盘"}, +// {"box_2d": [757,690,839,754], "name": "拨盘"}, +// {"box_2d": [725,296,885,571], "name": "相机机身"}, +// {"box_2d": [474,360,731,624], "name": "镜头筒"}, +// {"box_2d": [421,378,653,589], "name": "镜头环"}, +// {"box_2d": [725,296,885,767], "name": "带镜头的相机"}, +// {"box_2d": [329,43,948,1000], "name": "桌子"}, +// {"box_2d": [725,296,885,767], "name": "带镜头的相机和配件"}, +// {"box_2d": [725,296,885,767], "name": "带镜头的相机和三脚架底座"}, +// {"box_2d": [725,296,885,767], "name": "带镜头的相机和三脚架"}, +// {"box_2d": [725,296,885,767], "name": "带镜头的相机和三脚架配件"}, +// {"box_2d": [725,296,885,767], "name": "带镜头的相机和三脚架及配件"} +//] \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/gemini/api/textGeneration/api/Extensions.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/gemini/api/textGeneration/api/Extensions.kt index 99eb19c..20eb6e2 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/gemini/api/textGeneration/api/Extensions.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/gemini/api/textGeneration/api/Extensions.kt @@ -14,17 +14,27 @@ import com.tddworks.openai.api.chat.api.Role as OpenAIRole fun GenerateContentResponse.toOpenAIChatCompletion(): OpenAIChatCompletion { - val id = "chatcmpl-gemini-123" + val completionId = "chatcmpl-gemini-123" + val creationTimestamp = 1L + + val chatChoices = candidates.mapIndexed { index, candidate -> + val textPart = candidate.content.parts.first() as Part.TextPart + val message = ChatMessage.assistant(textPart.text) + val reason = candidate.finishReason?.let { FinishReason(it) } + + ChatChoice( + message = message, + index = index, + finishReason = reason + ) + } + return ChatCompletion( - id = id, - created = 1L, + id = completionId, + created = creationTimestamp, model = modelVersion, - choices = candidates.mapIndexed { index, candidate -> - ChatChoice( - message = ChatMessage.assistant(candidate.content.parts.first().text), - index = index, - finishReason = candidate.finishReason?.let { FinishReason(it) }) - }) + choices = chatChoices + ) } @@ -32,15 +42,16 @@ fun GenerateContentResponse.toOpenAIChatCompletionChunk(): OpenAIChatCompletionC val id = "chatcmpl-gemini-123" val created = 1L - val chatChunkList = listOf( - ChatChunk( - index = 0, - delta = ChatDelta( - role = OpenAIRole.Assistant, - content = candidates.firstOrNull()?.content?.parts?.firstOrNull()?.text - ), - finishReason = candidates.firstOrNull()?.finishReason, - ) + val firstCandidate = candidates.firstOrNull() + val firstTextPart = firstCandidate?.content?.parts?.firstOrNull() as? Part.TextPart + + val chatChunk = ChatChunk( + index = 0, + delta = ChatDelta( + role = OpenAIRole.Assistant, + content = firstTextPart?.text + ), + finishReason = firstCandidate?.finishReason ) return OpenAIChatCompletionChunk( @@ -48,7 +59,7 @@ fun GenerateContentResponse.toOpenAIChatCompletionChunk(): OpenAIChatCompletionC `object` = "gemini-chunk", created = created, model = modelVersion, - choices = chatChunkList + choices = listOf(chatChunk) ) } @@ -79,13 +90,13 @@ private fun OpenAIMessage.toGeminiMessageOrNull(): Content? { private fun OpenAIUserMessage.toGeminiMessage(): Content { return Content( - parts = listOf(Part(text = content)), role = OpenAIRole.User.name + parts = listOf(Part.TextPart(text = content)), role = OpenAIRole.User.name ) } private fun OpenAIAssistantMessage.toGeminiMessage(): Content { return Content( - parts = listOf(Part(text = content)), role = "model" + parts = listOf(Part.TextPart(text = content)), role = "model" ) } @@ -106,6 +117,6 @@ private fun OpenAIAssistantMessage.toGeminiMessage(): Content { */ private fun OpenAISystemMessage.toGeminiMessage(): Content { return Content( - parts = listOf(Part(text = content)) + parts = listOf(Part.TextPart(text = content)) ) } \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/gemini/api/textGeneration/api/ExtensionsTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/gemini/api/textGeneration/api/ExtensionsTest.kt index c0f9f74..d008f2e 100644 --- a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/gemini/api/textGeneration/api/ExtensionsTest.kt +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/gemini/api/textGeneration/api/ExtensionsTest.kt @@ -17,7 +17,7 @@ class ExtensionsTest { candidates = listOf( Candidate( content = Content( - parts = listOf(Part("some-text")), + parts = listOf(Part.TextPart("some-text")), role = "model" ) ) @@ -30,7 +30,8 @@ class ExtensionsTest { modelVersion = "gemini-1.5-flash" ) - val openAIChatCompletionChunk = generateContentResponse.toOpenAIChatCompletionChunk() + val openAIChatCompletionChunk = + generateContentResponse.toOpenAIChatCompletionChunk() assertEquals("gemini-1.5-flash", openAIChatCompletionChunk.model) assertEquals(1, openAIChatCompletionChunk.choices.size) @@ -46,7 +47,7 @@ class ExtensionsTest { candidates = listOf( Candidate( content = Content( - parts = listOf(Part("some-text")), + parts = listOf(Part.TextPart("some-text")), role = "model" ) ) @@ -74,7 +75,7 @@ class ExtensionsTest { candidates = listOf( Candidate( content = Content( - parts = listOf(Part("some-text")), + parts = listOf(Part.TextPart("some-text")), role = "model" ), finishReason = "STOP" @@ -117,14 +118,20 @@ class ExtensionsTest { assertEquals(GeminiModel.GEMINI_1_5_FLASH, generateContentRequest.model) assertEquals( "How are you?", - generateContentRequest.systemInstruction?.parts?.get(0)?.text + (generateContentRequest.systemInstruction?.parts?.get(0) as? Part.TextPart)?.text ) assertNull(generateContentRequest.systemInstruction?.role) assertEquals(2, generateContentRequest.contents.size) assertEquals("user", generateContentRequest.contents[0].role) - assertEquals("Hello", generateContentRequest.contents[0].parts[0].text) + assertEquals( + "Hello", + (generateContentRequest.contents[0].parts[0] as Part.TextPart).text + ) assertEquals("model", generateContentRequest.contents[1].role) - assertEquals("Hi there", generateContentRequest.contents[1].parts[0].text) + assertEquals( + "Hi there", + (generateContentRequest.contents[1].parts[0] as Part.TextPart).text + ) } } \ No newline at end of file