Skip to content

Commit

Permalink
feat(BE-215): ollama fim api support
Browse files Browse the repository at this point in the history
 - add missing ut
  • Loading branch information
hanrw committed Jun 19, 2024
1 parent bc61824 commit 4a5a8a3
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
package com.tddworks.openai.api.chat.api.vision

import kotlinx.serialization.*
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlin.jvm.JvmInline

@Serializable(with = VisionMessageContentSerializer::class)
Expand Down Expand Up @@ -53,21 +49,8 @@ sealed interface VisionMessageContent {
) : VisionMessageContent
}

object VisionMessageContentSerializer :
JsonContentPolymorphicSerializer<VisionMessageContent>(VisionMessageContent::class) {
override fun selectDeserializer(element: JsonElement): KSerializer<out VisionMessageContent> {
val type = element.jsonObject["type"]?.jsonPrimitive?.content

return when (type) {
ContentType.TEXT.value -> VisionMessageContent.TextContent.serializer()
ContentType.IMAGE.value -> VisionMessageContent.ImageContent.serializer()
else -> throw IllegalArgumentException("Unknown type")
}
}
}

@Serializable
class ImageUrl(
data class ImageUrl(
@SerialName("url")
val value: String,
@EncodeDefault(EncodeDefault.Mode.ALWAYS)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.tddworks.openai.api.chat.api.vision

import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive

object VisionMessageContentSerializer :
JsonContentPolymorphicSerializer<VisionMessageContent>(VisionMessageContent::class) {
override fun selectDeserializer(element: JsonElement): KSerializer<out VisionMessageContent> {
val typeObject = element.jsonObject["type"]
val jsonPrimitive = typeObject?.jsonPrimitive
val type = jsonPrimitive?.content

return when (type) {
ContentType.TEXT.value -> VisionMessageContent.TextContent.serializer()
ContentType.IMAGE.value -> VisionMessageContent.ImageContent.serializer()
else -> throw IllegalArgumentException("Unknown type")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,99 @@ import com.tddworks.openai.api.common.prettyJson
import kotlinx.serialization.encodeToString
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows

class VisionMessageContentTest {

@Test
fun `should throw IllegalArgumentException when type was not jsonPrimitive`() {
val json = """
{
"type": ["text"],
"text": "Hello, how may I assist you today?"
}
""".trimIndent()

assertThrows<IllegalArgumentException> {
prettyJson.decodeFromString(
VisionMessageContent.serializer(), json
)
}
}


@Test
fun `should throw IllegalArgumentException for unknown json`() {
val json = """
{
"text": "Hello, how may I assist you today?"
}
""".trimIndent()

assertThrows<IllegalArgumentException> {
prettyJson.decodeFromString(
VisionMessageContent.serializer(), json
)
}
}

@Test
fun `should throw IllegalArgumentException for unknown type`() {
val json = """
{
"type": "unknown",
"text": "Hello, how may I assist you today?"
}
""".trimIndent()

assertThrows<IllegalArgumentException> {
prettyJson.decodeFromString(
VisionMessageContent.serializer(), json
)
}
}

@Test
fun `should convert json to image VisionMessageContent`() {
val json = """
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image.jpg",
"detail": "auto"
}
}
""".trimIndent()

val visionMessageContent =
prettyJson.decodeFromString(VisionMessageContent.serializer(), json)

assertEquals(
VisionMessageContent.ImageContent(
imageUrl = ImageUrl("https://example.com/image.jpg")
), visionMessageContent
)
}

@Test
fun `should convert json to text VisionMessageContent`() {
val json = """
{
"type": "text",
"text": "Hello, how may I assist you today?"
}
""".trimIndent()

val visionMessageContent =
prettyJson.decodeFromString(VisionMessageContent.serializer(), json)

assertEquals(
VisionMessageContent.TextContent(content = "Hello, how may I assist you today?"),
visionMessageContent
)
}


@Test
fun `should return correct json for image content`() {
val imageUrl = ImageUrl("https://example.com/image.jpg")
Expand All @@ -21,14 +112,14 @@ class VisionMessageContentTest {
""".trimIndent()

assertEquals(
expectedJson,
prettyJson.encodeToString(imageContent)
expectedJson, prettyJson.encodeToString(imageContent)
)
}

@Test
fun `should return correct json for text content`() {
val textContent = VisionMessageContent.TextContent(content = "Hello, how may I assist you today?")
val textContent =
VisionMessageContent.TextContent(content = "Hello, how may I assist you today?")
val expectedJson = """
{
"type": "text",
Expand Down

0 comments on commit 4a5a8a3

Please sign in to comment.