diff --git a/apps/opik-backend/pom.xml b/apps/opik-backend/pom.xml
index b03941e774..beeb06c10c 100644
--- a/apps/opik-backend/pom.xml
+++ b/apps/opik-backend/pom.xml
@@ -221,6 +221,10 @@
dev.langchain4j
langchain4j-anthropic
+
+ dev.langchain4j
+ langchain4j-google-ai-gemini
+
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java
index 3c025b9d02..327c4c318c 100644
--- a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java
+++ b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java
@@ -1,6 +1,7 @@
package com.comet.opik;
import com.comet.opik.api.error.JsonInvalidFormatExceptionMapper;
+import com.comet.opik.domain.llmproviders.LlmProviderClientModule;
import com.comet.opik.infrastructure.ConfigurationModule;
import com.comet.opik.infrastructure.EncryptionUtils;
import com.comet.opik.infrastructure.OpikConfiguration;
@@ -72,7 +73,7 @@ public void initialize(Bootstrap bootstrap) {
.withPlugins(new SqlObjectPlugin(), new Jackson2Plugin()))
.modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(),
new RateLimitModule(), new NameGeneratorModule(), new HttpModule(), new EventModule(),
- new ConfigurationModule(), new BiModule())
+ new ConfigurationModule(), new BiModule(), new LlmProviderClientModule())
.installers(JobGuiceyInstaller.class)
.listen(new OpikGuiceyLifecycleEventListener())
.enableAutoConfig()
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java b/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java
index 36de62b825..36642febb5 100644
--- a/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java
+++ b/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java
@@ -11,7 +11,9 @@
@RequiredArgsConstructor
public enum LlmProvider {
OPEN_AI("openai"),
- ANTHROPIC("anthropic");
+ ANTHROPIC("anthropic"),
+ GEMINI("gemini"),
+ ;
@JsonValue
private final String value;
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/ChunkedResponseHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/ChunkedResponseHandler.java
new file mode 100644
index 0000000000..d5f89a59a7
--- /dev/null
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/ChunkedResponseHandler.java
@@ -0,0 +1,60 @@
+package com.comet.opik.domain.llmproviders;
+
+import dev.ai4j.openai4j.chat.ChatCompletionChoice;
+import dev.ai4j.openai4j.chat.ChatCompletionResponse;
+import dev.ai4j.openai4j.chat.Delta;
+import dev.ai4j.openai4j.chat.Role;
+import dev.ai4j.openai4j.shared.Usage;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.output.Response;
+import lombok.NonNull;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.function.Consumer;
+
+public record ChunkedResponseHandler(
+ @NonNull Consumer handleMessage,
+ @NonNull Runnable handleClose,
+ @NonNull Consumer handleError,
+ @NonNull String model) implements StreamingResponseHandler {
+
+ @Override
+ public void onNext(@NonNull String content) {
+ handleMessage.accept(ChatCompletionResponse.builder()
+ .model(model)
+ .choices(List.of(ChatCompletionChoice.builder()
+ .delta(Delta.builder()
+ .content(content)
+ .role(Role.ASSISTANT)
+ .build())
+ .build()))
+ .build());
+ }
+
+ @Override
+ public void onComplete(@NonNull Response response) {
+ handleMessage.accept(ChatCompletionResponse.builder()
+ .model(model)
+ .choices(List.of(ChatCompletionChoice.builder()
+ .delta(Delta.builder()
+ .content("")
+ .role(Role.ASSISTANT)
+ .build())
+ .build()))
+ .usage(Usage.builder()
+ .promptTokens(response.tokenUsage().inputTokenCount())
+ .completionTokens(response.tokenUsage().outputTokenCount())
+ .totalTokens(response.tokenUsage().totalTokenCount())
+ .build())
+ .id(Optional.ofNullable(response.metadata().get("id")).map(Object::toString).orElse(null))
+ .build());
+ handleClose.run();
+ }
+
+ @Override
+ public void onError(@NonNull Throwable throwable) {
+ handleError.accept(throwable);
+ }
+}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/GeminiModelName.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/GeminiModelName.java
new file mode 100644
index 0000000000..25ae957f23
--- /dev/null
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/GeminiModelName.java
@@ -0,0 +1,25 @@
+package com.comet.opik.domain.llmproviders;
+
+import lombok.RequiredArgsConstructor;
+
+/*
+Langchain4j doesn't provide gemini models enum.
+This information is taken from: https://ai.google.dev/gemini-api/docs/models/gemini
+ */
+@RequiredArgsConstructor
+public enum GeminiModelName {
+ GEMINI_2_0_FLASH("gemini-2.0-flash-exp"),
+ GEMINI_1_5_FLASH("gemini-1.5-flash"),
+ GEMINI_1_5_FLASH_8B("gemini-1.5-flash-8b"),
+ GEMINI_1_5_PRO("gemini-1.5-pro"),
+ GEMINI_1_0_PRO("gemini-1.0-pro"),
+ TEXT_EMBEDDING("text-embedding-004"),
+ AQA("aqa");
+
+ private final String value;
+
+ @Override
+ public String toString() {
+ return value;
+ }
+}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java
index 6781b7fb9a..511cd250a9 100644
--- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java
@@ -1,68 +1,33 @@
package com.comet.opik.domain.llmproviders;
-import com.comet.opik.infrastructure.LlmProviderClientConfig;
-import dev.ai4j.openai4j.chat.AssistantMessage;
-import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
-import dev.ai4j.openai4j.chat.Delta;
-import dev.ai4j.openai4j.chat.Message;
-import dev.ai4j.openai4j.chat.Role;
-import dev.ai4j.openai4j.chat.SystemMessage;
-import dev.ai4j.openai4j.chat.UserMessage;
-import dev.ai4j.openai4j.shared.Usage;
-import dev.langchain4j.data.message.AiMessage;
-import dev.langchain4j.model.StreamingResponseHandler;
-import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
-import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
-import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
-import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
-import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent;
-import dev.langchain4j.model.anthropic.internal.api.AnthropicRole;
-import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
-import dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice;
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
import dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException;
-import dev.langchain4j.model.output.Response;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.ws.rs.BadRequestException;
import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
-import org.apache.commons.lang3.StringUtils;
-import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import static com.comet.opik.domain.ChatCompletionService.ERROR_EMPTY_MESSAGES;
import static com.comet.opik.domain.ChatCompletionService.ERROR_NO_COMPLETION_TOKENS;
+@RequiredArgsConstructor
@Slf4j
class LlmProviderAnthropic implements LlmProviderService {
- private final @NonNull LlmProviderClientConfig llmProviderClientConfig;
private final @NonNull AnthropicClient anthropicClient;
- public LlmProviderAnthropic(@NonNull LlmProviderClientConfig llmProviderClientConfig, @NonNull String apiKey) {
- this.llmProviderClientConfig = llmProviderClientConfig;
- this.anthropicClient = newClient(apiKey);
- }
-
@Override
public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
- var response = anthropicClient.createMessage(toAnthropicCreateMessageRequest(request));
+ var mapper = LlmProviderAnthropicMapper.INSTANCE;
+ var response = anthropicClient.createMessage(mapper.toCreateMessageRequest(request));
- return ChatCompletionResponse.builder()
- .id(response.id)
- .model(response.model)
- .choices(response.content.stream().map(content -> toChatCompletionChoice(response, content))
- .toList())
- .usage(Usage.builder()
- .promptTokens(response.usage.inputTokens)
- .completionTokens(response.usage.outputTokens)
- .totalTokens(response.usage.inputTokens + response.usage.outputTokens)
- .build())
- .build();
+ return mapper.toResponse(response);
}
@Override
@@ -72,7 +37,7 @@ public void generateStream(
@NonNull Consumer handleMessage,
@NonNull Runnable handleClose, @NonNull Consumer handleError) {
validateRequest(request);
- anthropicClient.createMessage(toAnthropicCreateMessageRequest(request),
+ anthropicClient.createMessage(LlmProviderAnthropicMapper.INSTANCE.toCreateMessageRequest(request),
new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model()));
}
@@ -88,7 +53,7 @@ public void validateRequest(@NonNull ChatCompletionRequest request) {
}
@Override
- public @NonNull Optional getLlmProviderError(Throwable runtimeException) {
+ public Optional getLlmProviderError(@NonNull Throwable runtimeException) {
if (runtimeException instanceof AnthropicHttpException anthropicHttpException) {
return Optional.of(new ErrorMessage(anthropicHttpException.statusCode(),
anthropicHttpException.getMessage()));
@@ -96,143 +61,4 @@ public void validateRequest(@NonNull ChatCompletionRequest request) {
return Optional.empty();
}
-
- private AnthropicCreateMessageRequest toAnthropicCreateMessageRequest(ChatCompletionRequest request) {
- var builder = AnthropicCreateMessageRequest.builder();
- Optional.ofNullable(request.toolChoice())
- .ifPresent(toolChoice -> builder.toolChoice(AnthropicToolChoice.from(
- request.toolChoice().toString())));
- return builder
- .stream(request.stream())
- .model(request.model())
- .messages(request.messages().stream()
- .filter(message -> List.of(Role.ASSISTANT, Role.USER).contains(message.role()))
- .map(this::toMessage).toList())
- .system(request.messages().stream()
- .filter(message -> message.role() == Role.SYSTEM)
- .map(this::toSystemMessage).toList())
- .temperature(request.temperature())
- .topP(request.topP())
- .stopSequences(request.stop())
- .maxTokens(request.maxCompletionTokens())
- .build();
- }
-
- private AnthropicMessage toMessage(Message message) {
- if (message.role() == Role.ASSISTANT) {
- return AnthropicMessage.builder()
- .role(AnthropicRole.ASSISTANT)
- .content(List.of(new AnthropicTextContent(((AssistantMessage) message).content())))
- .build();
- }
-
- if (message.role() == Role.USER) {
- return AnthropicMessage.builder()
- .role(AnthropicRole.USER)
- .content(List.of(toAnthropicMessageContent(((UserMessage) message).content())))
- .build();
- }
-
- throw new BadRequestException("unexpected message role: " + message.role());
- }
-
- private AnthropicTextContent toSystemMessage(Message message) {
- if (message.role() != Role.SYSTEM) {
- throw new BadRequestException("expecting only system role, got: " + message.role());
- }
-
- return new AnthropicTextContent(((SystemMessage) message).content());
- }
-
- private AnthropicMessageContent toAnthropicMessageContent(Object rawContent) {
- if (rawContent instanceof String content) {
- return new AnthropicTextContent(content);
- }
-
- throw new BadRequestException("only text content is supported");
- }
-
- private ChatCompletionChoice toChatCompletionChoice(
- AnthropicCreateMessageResponse response, AnthropicContent content) {
- return ChatCompletionChoice.builder()
- .message(AssistantMessage.builder()
- .name(content.name)
- .content(content.text)
- .build())
- .finishReason(response.stopReason)
- .build();
- }
-
- private AnthropicClient newClient(String apiKey) {
- var anthropicClientBuilder = AnthropicClient.builder();
- Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
- .map(LlmProviderClientConfig.AnthropicClientConfig::url)
- .ifPresent(url -> {
- if (StringUtils.isNotEmpty(url)) {
- anthropicClientBuilder.baseUrl(url);
- }
- });
- Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
- .map(LlmProviderClientConfig.AnthropicClientConfig::version)
- .ifPresent(version -> {
- if (StringUtils.isNotBlank(version)) {
- anthropicClientBuilder.version(version);
- }
- });
- Optional.ofNullable(llmProviderClientConfig.getLogRequests())
- .ifPresent(anthropicClientBuilder::logRequests);
- Optional.ofNullable(llmProviderClientConfig.getLogResponses())
- .ifPresent(anthropicClientBuilder::logResponses);
- // anthropic client builder only receives one timeout variant
- Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
- .ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration()));
- return anthropicClientBuilder
- .apiKey(apiKey)
- .build();
- }
-
- private record ChunkedResponseHandler(
- Consumer handleMessage,
- Runnable handleClose,
- Consumer handleError,
- String model) implements StreamingResponseHandler {
-
- @Override
- public void onNext(String s) {
- handleMessage.accept(ChatCompletionResponse.builder()
- .model(model)
- .choices(List.of(ChatCompletionChoice.builder()
- .delta(Delta.builder()
- .content(s)
- .role(Role.ASSISTANT)
- .build())
- .build()))
- .build());
- }
-
- @Override
- public void onComplete(Response response) {
- handleMessage.accept(ChatCompletionResponse.builder()
- .model(model)
- .choices(List.of(ChatCompletionChoice.builder()
- .delta(Delta.builder()
- .content("")
- .role(Role.ASSISTANT)
- .build())
- .build()))
- .usage(Usage.builder()
- .promptTokens(response.tokenUsage().inputTokenCount())
- .completionTokens(response.tokenUsage().outputTokenCount())
- .totalTokens(response.tokenUsage().totalTokenCount())
- .build())
- .id((String) response.metadata().get("id"))
- .build());
- handleClose.run();
- }
-
- @Override
- public void onError(Throwable throwable) {
- handleError.accept(throwable);
- }
- }
}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropicMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropicMapper.java
new file mode 100644
index 0000000000..dcce42e8a0
--- /dev/null
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropicMapper.java
@@ -0,0 +1,118 @@
+package com.comet.opik.domain.llmproviders;
+
+import dev.ai4j.openai4j.chat.AssistantMessage;
+import dev.ai4j.openai4j.chat.ChatCompletionChoice;
+import dev.ai4j.openai4j.chat.ChatCompletionRequest;
+import dev.ai4j.openai4j.chat.ChatCompletionResponse;
+import dev.ai4j.openai4j.chat.Message;
+import dev.ai4j.openai4j.chat.Role;
+import dev.ai4j.openai4j.chat.SystemMessage;
+import dev.ai4j.openai4j.chat.UserMessage;
+import dev.ai4j.openai4j.shared.Usage;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicRole;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
+import jakarta.ws.rs.BadRequestException;
+import lombok.NonNull;
+import org.mapstruct.Mapper;
+import org.mapstruct.Mapping;
+import org.mapstruct.Named;
+import org.mapstruct.factory.Mappers;
+
+import java.util.List;
+
+@Mapper
+public interface LlmProviderAnthropicMapper {
+ LlmProviderAnthropicMapper INSTANCE = Mappers.getMapper(LlmProviderAnthropicMapper.class);
+
+ @Mapping(source = "response", target = "choices", qualifiedByName = "mapToChoices")
+ @Mapping(source = "usage", target = "usage", qualifiedByName = "mapToUsage")
+ ChatCompletionResponse toResponse(@NonNull AnthropicCreateMessageResponse response);
+
+ @Mapping(source = "content", target = "message")
+ @Mapping(source = "response.stopReason", target = "finishReason")
+ ChatCompletionChoice toChoice(@NonNull AnthropicContent content, @NonNull AnthropicCreateMessageResponse response);
+
+ @Mapping(source = "text", target = "content")
+ AssistantMessage toAssistantMessage(@NonNull AnthropicContent content);
+
+ @Mapping(expression = "java(request.model())", target = "model")
+ @Mapping(expression = "java(request.stream())", target = "stream")
+ @Mapping(expression = "java(request.temperature())", target = "temperature")
+ @Mapping(expression = "java(request.topP())", target = "topP")
+ @Mapping(expression = "java(request.stop())", target = "stopSequences")
+ @Mapping(expression = "java(request.maxCompletionTokens())", target = "maxTokens")
+ @Mapping(source = "request", target = "messages", qualifiedByName = "mapToMessages")
+ @Mapping(source = "request", target = "system", qualifiedByName = "mapToSystemMessages")
+ AnthropicCreateMessageRequest toCreateMessageRequest(@NonNull ChatCompletionRequest request);
+
+ @Named("mapToChoices")
+ default List mapToChoices(@NonNull AnthropicCreateMessageResponse response) {
+ if (response.content == null || response.content.isEmpty()) {
+ return List.of();
+ }
+ return response.content.stream().map(content -> toChoice(content, response)).toList();
+ }
+
+ @Named("mapToUsage")
+ default Usage mapToUsage(AnthropicUsage usage) {
+ if (usage == null) {
+ return null;
+ }
+
+ return Usage.builder()
+ .promptTokens(usage.inputTokens)
+ .completionTokens(usage.outputTokens)
+ .totalTokens(usage.inputTokens + usage.outputTokens)
+ .build();
+ }
+
+ @Named("mapToMessages")
+ default List mapToMessages(@NonNull ChatCompletionRequest request) {
+ return request.messages().stream()
+ .filter(message -> List.of(Role.ASSISTANT, Role.USER).contains(message.role()))
+ .map(this::mapToAnthropicMessage).toList();
+ }
+
+ @Named("mapToSystemMessages")
+ default List mapToSystemMessages(@NonNull ChatCompletionRequest request) {
+ return request.messages().stream()
+ .filter(message -> message.role() == Role.SYSTEM)
+ .map(this::mapToSystemMessage).toList();
+ }
+
+ default AnthropicMessage mapToAnthropicMessage(@NonNull Message message) {
+ return switch (message) {
+ case AssistantMessage assistantMessage -> AnthropicMessage.builder()
+ .role(AnthropicRole.ASSISTANT)
+ .content(List.of(new AnthropicTextContent(assistantMessage.content())))
+ .build();
+ case UserMessage userMessage -> AnthropicMessage.builder()
+ .role(AnthropicRole.USER)
+ .content(List.of(toAnthropicMessageContent(userMessage.content())))
+ .build();
+ default -> throw new BadRequestException("unexpected message role: " + message.role());
+ };
+ }
+
+ default AnthropicMessageContent toAnthropicMessageContent(@NonNull Object rawContent) {
+ if (rawContent instanceof String content) {
+ return new AnthropicTextContent(content);
+ }
+
+ throw new BadRequestException("only text content is supported");
+ }
+
+ default AnthropicTextContent mapToSystemMessage(@NonNull Message message) {
+ if (message.role() != Role.SYSTEM) {
+ throw new BadRequestException("expecting only system role, got: " + message.role());
+ }
+
+ return new AnthropicTextContent(((SystemMessage) message).content());
+ }
+}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java
new file mode 100644
index 0000000000..145b597271
--- /dev/null
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java
@@ -0,0 +1,72 @@
+package com.comet.opik.domain.llmproviders;
+
+import com.comet.opik.infrastructure.LlmProviderClientConfig;
+import dev.ai4j.openai4j.OpenAiClient;
+import dev.ai4j.openai4j.chat.ChatCompletionRequest;
+import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
+import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel;
+import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel;
+import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
+import org.apache.commons.lang3.StringUtils;
+
+import java.util.Optional;
+
+@RequiredArgsConstructor
+public class LlmProviderClientGenerator {
+ private static final int MAX_RETRIES = 1;
+
+ private final @NonNull LlmProviderClientConfig llmProviderClientConfig;
+
+ public AnthropicClient newAnthropicClient(@NonNull String apiKey) {
+ var anthropicClientBuilder = AnthropicClient.builder();
+ Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
+ .map(LlmProviderClientConfig.AnthropicClientConfig::url)
+ .filter(StringUtils::isNotEmpty)
+ .ifPresent(anthropicClientBuilder::baseUrl);
+ Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
+ .map(LlmProviderClientConfig.AnthropicClientConfig::version)
+ .filter(StringUtils::isNotBlank)
+ .ifPresent(anthropicClientBuilder::version);
+ Optional.ofNullable(llmProviderClientConfig.getLogRequests())
+ .ifPresent(anthropicClientBuilder::logRequests);
+ Optional.ofNullable(llmProviderClientConfig.getLogResponses())
+ .ifPresent(anthropicClientBuilder::logResponses);
+ // anthropic client builder only receives one timeout variant
+ Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
+ .ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration()));
+ return anthropicClientBuilder
+ .apiKey(apiKey)
+ .build();
+ }
+
+ public OpenAiClient newOpenAiClient(@NonNull String apiKey) {
+ var openAiClientBuilder = OpenAiClient.builder();
+ Optional.ofNullable(llmProviderClientConfig.getOpenAiClient())
+ .map(LlmProviderClientConfig.OpenAiClientConfig::url)
+ .filter(StringUtils::isNotBlank)
+ .ifPresent(openAiClientBuilder::baseUrl);
+ Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
+ .ifPresent(callTimeout -> openAiClientBuilder.callTimeout(callTimeout.toJavaDuration()));
+ Optional.ofNullable(llmProviderClientConfig.getConnectTimeout())
+ .ifPresent(connectTimeout -> openAiClientBuilder.connectTimeout(connectTimeout.toJavaDuration()));
+ Optional.ofNullable(llmProviderClientConfig.getReadTimeout())
+ .ifPresent(readTimeout -> openAiClientBuilder.readTimeout(readTimeout.toJavaDuration()));
+ Optional.ofNullable(llmProviderClientConfig.getWriteTimeout())
+ .ifPresent(writeTimeout -> openAiClientBuilder.writeTimeout(writeTimeout.toJavaDuration()));
+ return openAiClientBuilder
+ .openAiApiKey(apiKey)
+ .build();
+ }
+
+ public GoogleAiGeminiChatModel newGeminiClient(@NonNull String apiKey, @NonNull ChatCompletionRequest request) {
+ return LlmProviderGeminiMapper.INSTANCE.toGeminiChatModel(apiKey, request,
+ llmProviderClientConfig.getCallTimeout().toJavaDuration(), MAX_RETRIES);
+ }
+
+ public GoogleAiGeminiStreamingChatModel newGeminiStreamingClient(
+ @NonNull String apiKey, @NonNull ChatCompletionRequest request) {
+ return LlmProviderGeminiMapper.INSTANCE.toGeminiStreamingChatModel(apiKey, request,
+ llmProviderClientConfig.getCallTimeout().toJavaDuration(), MAX_RETRIES);
+ }
+}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientModule.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientModule.java
new file mode 100644
index 0000000000..f6ea5a4cb8
--- /dev/null
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientModule.java
@@ -0,0 +1,19 @@
+package com.comet.opik.domain.llmproviders;
+
+import com.comet.opik.infrastructure.LlmProviderClientConfig;
+import com.comet.opik.infrastructure.OpikConfiguration;
+import com.google.inject.Provides;
+import jakarta.inject.Singleton;
+import lombok.NonNull;
+import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule;
+import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;
+
+public class LlmProviderClientModule extends DropwizardAwareModule {
+
+ @Provides
+ @Singleton
+ public LlmProviderClientGenerator clientGenerator(
+ @NonNull @Config("llmProviderClient") LlmProviderClientConfig config) {
+ return new LlmProviderClientGenerator(config);
+ }
+}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java
index 033d3d6a49..2a92e113ac 100644
--- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java
@@ -3,7 +3,6 @@
import com.comet.opik.api.LlmProvider;
import com.comet.opik.domain.LlmProviderApiKeyService;
import com.comet.opik.infrastructure.EncryptionUtils;
-import com.comet.opik.infrastructure.LlmProviderClientConfig;
import dev.ai4j.openai4j.chat.ChatCompletionModel;
import dev.langchain4j.model.anthropic.AnthropicChatModelName;
import jakarta.inject.Inject;
@@ -12,7 +11,6 @@
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.EnumUtils;
-import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;
import java.util.function.Function;
@@ -21,16 +19,18 @@
public class LlmProviderFactory {
public static final String ERROR_MODEL_NOT_SUPPORTED = "model not supported %s";
- private final @NonNull @Config LlmProviderClientConfig llmProviderClientConfig;
private final @NonNull LlmProviderApiKeyService llmProviderApiKeyService;
+ private final @NonNull LlmProviderClientGenerator llmProviderClientGenerator;
public LlmProviderService getService(@NonNull String workspaceId, @NonNull String model) {
var llmProvider = getLlmProvider(model);
var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider));
return switch (llmProvider) {
- case LlmProvider.OPEN_AI -> new LlmProviderOpenAi(llmProviderClientConfig, apiKey);
- case LlmProvider.ANTHROPIC -> new LlmProviderAnthropic(llmProviderClientConfig, apiKey);
+ case LlmProvider.OPEN_AI -> new LlmProviderOpenAi(llmProviderClientGenerator.newOpenAiClient(apiKey));
+ case LlmProvider.ANTHROPIC ->
+ new LlmProviderAnthropic(llmProviderClientGenerator.newAnthropicClient(apiKey));
+ case LlmProvider.GEMINI -> new LlmProviderGemini(llmProviderClientGenerator, apiKey);
};
}
@@ -44,6 +44,9 @@ private LlmProvider getLlmProvider(String model) {
if (isModelBelongToProvider(model, AnthropicChatModelName.class, AnthropicChatModelName::toString)) {
return LlmProvider.ANTHROPIC;
}
+ if (isModelBelongToProvider(model, GeminiModelName.class, GeminiModelName::toString)) {
+ return LlmProvider.GEMINI;
+ }
throw new BadRequestException(ERROR_MODEL_NOT_SUPPORTED.formatted(model));
}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGemini.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGemini.java
new file mode 100644
index 0000000000..3ae0c2efa9
--- /dev/null
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGemini.java
@@ -0,0 +1,44 @@
+package com.comet.opik.domain.llmproviders;
+
+import dev.ai4j.openai4j.chat.ChatCompletionRequest;
+import dev.ai4j.openai4j.chat.ChatCompletionResponse;
+import io.dropwizard.jersey.errors.ErrorMessage;
+import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
+
+import java.util.Optional;
+import java.util.function.Consumer;
+
+@RequiredArgsConstructor
+public class LlmProviderGemini implements LlmProviderService {
+ private final @NonNull LlmProviderClientGenerator llmProviderClientGenerator;
+ private final @NonNull String apiKey;
+
+ @Override
+ public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
+ var mapper = LlmProviderGeminiMapper.INSTANCE;
+ var response = llmProviderClientGenerator.newGeminiClient(apiKey, request)
+ .generate(request.messages().stream().map(mapper::toChatMessage).toList());
+
+ return mapper.toChatCompletionResponse(request, response);
+ }
+
+ @Override
+ public void generateStream(@NonNull ChatCompletionRequest request, @NonNull String workspaceId,
+ @NonNull Consumer handleMessage, @NonNull Runnable handleClose,
+ @NonNull Consumer handleError) {
+ llmProviderClientGenerator.newGeminiStreamingClient(apiKey, request)
+ .generate(request.messages().stream().map(LlmProviderGeminiMapper.INSTANCE::toChatMessage).toList(),
+ new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model()));
+ }
+
+ @Override
+ public void validateRequest(@NonNull ChatCompletionRequest request) {
+ }
+
+ @Override
+ public Optional getLlmProviderError(@NonNull Throwable runtimeException) {
+ return Optional.empty();
+ }
+
+}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGeminiMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGeminiMapper.java
new file mode 100644
index 0000000000..65fc410a1c
--- /dev/null
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGeminiMapper.java
@@ -0,0 +1,98 @@
+package com.comet.opik.domain.llmproviders;
+
+import dev.ai4j.openai4j.chat.AssistantMessage;
+import dev.ai4j.openai4j.chat.ChatCompletionChoice;
+import dev.ai4j.openai4j.chat.ChatCompletionRequest;
+import dev.ai4j.openai4j.chat.ChatCompletionResponse;
+import dev.ai4j.openai4j.chat.Message;
+import dev.ai4j.openai4j.chat.Role;
+import dev.ai4j.openai4j.shared.Usage;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.SystemMessage;
+import dev.langchain4j.data.message.UserMessage;
+import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel;
+import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel;
+import dev.langchain4j.model.output.Response;
+import jakarta.ws.rs.BadRequestException;
+import lombok.NonNull;
+import org.mapstruct.Mapper;
+import org.mapstruct.Mapping;
+import org.mapstruct.Named;
+import org.mapstruct.factory.Mappers;
+
+import java.time.Duration;
+import java.util.List;
+
+@Mapper
+public interface LlmProviderGeminiMapper {
+ String ERR_UNEXPECTED_ROLE = "unexpected role '%s'";
+ String ERR_ROLE_MSG_TYPE_MISMATCH = "role and message instance are not matching, role: '%s', instance: '%s'";
+
+ LlmProviderGeminiMapper INSTANCE = Mappers.getMapper(LlmProviderGeminiMapper.class);
+
+ default ChatMessage toChatMessage(@NonNull Message message) {
+ if (!List.of(Role.ASSISTANT, Role.USER, Role.SYSTEM).contains(message.role())) {
+ throw new BadRequestException(ERR_UNEXPECTED_ROLE.formatted(message.role()));
+ }
+
+ switch (message.role()) {
+ case ASSISTANT -> {
+ if (message instanceof AssistantMessage assistantMessage) {
+ return AiMessage.from(assistantMessage.content());
+ }
+ }
+ case USER -> {
+ if (message instanceof dev.ai4j.openai4j.chat.UserMessage userMessage) {
+ return UserMessage.from(userMessage.content().toString());
+ }
+ }
+ case SYSTEM -> {
+ if (message instanceof dev.ai4j.openai4j.chat.SystemMessage systemMessage) {
+ return SystemMessage.from(systemMessage.content());
+ }
+ }
+ }
+
+ throw new BadRequestException(ERR_ROLE_MSG_TYPE_MISMATCH.formatted(message.role(),
+ message.getClass().getSimpleName()));
+ }
+
+ @Mapping(expression = "java(request.model())", target = "model")
+ @Mapping(source = "response", target = "choices", qualifiedByName = "mapToChoices")
+ @Mapping(source = "response", target = "usage", qualifiedByName = "mapToUsage")
+ ChatCompletionResponse toChatCompletionResponse(
+ @NonNull ChatCompletionRequest request, @NonNull Response response);
+
+ @Named("mapToChoices")
+ default List mapToChoices(@NonNull Response response) {
+ return List.of(ChatCompletionChoice.builder()
+ .message(AssistantMessage.builder().content(response.content().text()).build())
+ .build());
+ }
+
+ @Named("mapToUsage")
+ default Usage mapToUsage(@NonNull Response response) {
+ return Usage.builder()
+ .promptTokens(response.tokenUsage().inputTokenCount())
+ .completionTokens(response.tokenUsage().outputTokenCount())
+ .totalTokens(response.tokenUsage().totalTokenCount())
+ .build();
+ }
+
+ @Mapping(expression = "java(request.model())", target = "modelName")
+ @Mapping(expression = "java(request.maxCompletionTokens())", target = "maxOutputTokens")
+ @Mapping(expression = "java(request.stop())", target = "stopSequences")
+ @Mapping(expression = "java(request.temperature())", target = "temperature")
+ @Mapping(expression = "java(request.topP())", target = "topP")
+ GoogleAiGeminiChatModel toGeminiChatModel(
+ @NonNull String apiKey, @NonNull ChatCompletionRequest request, @NonNull Duration timeout, int maxRetries);
+
+ @Mapping(expression = "java(request.model())", target = "modelName")
+ @Mapping(expression = "java(request.maxCompletionTokens())", target = "maxOutputTokens")
+ @Mapping(expression = "java(request.stop())", target = "stopSequences")
+ @Mapping(expression = "java(request.temperature())", target = "temperature")
+ @Mapping(expression = "java(request.topP())", target = "topP")
+ GoogleAiGeminiStreamingChatModel toGeminiStreamingChatModel(
+ @NonNull String apiKey, @NonNull ChatCompletionRequest request, @NonNull Duration timeout, int maxRetries);
+}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java
index f8b85c4103..60c902d97a 100644
--- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java
@@ -1,29 +1,21 @@
package com.comet.opik.domain.llmproviders;
-import com.comet.opik.infrastructure.LlmProviderClientConfig;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import io.dropwizard.jersey.errors.ErrorMessage;
-import jakarta.inject.Inject;
import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
-import org.apache.commons.lang3.StringUtils;
import java.util.Optional;
import java.util.function.Consumer;
+@RequiredArgsConstructor
@Slf4j
class LlmProviderOpenAi implements LlmProviderService {
- private final LlmProviderClientConfig llmProviderClientConfig;
- private final OpenAiClient openAiClient;
-
- @Inject
- public LlmProviderOpenAi(LlmProviderClientConfig llmProviderClientConfig, String apiKey) {
- this.llmProviderClientConfig = llmProviderClientConfig;
- this.openAiClient = newOpenAiClient(apiKey);
- }
+ private final @NonNull OpenAiClient openAiClient;
@Override
public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
@@ -50,41 +42,11 @@ public void validateRequest(@NonNull ChatCompletionRequest request) {
}
@Override
- public @NonNull Optional getLlmProviderError(Throwable runtimeException) {
+ public Optional getLlmProviderError(@NonNull Throwable runtimeException) {
if (runtimeException instanceof OpenAiHttpException openAiHttpException) {
return Optional.of(new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage()));
}
return Optional.empty();
}
-
- /**
- * At the moment, openai4j client and also langchain4j wrappers, don't support dynamic API keys. That can imply
- * an important performance penalty for next phases. The following options should be evaluated:
- * - Cache clients, but can be unsafe.
- * - Find and evaluate other clients.
- * - Implement our own client.
- * TODO as part of : OPIK-522
- */
- private OpenAiClient newOpenAiClient(String apiKey) {
- var openAiClientBuilder = OpenAiClient.builder();
- Optional.ofNullable(llmProviderClientConfig.getOpenAiClient())
- .map(LlmProviderClientConfig.OpenAiClientConfig::url)
- .ifPresent(baseUrl -> {
- if (StringUtils.isNotBlank(baseUrl)) {
- openAiClientBuilder.baseUrl(baseUrl);
- }
- });
- Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
- .ifPresent(callTimeout -> openAiClientBuilder.callTimeout(callTimeout.toJavaDuration()));
- Optional.ofNullable(llmProviderClientConfig.getConnectTimeout())
- .ifPresent(connectTimeout -> openAiClientBuilder.connectTimeout(connectTimeout.toJavaDuration()));
- Optional.ofNullable(llmProviderClientConfig.getReadTimeout())
- .ifPresent(readTimeout -> openAiClientBuilder.readTimeout(readTimeout.toJavaDuration()));
- Optional.ofNullable(llmProviderClientConfig.getWriteTimeout())
- .ifPresent(writeTimeout -> openAiClientBuilder.writeTimeout(writeTimeout.toJavaDuration()));
- return openAiClientBuilder
- .openAiApiKey(apiKey)
- .build();
- }
}
diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java
index 851acf2034..83d9140f1d 100644
--- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java
+++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java
@@ -22,5 +22,5 @@ void generateStream(
void validateRequest(@NonNull ChatCompletionRequest request);
- @NonNull Optional getLlmProviderError(Throwable runtimeException);
+ Optional getLlmProviderError(@NonNull Throwable runtimeException);
}
diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java
index e69e5d0964..f7d4d28f60 100644
--- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java
+++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java
@@ -11,6 +11,7 @@
import com.comet.opik.api.resources.utils.WireMockUtils;
import com.comet.opik.api.resources.utils.resources.ChatCompletionsClient;
import com.comet.opik.api.resources.utils.resources.LlmProviderApiKeyResourceClient;
+import com.comet.opik.domain.llmproviders.GeminiModelName;
import com.comet.opik.podam.PodamFactoryUtils;
import com.redis.testcontainers.RedisContainer;
import dev.ai4j.openai4j.chat.ChatCompletionModel;
@@ -223,7 +224,9 @@ private static Stream testModelsProvider() {
arguments(ChatCompletionModel.GPT_4O_MINI.toString(), LlmProvider.OPEN_AI,
UUID.randomUUID().toString()),
arguments(AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620.toString(), LlmProvider.ANTHROPIC,
- System.getenv("ANTHROPIC_API_KEY")));
+ System.getenv("ANTHROPIC_API_KEY")),
+ arguments(GeminiModelName.GEMINI_1_0_PRO.toString(), LlmProvider.GEMINI,
+ System.getenv("GEMINI_AI_KEY")));
}
@ParameterizedTest
diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderClientsMappersTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderClientsMappersTest.java
new file mode 100644
index 0000000000..37f073194b
--- /dev/null
+++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderClientsMappersTest.java
@@ -0,0 +1,127 @@
+package com.comet.opik.domain.llmproviders;
+
+import com.comet.opik.podam.PodamFactoryUtils;
+import dev.ai4j.openai4j.chat.AssistantMessage;
+import dev.ai4j.openai4j.chat.ChatCompletionChoice;
+import dev.ai4j.openai4j.chat.ChatCompletionRequest;
+import dev.ai4j.openai4j.chat.Message;
+import dev.ai4j.openai4j.chat.Role;
+import dev.ai4j.openai4j.shared.Usage;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.SystemMessage;
+import dev.langchain4j.data.message.UserMessage;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
+import dev.langchain4j.model.output.FinishReason;
+import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
+import org.junit.jupiter.api.Nested;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInstance;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import uk.co.jemos.podam.api.PodamFactory;
+
+import java.util.List;
+import java.util.stream.Stream;
+
+import static dev.langchain4j.data.message.AiMessage.aiMessage;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
+
+public class LlmProviderClientsMappersTest {
+ private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory();
+
+ @Nested
+ @TestInstance(TestInstance.Lifecycle.PER_CLASS)
+ class AnthropicMappers {
+ @Test
+ void testToResponse() {
+ var response = podamFactory.manufacturePojo(AnthropicCreateMessageResponse.class);
+
+ var actual = LlmProviderAnthropicMapper.INSTANCE.toResponse(response);
+ assertThat(actual).isNotNull();
+ assertThat(actual.id()).isEqualTo(response.id);
+ assertThat(actual.choices()).isEqualTo(List.of(ChatCompletionChoice.builder()
+ .message(AssistantMessage.builder()
+ .name(response.content.getFirst().name)
+ .content(response.content.getFirst().text)
+ .build())
+ .finishReason(response.stopReason)
+ .build()));
+ assertThat(actual.usage()).isEqualTo(Usage.builder()
+ .promptTokens(response.usage.inputTokens)
+ .completionTokens(response.usage.outputTokens)
+ .totalTokens(response.usage.inputTokens + response.usage.outputTokens)
+ .build());
+ }
+
+ @Test
+ void toCreateMessage() {
+ var request = podamFactory.manufacturePojo(ChatCompletionRequest.class);
+
+ AnthropicCreateMessageRequest actual = LlmProviderAnthropicMapper.INSTANCE
+ .toCreateMessageRequest(request);
+
+ assertThat(actual).isNotNull();
+ assertThat(actual.model).isEqualTo(request.model());
+ assertThat(actual.stream).isEqualTo(request.stream());
+ assertThat(actual.temperature).isEqualTo(request.temperature());
+ assertThat(actual.topP).isEqualTo(request.topP());
+ assertThat(actual.stopSequences).isEqualTo(request.stop());
+ assertThat(actual.messages).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo(
+ request.messages().stream()
+ .filter(message -> List.of(Role.USER, Role.ASSISTANT).contains(message.role()))
+ .map(LlmProviderAnthropicMapper.INSTANCE::mapToAnthropicMessage)
+ .toList());
+ assertThat(actual.system).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo(
+ request.messages().stream()
+ .filter(message -> message.role() == Role.SYSTEM)
+ .map(LlmProviderAnthropicMapper.INSTANCE::mapToSystemMessage)
+ .toList());
+ }
+ }
+
+ @Nested
+ @TestInstance(TestInstance.Lifecycle.PER_CLASS)
+ class GeminiMappers {
+ @Test
+ void testToResponse() {
+ var request = ChatCompletionRequest.builder().model(podamFactory.manufacturePojo(String.class)).build();
+ var response = new Response<>(aiMessage(podamFactory.manufacturePojo(String.class)),
+ new TokenUsage(podamFactory.manufacturePojo(Integer.class),
+ podamFactory.manufacturePojo(Integer.class)),
+ FinishReason.STOP);
+ var actual = LlmProviderGeminiMapper.INSTANCE.toChatCompletionResponse(request, response);
+ assertThat(actual).isNotNull();
+ assertThat(actual.model()).isEqualTo(request.model());
+ assertThat(actual.choices()).isEqualTo(List.of(ChatCompletionChoice.builder()
+ .message(AssistantMessage.builder().content(response.content().text()).build())
+ .build()));
+ assertThat(actual.usage()).isEqualTo(Usage.builder()
+ .promptTokens(response.tokenUsage().inputTokenCount())
+ .completionTokens(response.tokenUsage().outputTokenCount())
+ .totalTokens(response.tokenUsage().totalTokenCount())
+ .build());
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void testToChatMessage(Message message, ChatMessage expected) {
+ ChatMessage actual = LlmProviderGeminiMapper.INSTANCE.toChatMessage(message);
+ assertThat(actual).isEqualTo(expected);
+ }
+
+ private Stream testToChatMessage() {
+ var content = podamFactory.manufacturePojo(String.class);
+ return Stream.of(
+ arguments(AssistantMessage.builder().content(content).build(), AiMessage.from(content)),
+ arguments(dev.ai4j.openai4j.chat.UserMessage.builder().content(content).build(),
+ UserMessage.from(content)),
+ arguments(dev.ai4j.openai4j.chat.SystemMessage.builder().content(content).build(),
+ SystemMessage.from(content)));
+ }
+ }
+}
diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java
index 9250183902..c42e4489d0 100644
--- a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java
+++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java
@@ -25,6 +25,7 @@
import java.io.IOException;
import java.util.List;
import java.util.UUID;
+import java.util.function.Function;
import java.util.stream.Stream;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@@ -68,7 +69,8 @@ void testGetService(String model, LlmProvider llmProvider, Class extends LlmPr
.build());
// SUT
- var llmProviderFactory = new LlmProviderFactory(llmProviderClientConfig, llmProviderApiKeyService);
+ var llmProviderFactory = new LlmProviderFactory(llmProviderApiKeyService,
+ new LlmProviderClientGenerator(llmProviderClientConfig));
LlmProviderService actual = llmProviderFactory.getService(workspaceId, model);
@@ -81,7 +83,9 @@ private static Stream testGetService() {
.map(model -> arguments(model.toString(), LlmProvider.OPEN_AI, LlmProviderOpenAi.class));
var anthropicModels = EnumUtils.getEnumList(AnthropicChatModelName.class).stream()
.map(model -> arguments(model.toString(), LlmProvider.ANTHROPIC, LlmProviderAnthropic.class));
+ var geminiModels = EnumUtils.getEnumList(GeminiModelName.class).stream()
+ .map(model -> arguments(model.toString(), LlmProvider.GEMINI, LlmProviderGemini.class));
- return Stream.concat(openAiModels, anthropicModels);
+ return Stream.of(openAiModels, anthropicModels, geminiModels).flatMap(Function.identity());
}
}
diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java
index b86d77d830..5e960fc186 100644
--- a/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java
+++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java
@@ -13,7 +13,15 @@
import com.comet.opik.podam.manufacturer.ProviderApiKeyManufacturer;
import com.comet.opik.podam.manufacturer.ProviderApiKeyUpdateManufacturer;
import com.comet.opik.podam.manufacturer.UUIDTypeManufacturer;
+import com.comet.opik.podam.manufacturer.anthropic.AnthropicContentManufacturer;
+import com.comet.opik.podam.manufacturer.anthropic.AnthropicCreateMessageResponseManufacturer;
+import com.comet.opik.podam.manufacturer.anthropic.AnthropicUsageManufacturer;
+import com.comet.opik.podam.manufacturer.anthropic.ChatCompletionRequestManufacturer;
import com.fasterxml.jackson.databind.JsonNode;
+import dev.ai4j.openai4j.chat.ChatCompletionRequest;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import jakarta.validation.constraints.DecimalMax;
import jakarta.validation.constraints.DecimalMin;
import jakarta.validation.constraints.Pattern;
@@ -50,6 +58,12 @@ public static PodamFactory newPodamFactory() {
strategy.addOrReplaceTypeManufacturer(PromptVersion.class, PromptVersionManufacturer.INSTANCE);
strategy.addOrReplaceTypeManufacturer(ProviderApiKey.class, ProviderApiKeyManufacturer.INSTANCE);
strategy.addOrReplaceTypeManufacturer(ProviderApiKeyUpdate.class, ProviderApiKeyUpdateManufacturer.INSTANCE);
+ strategy.addOrReplaceTypeManufacturer(AnthropicContent.class, AnthropicContentManufacturer.INSTANCE);
+ strategy.addOrReplaceTypeManufacturer(AnthropicUsage.class, AnthropicUsageManufacturer.INSTANCE);
+ strategy.addOrReplaceTypeManufacturer(AnthropicCreateMessageResponse.class,
+ AnthropicCreateMessageResponseManufacturer.INSTANCE);
+ strategy.addOrReplaceTypeManufacturer(ChatCompletionRequest.class, ChatCompletionRequestManufacturer.INSTANCE);
+
return podamFactory;
}
diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicContentManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicContentManufacturer.java
new file mode 100644
index 0000000000..ae26ad926b
--- /dev/null
+++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicContentManufacturer.java
@@ -0,0 +1,22 @@
+package com.comet.opik.podam.manufacturer.anthropic;
+
+import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
+import uk.co.jemos.podam.api.AttributeMetadata;
+import uk.co.jemos.podam.api.DataProviderStrategy;
+import uk.co.jemos.podam.common.ManufacturingContext;
+import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;
+
+public class AnthropicContentManufacturer extends AbstractTypeManufacturer {
+ public static final AnthropicContentManufacturer INSTANCE = new AnthropicContentManufacturer();
+
+ @Override
+ public AnthropicContent getType(DataProviderStrategy strategy, AttributeMetadata metadata,
+ ManufacturingContext context) {
+ var content = new AnthropicContent();
+ content.name = strategy.getTypeValue(metadata, context, String.class);
+ content.text = strategy.getTypeValue(metadata, context, String.class);
+ content.id = strategy.getTypeValue(metadata, context, String.class);
+
+ return content;
+ }
+}
diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicCreateMessageResponseManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicCreateMessageResponseManufacturer.java
new file mode 100644
index 0000000000..fe7b6c7fb6
--- /dev/null
+++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicCreateMessageResponseManufacturer.java
@@ -0,0 +1,30 @@
+package com.comet.opik.podam.manufacturer.anthropic;
+
+import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
+import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
+import uk.co.jemos.podam.api.AttributeMetadata;
+import uk.co.jemos.podam.api.DataProviderStrategy;
+import uk.co.jemos.podam.common.ManufacturingContext;
+import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;
+
+import java.util.List;
+
+public class AnthropicCreateMessageResponseManufacturer
+ extends
+ AbstractTypeManufacturer {
+ public static final AnthropicCreateMessageResponseManufacturer INSTANCE = new AnthropicCreateMessageResponseManufacturer();
+
+ @Override
+ public AnthropicCreateMessageResponse getType(DataProviderStrategy strategy, AttributeMetadata metadata,
+ ManufacturingContext context) {
+ var response = new AnthropicCreateMessageResponse();
+ response.id = strategy.getTypeValue(metadata, context, String.class);
+ response.model = strategy.getTypeValue(metadata, context, String.class);
+ response.stopReason = strategy.getTypeValue(metadata, context, String.class);
+ response.content = List.of(strategy.getTypeValue(metadata, context, AnthropicContent.class));
+ response.usage = strategy.getTypeValue(metadata, context, AnthropicUsage.class);
+
+ return response;
+ }
+}
diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicUsageManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicUsageManufacturer.java
new file mode 100644
index 0000000000..dfdb1e7f8d
--- /dev/null
+++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicUsageManufacturer.java
@@ -0,0 +1,21 @@
+package com.comet.opik.podam.manufacturer.anthropic;
+
+import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
+import uk.co.jemos.podam.api.AttributeMetadata;
+import uk.co.jemos.podam.api.DataProviderStrategy;
+import uk.co.jemos.podam.common.ManufacturingContext;
+import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;
+
+public class AnthropicUsageManufacturer extends AbstractTypeManufacturer {
+ public static final AnthropicUsageManufacturer INSTANCE = new AnthropicUsageManufacturer();
+
+ @Override
+ public AnthropicUsage getType(DataProviderStrategy strategy, AttributeMetadata metadata,
+ ManufacturingContext context) {
+ var usage = new AnthropicUsage();
+ usage.inputTokens = strategy.getTypeValue(metadata, context, Integer.class);
+ usage.outputTokens = strategy.getTypeValue(metadata, context, Integer.class);
+
+ return usage;
+ }
+}
diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/ChatCompletionRequestManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/ChatCompletionRequestManufacturer.java
new file mode 100644
index 0000000000..8e3278bd01
--- /dev/null
+++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/ChatCompletionRequestManufacturer.java
@@ -0,0 +1,31 @@
+package com.comet.opik.podam.manufacturer.anthropic;
+
+import dev.ai4j.openai4j.chat.ChatCompletionRequest;
+import uk.co.jemos.podam.api.AttributeMetadata;
+import uk.co.jemos.podam.api.DataProviderStrategy;
+import uk.co.jemos.podam.common.ManufacturingContext;
+import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;
+
+public class ChatCompletionRequestManufacturer extends AbstractTypeManufacturer {
+ public static final ChatCompletionRequestManufacturer INSTANCE = new ChatCompletionRequestManufacturer();
+
+ @Override
+ public ChatCompletionRequest getType(DataProviderStrategy strategy, AttributeMetadata metadata,
+ ManufacturingContext context) {
+ var userMessageContent = strategy.getTypeValue(metadata, context, String.class);
+ var assistantMessageContent = strategy.getTypeValue(metadata, context, String.class);
+ var systemMessageContent = strategy.getTypeValue(metadata, context, String.class);
+
+ return ChatCompletionRequest.builder()
+ .model(strategy.getTypeValue(metadata, context, String.class))
+ .stream(strategy.getTypeValue(metadata, context, Boolean.class))
+ .temperature(strategy.getTypeValue(metadata, context, Double.class))
+ .topP(strategy.getTypeValue(metadata, context, Double.class))
+ .stop(strategy.getTypeValue(metadata, context, String.class))
+ .addUserMessage(userMessageContent)
+ .addAssistantMessage(assistantMessageContent)
+ .addSystemMessage(systemMessageContent)
+ .maxCompletionTokens(strategy.getTypeValue(metadata, context, Integer.class))
+ .build();
+ }
+}