diff --git a/apps/opik-backend/pom.xml b/apps/opik-backend/pom.xml index b03941e774..beeb06c10c 100644 --- a/apps/opik-backend/pom.xml +++ b/apps/opik-backend/pom.xml @@ -221,6 +221,10 @@ dev.langchain4j langchain4j-anthropic + + dev.langchain4j + langchain4j-google-ai-gemini + diff --git a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java index 3c025b9d02..327c4c318c 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java @@ -1,6 +1,7 @@ package com.comet.opik; import com.comet.opik.api.error.JsonInvalidFormatExceptionMapper; +import com.comet.opik.domain.llmproviders.LlmProviderClientModule; import com.comet.opik.infrastructure.ConfigurationModule; import com.comet.opik.infrastructure.EncryptionUtils; import com.comet.opik.infrastructure.OpikConfiguration; @@ -72,7 +73,7 @@ public void initialize(Bootstrap bootstrap) { .withPlugins(new SqlObjectPlugin(), new Jackson2Plugin())) .modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(), new RateLimitModule(), new NameGeneratorModule(), new HttpModule(), new EventModule(), - new ConfigurationModule(), new BiModule()) + new ConfigurationModule(), new BiModule(), new LlmProviderClientModule()) .installers(JobGuiceyInstaller.class) .listen(new OpikGuiceyLifecycleEventListener()) .enableAutoConfig() diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java b/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java index 36de62b825..36642febb5 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java @@ -11,7 +11,9 @@ @RequiredArgsConstructor public enum LlmProvider { OPEN_AI("openai"), - ANTHROPIC("anthropic"); + ANTHROPIC("anthropic"), + GEMINI("gemini"), + ; @JsonValue private final String value; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/ChunkedResponseHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/ChunkedResponseHandler.java new file mode 100644 index 0000000000..d5f89a59a7 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/ChunkedResponseHandler.java @@ -0,0 +1,60 @@ +package com.comet.opik.domain.llmproviders; + +import dev.ai4j.openai4j.chat.ChatCompletionChoice; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import dev.ai4j.openai4j.chat.Delta; +import dev.ai4j.openai4j.chat.Role; +import dev.ai4j.openai4j.shared.Usage; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.output.Response; +import lombok.NonNull; + +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; + +public record ChunkedResponseHandler( + @NonNull Consumer handleMessage, + @NonNull Runnable handleClose, + @NonNull Consumer handleError, + @NonNull String model) implements StreamingResponseHandler { + + @Override + public void onNext(@NonNull String content) { + handleMessage.accept(ChatCompletionResponse.builder() + .model(model) + .choices(List.of(ChatCompletionChoice.builder() + .delta(Delta.builder() + .content(content) + .role(Role.ASSISTANT) + .build()) + .build())) + .build()); + } + + @Override + public void onComplete(@NonNull Response response) { + handleMessage.accept(ChatCompletionResponse.builder() + .model(model) + .choices(List.of(ChatCompletionChoice.builder() + .delta(Delta.builder() + .content("") + .role(Role.ASSISTANT) + .build()) + .build())) + .usage(Usage.builder() + .promptTokens(response.tokenUsage().inputTokenCount()) + .completionTokens(response.tokenUsage().outputTokenCount()) + .totalTokens(response.tokenUsage().totalTokenCount()) + .build()) + .id(Optional.ofNullable(response.metadata().get("id")).map(Object::toString).orElse(null)) + .build()); + handleClose.run(); + } + + @Override + public void onError(@NonNull Throwable throwable) { + handleError.accept(throwable); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/GeminiModelName.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/GeminiModelName.java new file mode 100644 index 0000000000..25ae957f23 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/GeminiModelName.java @@ -0,0 +1,25 @@ +package com.comet.opik.domain.llmproviders; + +import lombok.RequiredArgsConstructor; + +/* +Langchain4j doesn't provide gemini models enum. +This information is taken from: https://ai.google.dev/gemini-api/docs/models/gemini + */ +@RequiredArgsConstructor +public enum GeminiModelName { + GEMINI_2_0_FLASH("gemini-2.0-flash-exp"), + GEMINI_1_5_FLASH("gemini-1.5-flash"), + GEMINI_1_5_FLASH_8B("gemini-1.5-flash-8b"), + GEMINI_1_5_PRO("gemini-1.5-pro"), + GEMINI_1_0_PRO("gemini-1.0-pro"), + TEXT_EMBEDDING("text-embedding-004"), + AQA("aqa"); + + private final String value; + + @Override + public String toString() { + return value; + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java index 6781b7fb9a..511cd250a9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java @@ -1,68 +1,33 @@ package com.comet.opik.domain.llmproviders; -import com.comet.opik.infrastructure.LlmProviderClientConfig; -import dev.ai4j.openai4j.chat.AssistantMessage; -import dev.ai4j.openai4j.chat.ChatCompletionChoice; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; -import dev.ai4j.openai4j.chat.Delta; -import dev.ai4j.openai4j.chat.Message; -import dev.ai4j.openai4j.chat.Role; -import dev.ai4j.openai4j.chat.SystemMessage; -import dev.ai4j.openai4j.chat.UserMessage; -import dev.ai4j.openai4j.shared.Usage; -import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.model.StreamingResponseHandler; -import dev.langchain4j.model.anthropic.internal.api.AnthropicContent; -import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest; -import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; -import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage; -import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent; -import dev.langchain4j.model.anthropic.internal.api.AnthropicRole; -import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent; -import dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice; import dev.langchain4j.model.anthropic.internal.client.AnthropicClient; import dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException; -import dev.langchain4j.model.output.Response; import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.ws.rs.BadRequestException; import lombok.NonNull; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.lang3.StringUtils; -import java.util.List; import java.util.Optional; import java.util.function.Consumer; import static com.comet.opik.domain.ChatCompletionService.ERROR_EMPTY_MESSAGES; import static com.comet.opik.domain.ChatCompletionService.ERROR_NO_COMPLETION_TOKENS; +@RequiredArgsConstructor @Slf4j class LlmProviderAnthropic implements LlmProviderService { - private final @NonNull LlmProviderClientConfig llmProviderClientConfig; private final @NonNull AnthropicClient anthropicClient; - public LlmProviderAnthropic(@NonNull LlmProviderClientConfig llmProviderClientConfig, @NonNull String apiKey) { - this.llmProviderClientConfig = llmProviderClientConfig; - this.anthropicClient = newClient(apiKey); - } - @Override public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) { - var response = anthropicClient.createMessage(toAnthropicCreateMessageRequest(request)); + var mapper = LlmProviderAnthropicMapper.INSTANCE; + var response = anthropicClient.createMessage(mapper.toCreateMessageRequest(request)); - return ChatCompletionResponse.builder() - .id(response.id) - .model(response.model) - .choices(response.content.stream().map(content -> toChatCompletionChoice(response, content)) - .toList()) - .usage(Usage.builder() - .promptTokens(response.usage.inputTokens) - .completionTokens(response.usage.outputTokens) - .totalTokens(response.usage.inputTokens + response.usage.outputTokens) - .build()) - .build(); + return mapper.toResponse(response); } @Override @@ -72,7 +37,7 @@ public void generateStream( @NonNull Consumer handleMessage, @NonNull Runnable handleClose, @NonNull Consumer handleError) { validateRequest(request); - anthropicClient.createMessage(toAnthropicCreateMessageRequest(request), + anthropicClient.createMessage(LlmProviderAnthropicMapper.INSTANCE.toCreateMessageRequest(request), new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model())); } @@ -88,7 +53,7 @@ public void validateRequest(@NonNull ChatCompletionRequest request) { } @Override - public @NonNull Optional getLlmProviderError(Throwable runtimeException) { + public Optional getLlmProviderError(@NonNull Throwable runtimeException) { if (runtimeException instanceof AnthropicHttpException anthropicHttpException) { return Optional.of(new ErrorMessage(anthropicHttpException.statusCode(), anthropicHttpException.getMessage())); @@ -96,143 +61,4 @@ public void validateRequest(@NonNull ChatCompletionRequest request) { return Optional.empty(); } - - private AnthropicCreateMessageRequest toAnthropicCreateMessageRequest(ChatCompletionRequest request) { - var builder = AnthropicCreateMessageRequest.builder(); - Optional.ofNullable(request.toolChoice()) - .ifPresent(toolChoice -> builder.toolChoice(AnthropicToolChoice.from( - request.toolChoice().toString()))); - return builder - .stream(request.stream()) - .model(request.model()) - .messages(request.messages().stream() - .filter(message -> List.of(Role.ASSISTANT, Role.USER).contains(message.role())) - .map(this::toMessage).toList()) - .system(request.messages().stream() - .filter(message -> message.role() == Role.SYSTEM) - .map(this::toSystemMessage).toList()) - .temperature(request.temperature()) - .topP(request.topP()) - .stopSequences(request.stop()) - .maxTokens(request.maxCompletionTokens()) - .build(); - } - - private AnthropicMessage toMessage(Message message) { - if (message.role() == Role.ASSISTANT) { - return AnthropicMessage.builder() - .role(AnthropicRole.ASSISTANT) - .content(List.of(new AnthropicTextContent(((AssistantMessage) message).content()))) - .build(); - } - - if (message.role() == Role.USER) { - return AnthropicMessage.builder() - .role(AnthropicRole.USER) - .content(List.of(toAnthropicMessageContent(((UserMessage) message).content()))) - .build(); - } - - throw new BadRequestException("unexpected message role: " + message.role()); - } - - private AnthropicTextContent toSystemMessage(Message message) { - if (message.role() != Role.SYSTEM) { - throw new BadRequestException("expecting only system role, got: " + message.role()); - } - - return new AnthropicTextContent(((SystemMessage) message).content()); - } - - private AnthropicMessageContent toAnthropicMessageContent(Object rawContent) { - if (rawContent instanceof String content) { - return new AnthropicTextContent(content); - } - - throw new BadRequestException("only text content is supported"); - } - - private ChatCompletionChoice toChatCompletionChoice( - AnthropicCreateMessageResponse response, AnthropicContent content) { - return ChatCompletionChoice.builder() - .message(AssistantMessage.builder() - .name(content.name) - .content(content.text) - .build()) - .finishReason(response.stopReason) - .build(); - } - - private AnthropicClient newClient(String apiKey) { - var anthropicClientBuilder = AnthropicClient.builder(); - Optional.ofNullable(llmProviderClientConfig.getAnthropicClient()) - .map(LlmProviderClientConfig.AnthropicClientConfig::url) - .ifPresent(url -> { - if (StringUtils.isNotEmpty(url)) { - anthropicClientBuilder.baseUrl(url); - } - }); - Optional.ofNullable(llmProviderClientConfig.getAnthropicClient()) - .map(LlmProviderClientConfig.AnthropicClientConfig::version) - .ifPresent(version -> { - if (StringUtils.isNotBlank(version)) { - anthropicClientBuilder.version(version); - } - }); - Optional.ofNullable(llmProviderClientConfig.getLogRequests()) - .ifPresent(anthropicClientBuilder::logRequests); - Optional.ofNullable(llmProviderClientConfig.getLogResponses()) - .ifPresent(anthropicClientBuilder::logResponses); - // anthropic client builder only receives one timeout variant - Optional.ofNullable(llmProviderClientConfig.getCallTimeout()) - .ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration())); - return anthropicClientBuilder - .apiKey(apiKey) - .build(); - } - - private record ChunkedResponseHandler( - Consumer handleMessage, - Runnable handleClose, - Consumer handleError, - String model) implements StreamingResponseHandler { - - @Override - public void onNext(String s) { - handleMessage.accept(ChatCompletionResponse.builder() - .model(model) - .choices(List.of(ChatCompletionChoice.builder() - .delta(Delta.builder() - .content(s) - .role(Role.ASSISTANT) - .build()) - .build())) - .build()); - } - - @Override - public void onComplete(Response response) { - handleMessage.accept(ChatCompletionResponse.builder() - .model(model) - .choices(List.of(ChatCompletionChoice.builder() - .delta(Delta.builder() - .content("") - .role(Role.ASSISTANT) - .build()) - .build())) - .usage(Usage.builder() - .promptTokens(response.tokenUsage().inputTokenCount()) - .completionTokens(response.tokenUsage().outputTokenCount()) - .totalTokens(response.tokenUsage().totalTokenCount()) - .build()) - .id((String) response.metadata().get("id")) - .build()); - handleClose.run(); - } - - @Override - public void onError(Throwable throwable) { - handleError.accept(throwable); - } - } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropicMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropicMapper.java new file mode 100644 index 0000000000..dcce42e8a0 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropicMapper.java @@ -0,0 +1,118 @@ +package com.comet.opik.domain.llmproviders; + +import dev.ai4j.openai4j.chat.AssistantMessage; +import dev.ai4j.openai4j.chat.ChatCompletionChoice; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import dev.ai4j.openai4j.chat.Message; +import dev.ai4j.openai4j.chat.Role; +import dev.ai4j.openai4j.chat.SystemMessage; +import dev.ai4j.openai4j.chat.UserMessage; +import dev.ai4j.openai4j.shared.Usage; +import dev.langchain4j.model.anthropic.internal.api.AnthropicContent; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; +import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage; +import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent; +import dev.langchain4j.model.anthropic.internal.api.AnthropicRole; +import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent; +import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage; +import jakarta.ws.rs.BadRequestException; +import lombok.NonNull; +import org.mapstruct.Mapper; +import org.mapstruct.Mapping; +import org.mapstruct.Named; +import org.mapstruct.factory.Mappers; + +import java.util.List; + +@Mapper +public interface LlmProviderAnthropicMapper { + LlmProviderAnthropicMapper INSTANCE = Mappers.getMapper(LlmProviderAnthropicMapper.class); + + @Mapping(source = "response", target = "choices", qualifiedByName = "mapToChoices") + @Mapping(source = "usage", target = "usage", qualifiedByName = "mapToUsage") + ChatCompletionResponse toResponse(@NonNull AnthropicCreateMessageResponse response); + + @Mapping(source = "content", target = "message") + @Mapping(source = "response.stopReason", target = "finishReason") + ChatCompletionChoice toChoice(@NonNull AnthropicContent content, @NonNull AnthropicCreateMessageResponse response); + + @Mapping(source = "text", target = "content") + AssistantMessage toAssistantMessage(@NonNull AnthropicContent content); + + @Mapping(expression = "java(request.model())", target = "model") + @Mapping(expression = "java(request.stream())", target = "stream") + @Mapping(expression = "java(request.temperature())", target = "temperature") + @Mapping(expression = "java(request.topP())", target = "topP") + @Mapping(expression = "java(request.stop())", target = "stopSequences") + @Mapping(expression = "java(request.maxCompletionTokens())", target = "maxTokens") + @Mapping(source = "request", target = "messages", qualifiedByName = "mapToMessages") + @Mapping(source = "request", target = "system", qualifiedByName = "mapToSystemMessages") + AnthropicCreateMessageRequest toCreateMessageRequest(@NonNull ChatCompletionRequest request); + + @Named("mapToChoices") + default List mapToChoices(@NonNull AnthropicCreateMessageResponse response) { + if (response.content == null || response.content.isEmpty()) { + return List.of(); + } + return response.content.stream().map(content -> toChoice(content, response)).toList(); + } + + @Named("mapToUsage") + default Usage mapToUsage(AnthropicUsage usage) { + if (usage == null) { + return null; + } + + return Usage.builder() + .promptTokens(usage.inputTokens) + .completionTokens(usage.outputTokens) + .totalTokens(usage.inputTokens + usage.outputTokens) + .build(); + } + + @Named("mapToMessages") + default List mapToMessages(@NonNull ChatCompletionRequest request) { + return request.messages().stream() + .filter(message -> List.of(Role.ASSISTANT, Role.USER).contains(message.role())) + .map(this::mapToAnthropicMessage).toList(); + } + + @Named("mapToSystemMessages") + default List mapToSystemMessages(@NonNull ChatCompletionRequest request) { + return request.messages().stream() + .filter(message -> message.role() == Role.SYSTEM) + .map(this::mapToSystemMessage).toList(); + } + + default AnthropicMessage mapToAnthropicMessage(@NonNull Message message) { + return switch (message) { + case AssistantMessage assistantMessage -> AnthropicMessage.builder() + .role(AnthropicRole.ASSISTANT) + .content(List.of(new AnthropicTextContent(assistantMessage.content()))) + .build(); + case UserMessage userMessage -> AnthropicMessage.builder() + .role(AnthropicRole.USER) + .content(List.of(toAnthropicMessageContent(userMessage.content()))) + .build(); + default -> throw new BadRequestException("unexpected message role: " + message.role()); + }; + } + + default AnthropicMessageContent toAnthropicMessageContent(@NonNull Object rawContent) { + if (rawContent instanceof String content) { + return new AnthropicTextContent(content); + } + + throw new BadRequestException("only text content is supported"); + } + + default AnthropicTextContent mapToSystemMessage(@NonNull Message message) { + if (message.role() != Role.SYSTEM) { + throw new BadRequestException("expecting only system role, got: " + message.role()); + } + + return new AnthropicTextContent(((SystemMessage) message).content()); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java new file mode 100644 index 0000000000..145b597271 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java @@ -0,0 +1,72 @@ +package com.comet.opik.domain.llmproviders; + +import com.comet.opik.infrastructure.LlmProviderClientConfig; +import dev.ai4j.openai4j.OpenAiClient; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.langchain4j.model.anthropic.internal.client.AnthropicClient; +import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel; +import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; + +import java.util.Optional; + +@RequiredArgsConstructor +public class LlmProviderClientGenerator { + private static final int MAX_RETRIES = 1; + + private final @NonNull LlmProviderClientConfig llmProviderClientConfig; + + public AnthropicClient newAnthropicClient(@NonNull String apiKey) { + var anthropicClientBuilder = AnthropicClient.builder(); + Optional.ofNullable(llmProviderClientConfig.getAnthropicClient()) + .map(LlmProviderClientConfig.AnthropicClientConfig::url) + .filter(StringUtils::isNotEmpty) + .ifPresent(anthropicClientBuilder::baseUrl); + Optional.ofNullable(llmProviderClientConfig.getAnthropicClient()) + .map(LlmProviderClientConfig.AnthropicClientConfig::version) + .filter(StringUtils::isNotBlank) + .ifPresent(anthropicClientBuilder::version); + Optional.ofNullable(llmProviderClientConfig.getLogRequests()) + .ifPresent(anthropicClientBuilder::logRequests); + Optional.ofNullable(llmProviderClientConfig.getLogResponses()) + .ifPresent(anthropicClientBuilder::logResponses); + // anthropic client builder only receives one timeout variant + Optional.ofNullable(llmProviderClientConfig.getCallTimeout()) + .ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration())); + return anthropicClientBuilder + .apiKey(apiKey) + .build(); + } + + public OpenAiClient newOpenAiClient(@NonNull String apiKey) { + var openAiClientBuilder = OpenAiClient.builder(); + Optional.ofNullable(llmProviderClientConfig.getOpenAiClient()) + .map(LlmProviderClientConfig.OpenAiClientConfig::url) + .filter(StringUtils::isNotBlank) + .ifPresent(openAiClientBuilder::baseUrl); + Optional.ofNullable(llmProviderClientConfig.getCallTimeout()) + .ifPresent(callTimeout -> openAiClientBuilder.callTimeout(callTimeout.toJavaDuration())); + Optional.ofNullable(llmProviderClientConfig.getConnectTimeout()) + .ifPresent(connectTimeout -> openAiClientBuilder.connectTimeout(connectTimeout.toJavaDuration())); + Optional.ofNullable(llmProviderClientConfig.getReadTimeout()) + .ifPresent(readTimeout -> openAiClientBuilder.readTimeout(readTimeout.toJavaDuration())); + Optional.ofNullable(llmProviderClientConfig.getWriteTimeout()) + .ifPresent(writeTimeout -> openAiClientBuilder.writeTimeout(writeTimeout.toJavaDuration())); + return openAiClientBuilder + .openAiApiKey(apiKey) + .build(); + } + + public GoogleAiGeminiChatModel newGeminiClient(@NonNull String apiKey, @NonNull ChatCompletionRequest request) { + return LlmProviderGeminiMapper.INSTANCE.toGeminiChatModel(apiKey, request, + llmProviderClientConfig.getCallTimeout().toJavaDuration(), MAX_RETRIES); + } + + public GoogleAiGeminiStreamingChatModel newGeminiStreamingClient( + @NonNull String apiKey, @NonNull ChatCompletionRequest request) { + return LlmProviderGeminiMapper.INSTANCE.toGeminiStreamingChatModel(apiKey, request, + llmProviderClientConfig.getCallTimeout().toJavaDuration(), MAX_RETRIES); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientModule.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientModule.java new file mode 100644 index 0000000000..f6ea5a4cb8 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientModule.java @@ -0,0 +1,19 @@ +package com.comet.opik.domain.llmproviders; + +import com.comet.opik.infrastructure.LlmProviderClientConfig; +import com.comet.opik.infrastructure.OpikConfiguration; +import com.google.inject.Provides; +import jakarta.inject.Singleton; +import lombok.NonNull; +import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule; +import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; + +public class LlmProviderClientModule extends DropwizardAwareModule { + + @Provides + @Singleton + public LlmProviderClientGenerator clientGenerator( + @NonNull @Config("llmProviderClient") LlmProviderClientConfig config) { + return new LlmProviderClientGenerator(config); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java index 033d3d6a49..2a92e113ac 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java @@ -3,7 +3,6 @@ import com.comet.opik.api.LlmProvider; import com.comet.opik.domain.LlmProviderApiKeyService; import com.comet.opik.infrastructure.EncryptionUtils; -import com.comet.opik.infrastructure.LlmProviderClientConfig; import dev.ai4j.openai4j.chat.ChatCompletionModel; import dev.langchain4j.model.anthropic.AnthropicChatModelName; import jakarta.inject.Inject; @@ -12,7 +11,6 @@ import lombok.NonNull; import lombok.RequiredArgsConstructor; import org.apache.commons.lang3.EnumUtils; -import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; import java.util.function.Function; @@ -21,16 +19,18 @@ public class LlmProviderFactory { public static final String ERROR_MODEL_NOT_SUPPORTED = "model not supported %s"; - private final @NonNull @Config LlmProviderClientConfig llmProviderClientConfig; private final @NonNull LlmProviderApiKeyService llmProviderApiKeyService; + private final @NonNull LlmProviderClientGenerator llmProviderClientGenerator; public LlmProviderService getService(@NonNull String workspaceId, @NonNull String model) { var llmProvider = getLlmProvider(model); var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider)); return switch (llmProvider) { - case LlmProvider.OPEN_AI -> new LlmProviderOpenAi(llmProviderClientConfig, apiKey); - case LlmProvider.ANTHROPIC -> new LlmProviderAnthropic(llmProviderClientConfig, apiKey); + case LlmProvider.OPEN_AI -> new LlmProviderOpenAi(llmProviderClientGenerator.newOpenAiClient(apiKey)); + case LlmProvider.ANTHROPIC -> + new LlmProviderAnthropic(llmProviderClientGenerator.newAnthropicClient(apiKey)); + case LlmProvider.GEMINI -> new LlmProviderGemini(llmProviderClientGenerator, apiKey); }; } @@ -44,6 +44,9 @@ private LlmProvider getLlmProvider(String model) { if (isModelBelongToProvider(model, AnthropicChatModelName.class, AnthropicChatModelName::toString)) { return LlmProvider.ANTHROPIC; } + if (isModelBelongToProvider(model, GeminiModelName.class, GeminiModelName::toString)) { + return LlmProvider.GEMINI; + } throw new BadRequestException(ERROR_MODEL_NOT_SUPPORTED.formatted(model)); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGemini.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGemini.java new file mode 100644 index 0000000000..3ae0c2efa9 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGemini.java @@ -0,0 +1,44 @@ +package com.comet.opik.domain.llmproviders; + +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import io.dropwizard.jersey.errors.ErrorMessage; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +import java.util.Optional; +import java.util.function.Consumer; + +@RequiredArgsConstructor +public class LlmProviderGemini implements LlmProviderService { + private final @NonNull LlmProviderClientGenerator llmProviderClientGenerator; + private final @NonNull String apiKey; + + @Override + public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) { + var mapper = LlmProviderGeminiMapper.INSTANCE; + var response = llmProviderClientGenerator.newGeminiClient(apiKey, request) + .generate(request.messages().stream().map(mapper::toChatMessage).toList()); + + return mapper.toChatCompletionResponse(request, response); + } + + @Override + public void generateStream(@NonNull ChatCompletionRequest request, @NonNull String workspaceId, + @NonNull Consumer handleMessage, @NonNull Runnable handleClose, + @NonNull Consumer handleError) { + llmProviderClientGenerator.newGeminiStreamingClient(apiKey, request) + .generate(request.messages().stream().map(LlmProviderGeminiMapper.INSTANCE::toChatMessage).toList(), + new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model())); + } + + @Override + public void validateRequest(@NonNull ChatCompletionRequest request) { + } + + @Override + public Optional getLlmProviderError(@NonNull Throwable runtimeException) { + return Optional.empty(); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGeminiMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGeminiMapper.java new file mode 100644 index 0000000000..65fc410a1c --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGeminiMapper.java @@ -0,0 +1,98 @@ +package com.comet.opik.domain.llmproviders; + +import dev.ai4j.openai4j.chat.AssistantMessage; +import dev.ai4j.openai4j.chat.ChatCompletionChoice; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import dev.ai4j.openai4j.chat.Message; +import dev.ai4j.openai4j.chat.Role; +import dev.ai4j.openai4j.shared.Usage; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel; +import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel; +import dev.langchain4j.model.output.Response; +import jakarta.ws.rs.BadRequestException; +import lombok.NonNull; +import org.mapstruct.Mapper; +import org.mapstruct.Mapping; +import org.mapstruct.Named; +import org.mapstruct.factory.Mappers; + +import java.time.Duration; +import java.util.List; + +@Mapper +public interface LlmProviderGeminiMapper { + String ERR_UNEXPECTED_ROLE = "unexpected role '%s'"; + String ERR_ROLE_MSG_TYPE_MISMATCH = "role and message instance are not matching, role: '%s', instance: '%s'"; + + LlmProviderGeminiMapper INSTANCE = Mappers.getMapper(LlmProviderGeminiMapper.class); + + default ChatMessage toChatMessage(@NonNull Message message) { + if (!List.of(Role.ASSISTANT, Role.USER, Role.SYSTEM).contains(message.role())) { + throw new BadRequestException(ERR_UNEXPECTED_ROLE.formatted(message.role())); + } + + switch (message.role()) { + case ASSISTANT -> { + if (message instanceof AssistantMessage assistantMessage) { + return AiMessage.from(assistantMessage.content()); + } + } + case USER -> { + if (message instanceof dev.ai4j.openai4j.chat.UserMessage userMessage) { + return UserMessage.from(userMessage.content().toString()); + } + } + case SYSTEM -> { + if (message instanceof dev.ai4j.openai4j.chat.SystemMessage systemMessage) { + return SystemMessage.from(systemMessage.content()); + } + } + } + + throw new BadRequestException(ERR_ROLE_MSG_TYPE_MISMATCH.formatted(message.role(), + message.getClass().getSimpleName())); + } + + @Mapping(expression = "java(request.model())", target = "model") + @Mapping(source = "response", target = "choices", qualifiedByName = "mapToChoices") + @Mapping(source = "response", target = "usage", qualifiedByName = "mapToUsage") + ChatCompletionResponse toChatCompletionResponse( + @NonNull ChatCompletionRequest request, @NonNull Response response); + + @Named("mapToChoices") + default List mapToChoices(@NonNull Response response) { + return List.of(ChatCompletionChoice.builder() + .message(AssistantMessage.builder().content(response.content().text()).build()) + .build()); + } + + @Named("mapToUsage") + default Usage mapToUsage(@NonNull Response response) { + return Usage.builder() + .promptTokens(response.tokenUsage().inputTokenCount()) + .completionTokens(response.tokenUsage().outputTokenCount()) + .totalTokens(response.tokenUsage().totalTokenCount()) + .build(); + } + + @Mapping(expression = "java(request.model())", target = "modelName") + @Mapping(expression = "java(request.maxCompletionTokens())", target = "maxOutputTokens") + @Mapping(expression = "java(request.stop())", target = "stopSequences") + @Mapping(expression = "java(request.temperature())", target = "temperature") + @Mapping(expression = "java(request.topP())", target = "topP") + GoogleAiGeminiChatModel toGeminiChatModel( + @NonNull String apiKey, @NonNull ChatCompletionRequest request, @NonNull Duration timeout, int maxRetries); + + @Mapping(expression = "java(request.model())", target = "modelName") + @Mapping(expression = "java(request.maxCompletionTokens())", target = "maxOutputTokens") + @Mapping(expression = "java(request.stop())", target = "stopSequences") + @Mapping(expression = "java(request.temperature())", target = "temperature") + @Mapping(expression = "java(request.topP())", target = "topP") + GoogleAiGeminiStreamingChatModel toGeminiStreamingChatModel( + @NonNull String apiKey, @NonNull ChatCompletionRequest request, @NonNull Duration timeout, int maxRetries); +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java index f8b85c4103..60c902d97a 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java @@ -1,29 +1,21 @@ package com.comet.opik.domain.llmproviders; -import com.comet.opik.infrastructure.LlmProviderClientConfig; import dev.ai4j.openai4j.OpenAiClient; import dev.ai4j.openai4j.OpenAiHttpException; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import io.dropwizard.jersey.errors.ErrorMessage; -import jakarta.inject.Inject; import lombok.NonNull; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; import java.util.Optional; import java.util.function.Consumer; +@RequiredArgsConstructor @Slf4j class LlmProviderOpenAi implements LlmProviderService { - private final LlmProviderClientConfig llmProviderClientConfig; - private final OpenAiClient openAiClient; - - @Inject - public LlmProviderOpenAi(LlmProviderClientConfig llmProviderClientConfig, String apiKey) { - this.llmProviderClientConfig = llmProviderClientConfig; - this.openAiClient = newOpenAiClient(apiKey); - } + private final @NonNull OpenAiClient openAiClient; @Override public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) { @@ -50,41 +42,11 @@ public void validateRequest(@NonNull ChatCompletionRequest request) { } @Override - public @NonNull Optional getLlmProviderError(Throwable runtimeException) { + public Optional getLlmProviderError(@NonNull Throwable runtimeException) { if (runtimeException instanceof OpenAiHttpException openAiHttpException) { return Optional.of(new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage())); } return Optional.empty(); } - - /** - * At the moment, openai4j client and also langchain4j wrappers, don't support dynamic API keys. That can imply - * an important performance penalty for next phases. The following options should be evaluated: - * - Cache clients, but can be unsafe. - * - Find and evaluate other clients. - * - Implement our own client. - * TODO as part of : OPIK-522 - */ - private OpenAiClient newOpenAiClient(String apiKey) { - var openAiClientBuilder = OpenAiClient.builder(); - Optional.ofNullable(llmProviderClientConfig.getOpenAiClient()) - .map(LlmProviderClientConfig.OpenAiClientConfig::url) - .ifPresent(baseUrl -> { - if (StringUtils.isNotBlank(baseUrl)) { - openAiClientBuilder.baseUrl(baseUrl); - } - }); - Optional.ofNullable(llmProviderClientConfig.getCallTimeout()) - .ifPresent(callTimeout -> openAiClientBuilder.callTimeout(callTimeout.toJavaDuration())); - Optional.ofNullable(llmProviderClientConfig.getConnectTimeout()) - .ifPresent(connectTimeout -> openAiClientBuilder.connectTimeout(connectTimeout.toJavaDuration())); - Optional.ofNullable(llmProviderClientConfig.getReadTimeout()) - .ifPresent(readTimeout -> openAiClientBuilder.readTimeout(readTimeout.toJavaDuration())); - Optional.ofNullable(llmProviderClientConfig.getWriteTimeout()) - .ifPresent(writeTimeout -> openAiClientBuilder.writeTimeout(writeTimeout.toJavaDuration())); - return openAiClientBuilder - .openAiApiKey(apiKey) - .build(); - } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java index 851acf2034..83d9140f1d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java @@ -22,5 +22,5 @@ void generateStream( void validateRequest(@NonNull ChatCompletionRequest request); - @NonNull Optional getLlmProviderError(Throwable runtimeException); + Optional getLlmProviderError(@NonNull Throwable runtimeException); } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java index e69e5d0964..f7d4d28f60 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java @@ -11,6 +11,7 @@ import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.api.resources.utils.resources.ChatCompletionsClient; import com.comet.opik.api.resources.utils.resources.LlmProviderApiKeyResourceClient; +import com.comet.opik.domain.llmproviders.GeminiModelName; import com.comet.opik.podam.PodamFactoryUtils; import com.redis.testcontainers.RedisContainer; import dev.ai4j.openai4j.chat.ChatCompletionModel; @@ -223,7 +224,9 @@ private static Stream testModelsProvider() { arguments(ChatCompletionModel.GPT_4O_MINI.toString(), LlmProvider.OPEN_AI, UUID.randomUUID().toString()), arguments(AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620.toString(), LlmProvider.ANTHROPIC, - System.getenv("ANTHROPIC_API_KEY"))); + System.getenv("ANTHROPIC_API_KEY")), + arguments(GeminiModelName.GEMINI_1_0_PRO.toString(), LlmProvider.GEMINI, + System.getenv("GEMINI_AI_KEY"))); } @ParameterizedTest diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderClientsMappersTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderClientsMappersTest.java new file mode 100644 index 0000000000..37f073194b --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderClientsMappersTest.java @@ -0,0 +1,127 @@ +package com.comet.opik.domain.llmproviders; + +import com.comet.opik.podam.PodamFactoryUtils; +import dev.ai4j.openai4j.chat.AssistantMessage; +import dev.ai4j.openai4j.chat.ChatCompletionChoice; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.Message; +import dev.ai4j.openai4j.chat.Role; +import dev.ai4j.openai4j.shared.Usage; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import uk.co.jemos.podam.api.PodamFactory; + +import java.util.List; +import java.util.stream.Stream; + +import static dev.langchain4j.data.message.AiMessage.aiMessage; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +public class LlmProviderClientsMappersTest { + private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory(); + + @Nested + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class AnthropicMappers { + @Test + void testToResponse() { + var response = podamFactory.manufacturePojo(AnthropicCreateMessageResponse.class); + + var actual = LlmProviderAnthropicMapper.INSTANCE.toResponse(response); + assertThat(actual).isNotNull(); + assertThat(actual.id()).isEqualTo(response.id); + assertThat(actual.choices()).isEqualTo(List.of(ChatCompletionChoice.builder() + .message(AssistantMessage.builder() + .name(response.content.getFirst().name) + .content(response.content.getFirst().text) + .build()) + .finishReason(response.stopReason) + .build())); + assertThat(actual.usage()).isEqualTo(Usage.builder() + .promptTokens(response.usage.inputTokens) + .completionTokens(response.usage.outputTokens) + .totalTokens(response.usage.inputTokens + response.usage.outputTokens) + .build()); + } + + @Test + void toCreateMessage() { + var request = podamFactory.manufacturePojo(ChatCompletionRequest.class); + + AnthropicCreateMessageRequest actual = LlmProviderAnthropicMapper.INSTANCE + .toCreateMessageRequest(request); + + assertThat(actual).isNotNull(); + assertThat(actual.model).isEqualTo(request.model()); + assertThat(actual.stream).isEqualTo(request.stream()); + assertThat(actual.temperature).isEqualTo(request.temperature()); + assertThat(actual.topP).isEqualTo(request.topP()); + assertThat(actual.stopSequences).isEqualTo(request.stop()); + assertThat(actual.messages).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo( + request.messages().stream() + .filter(message -> List.of(Role.USER, Role.ASSISTANT).contains(message.role())) + .map(LlmProviderAnthropicMapper.INSTANCE::mapToAnthropicMessage) + .toList()); + assertThat(actual.system).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo( + request.messages().stream() + .filter(message -> message.role() == Role.SYSTEM) + .map(LlmProviderAnthropicMapper.INSTANCE::mapToSystemMessage) + .toList()); + } + } + + @Nested + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class GeminiMappers { + @Test + void testToResponse() { + var request = ChatCompletionRequest.builder().model(podamFactory.manufacturePojo(String.class)).build(); + var response = new Response<>(aiMessage(podamFactory.manufacturePojo(String.class)), + new TokenUsage(podamFactory.manufacturePojo(Integer.class), + podamFactory.manufacturePojo(Integer.class)), + FinishReason.STOP); + var actual = LlmProviderGeminiMapper.INSTANCE.toChatCompletionResponse(request, response); + assertThat(actual).isNotNull(); + assertThat(actual.model()).isEqualTo(request.model()); + assertThat(actual.choices()).isEqualTo(List.of(ChatCompletionChoice.builder() + .message(AssistantMessage.builder().content(response.content().text()).build()) + .build())); + assertThat(actual.usage()).isEqualTo(Usage.builder() + .promptTokens(response.tokenUsage().inputTokenCount()) + .completionTokens(response.tokenUsage().outputTokenCount()) + .totalTokens(response.tokenUsage().totalTokenCount()) + .build()); + } + + @ParameterizedTest + @MethodSource + void testToChatMessage(Message message, ChatMessage expected) { + ChatMessage actual = LlmProviderGeminiMapper.INSTANCE.toChatMessage(message); + assertThat(actual).isEqualTo(expected); + } + + private Stream testToChatMessage() { + var content = podamFactory.manufacturePojo(String.class); + return Stream.of( + arguments(AssistantMessage.builder().content(content).build(), AiMessage.from(content)), + arguments(dev.ai4j.openai4j.chat.UserMessage.builder().content(content).build(), + UserMessage.from(content)), + arguments(dev.ai4j.openai4j.chat.SystemMessage.builder().content(content).build(), + SystemMessage.from(content))); + } + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java index 9250183902..c42e4489d0 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.util.List; import java.util.UUID; +import java.util.function.Function; import java.util.stream.Stream; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @@ -68,7 +69,8 @@ void testGetService(String model, LlmProvider llmProvider, Class testGetService() { .map(model -> arguments(model.toString(), LlmProvider.OPEN_AI, LlmProviderOpenAi.class)); var anthropicModels = EnumUtils.getEnumList(AnthropicChatModelName.class).stream() .map(model -> arguments(model.toString(), LlmProvider.ANTHROPIC, LlmProviderAnthropic.class)); + var geminiModels = EnumUtils.getEnumList(GeminiModelName.class).stream() + .map(model -> arguments(model.toString(), LlmProvider.GEMINI, LlmProviderGemini.class)); - return Stream.concat(openAiModels, anthropicModels); + return Stream.of(openAiModels, anthropicModels, geminiModels).flatMap(Function.identity()); } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java index b86d77d830..5e960fc186 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java @@ -13,7 +13,15 @@ import com.comet.opik.podam.manufacturer.ProviderApiKeyManufacturer; import com.comet.opik.podam.manufacturer.ProviderApiKeyUpdateManufacturer; import com.comet.opik.podam.manufacturer.UUIDTypeManufacturer; +import com.comet.opik.podam.manufacturer.anthropic.AnthropicContentManufacturer; +import com.comet.opik.podam.manufacturer.anthropic.AnthropicCreateMessageResponseManufacturer; +import com.comet.opik.podam.manufacturer.anthropic.AnthropicUsageManufacturer; +import com.comet.opik.podam.manufacturer.anthropic.ChatCompletionRequestManufacturer; import com.fasterxml.jackson.databind.JsonNode; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.langchain4j.model.anthropic.internal.api.AnthropicContent; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; +import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage; import jakarta.validation.constraints.DecimalMax; import jakarta.validation.constraints.DecimalMin; import jakarta.validation.constraints.Pattern; @@ -50,6 +58,12 @@ public static PodamFactory newPodamFactory() { strategy.addOrReplaceTypeManufacturer(PromptVersion.class, PromptVersionManufacturer.INSTANCE); strategy.addOrReplaceTypeManufacturer(ProviderApiKey.class, ProviderApiKeyManufacturer.INSTANCE); strategy.addOrReplaceTypeManufacturer(ProviderApiKeyUpdate.class, ProviderApiKeyUpdateManufacturer.INSTANCE); + strategy.addOrReplaceTypeManufacturer(AnthropicContent.class, AnthropicContentManufacturer.INSTANCE); + strategy.addOrReplaceTypeManufacturer(AnthropicUsage.class, AnthropicUsageManufacturer.INSTANCE); + strategy.addOrReplaceTypeManufacturer(AnthropicCreateMessageResponse.class, + AnthropicCreateMessageResponseManufacturer.INSTANCE); + strategy.addOrReplaceTypeManufacturer(ChatCompletionRequest.class, ChatCompletionRequestManufacturer.INSTANCE); + return podamFactory; } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicContentManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicContentManufacturer.java new file mode 100644 index 0000000000..ae26ad926b --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicContentManufacturer.java @@ -0,0 +1,22 @@ +package com.comet.opik.podam.manufacturer.anthropic; + +import dev.langchain4j.model.anthropic.internal.api.AnthropicContent; +import uk.co.jemos.podam.api.AttributeMetadata; +import uk.co.jemos.podam.api.DataProviderStrategy; +import uk.co.jemos.podam.common.ManufacturingContext; +import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer; + +public class AnthropicContentManufacturer extends AbstractTypeManufacturer { + public static final AnthropicContentManufacturer INSTANCE = new AnthropicContentManufacturer(); + + @Override + public AnthropicContent getType(DataProviderStrategy strategy, AttributeMetadata metadata, + ManufacturingContext context) { + var content = new AnthropicContent(); + content.name = strategy.getTypeValue(metadata, context, String.class); + content.text = strategy.getTypeValue(metadata, context, String.class); + content.id = strategy.getTypeValue(metadata, context, String.class); + + return content; + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicCreateMessageResponseManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicCreateMessageResponseManufacturer.java new file mode 100644 index 0000000000..fe7b6c7fb6 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicCreateMessageResponseManufacturer.java @@ -0,0 +1,30 @@ +package com.comet.opik.podam.manufacturer.anthropic; + +import dev.langchain4j.model.anthropic.internal.api.AnthropicContent; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; +import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage; +import uk.co.jemos.podam.api.AttributeMetadata; +import uk.co.jemos.podam.api.DataProviderStrategy; +import uk.co.jemos.podam.common.ManufacturingContext; +import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer; + +import java.util.List; + +public class AnthropicCreateMessageResponseManufacturer + extends + AbstractTypeManufacturer { + public static final AnthropicCreateMessageResponseManufacturer INSTANCE = new AnthropicCreateMessageResponseManufacturer(); + + @Override + public AnthropicCreateMessageResponse getType(DataProviderStrategy strategy, AttributeMetadata metadata, + ManufacturingContext context) { + var response = new AnthropicCreateMessageResponse(); + response.id = strategy.getTypeValue(metadata, context, String.class); + response.model = strategy.getTypeValue(metadata, context, String.class); + response.stopReason = strategy.getTypeValue(metadata, context, String.class); + response.content = List.of(strategy.getTypeValue(metadata, context, AnthropicContent.class)); + response.usage = strategy.getTypeValue(metadata, context, AnthropicUsage.class); + + return response; + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicUsageManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicUsageManufacturer.java new file mode 100644 index 0000000000..dfdb1e7f8d --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/AnthropicUsageManufacturer.java @@ -0,0 +1,21 @@ +package com.comet.opik.podam.manufacturer.anthropic; + +import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage; +import uk.co.jemos.podam.api.AttributeMetadata; +import uk.co.jemos.podam.api.DataProviderStrategy; +import uk.co.jemos.podam.common.ManufacturingContext; +import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer; + +public class AnthropicUsageManufacturer extends AbstractTypeManufacturer { + public static final AnthropicUsageManufacturer INSTANCE = new AnthropicUsageManufacturer(); + + @Override + public AnthropicUsage getType(DataProviderStrategy strategy, AttributeMetadata metadata, + ManufacturingContext context) { + var usage = new AnthropicUsage(); + usage.inputTokens = strategy.getTypeValue(metadata, context, Integer.class); + usage.outputTokens = strategy.getTypeValue(metadata, context, Integer.class); + + return usage; + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/ChatCompletionRequestManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/ChatCompletionRequestManufacturer.java new file mode 100644 index 0000000000..8e3278bd01 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/anthropic/ChatCompletionRequestManufacturer.java @@ -0,0 +1,31 @@ +package com.comet.opik.podam.manufacturer.anthropic; + +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import uk.co.jemos.podam.api.AttributeMetadata; +import uk.co.jemos.podam.api.DataProviderStrategy; +import uk.co.jemos.podam.common.ManufacturingContext; +import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer; + +public class ChatCompletionRequestManufacturer extends AbstractTypeManufacturer { + public static final ChatCompletionRequestManufacturer INSTANCE = new ChatCompletionRequestManufacturer(); + + @Override + public ChatCompletionRequest getType(DataProviderStrategy strategy, AttributeMetadata metadata, + ManufacturingContext context) { + var userMessageContent = strategy.getTypeValue(metadata, context, String.class); + var assistantMessageContent = strategy.getTypeValue(metadata, context, String.class); + var systemMessageContent = strategy.getTypeValue(metadata, context, String.class); + + return ChatCompletionRequest.builder() + .model(strategy.getTypeValue(metadata, context, String.class)) + .stream(strategy.getTypeValue(metadata, context, Boolean.class)) + .temperature(strategy.getTypeValue(metadata, context, Double.class)) + .topP(strategy.getTypeValue(metadata, context, Double.class)) + .stop(strategy.getTypeValue(metadata, context, String.class)) + .addUserMessage(userMessageContent) + .addAssistantMessage(assistantMessageContent) + .addSystemMessage(systemMessageContent) + .maxCompletionTokens(strategy.getTypeValue(metadata, context, Integer.class)) + .build(); + } +}