Skip to content

Commit

Permalink
[OPIK-611] support gemini models in playground (#987)
Browse files Browse the repository at this point in the history
* OPIK-611 gemini infra

* OPIK-611 get service gemini failing test

* OPIK-611 get service gemini failing test green

* OPIK-611 gemini e2e failing test

* OPIK-611 gemini e2e failing test green [WIP - only create]

* OPIK-611 refactor

* OPIK-611 gemini e2e failing tests green

* OPIK-611 refactor

* OPIK-611 post rebase adjustments

* OPIK-611 move client generation to a module

* OPIK-611 anthropic mappers

* OPIK-611 gemini mappers

* OPIK-611 mappers coverage

* OPIK-611 pr comments

* OPIK-611 pr comments

* OPIK-611 refactor

* OPIK-611 minor fixes

* OPIK-611 pr comments

* OPIK-611 changed gemini model creation to map struct

* OPIK-611 fix test
  • Loading branch information
idoberko2 authored Jan 13, 2025
1 parent 13d5ee9 commit 1d5bdac
Show file tree
Hide file tree
Showing 22 changed files with 720 additions and 234 deletions.
4 changes: 4 additions & 0 deletions apps/opik-backend/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-anthropic</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-google-ai-gemini</artifactId>
</dependency>

<!-- Test -->

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik;

import com.comet.opik.api.error.JsonInvalidFormatExceptionMapper;
import com.comet.opik.domain.llmproviders.LlmProviderClientModule;
import com.comet.opik.infrastructure.ConfigurationModule;
import com.comet.opik.infrastructure.EncryptionUtils;
import com.comet.opik.infrastructure.OpikConfiguration;
Expand Down Expand Up @@ -72,7 +73,7 @@ public void initialize(Bootstrap<OpikConfiguration> bootstrap) {
.withPlugins(new SqlObjectPlugin(), new Jackson2Plugin()))
.modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(),
new RateLimitModule(), new NameGeneratorModule(), new HttpModule(), new EventModule(),
new ConfigurationModule(), new BiModule())
new ConfigurationModule(), new BiModule(), new LlmProviderClientModule())
.installers(JobGuiceyInstaller.class)
.listen(new OpikGuiceyLifecycleEventListener())
.enableAutoConfig()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
@RequiredArgsConstructor
public enum LlmProvider {
OPEN_AI("openai"),
ANTHROPIC("anthropic");
ANTHROPIC("anthropic"),
GEMINI("gemini"),
;

@JsonValue
private final String value;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.comet.opik.domain.llmproviders;

import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.ai4j.openai4j.chat.Role;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import lombok.NonNull;

import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

public record ChunkedResponseHandler(
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose,
@NonNull Consumer<Throwable> handleError,
@NonNull String model) implements StreamingResponseHandler<AiMessage> {

@Override
public void onNext(@NonNull String content) {
handleMessage.accept(ChatCompletionResponse.builder()
.model(model)
.choices(List.of(ChatCompletionChoice.builder()
.delta(Delta.builder()
.content(content)
.role(Role.ASSISTANT)
.build())
.build()))
.build());
}

@Override
public void onComplete(@NonNull Response<AiMessage> response) {
handleMessage.accept(ChatCompletionResponse.builder()
.model(model)
.choices(List.of(ChatCompletionChoice.builder()
.delta(Delta.builder()
.content("")
.role(Role.ASSISTANT)
.build())
.build()))
.usage(Usage.builder()
.promptTokens(response.tokenUsage().inputTokenCount())
.completionTokens(response.tokenUsage().outputTokenCount())
.totalTokens(response.tokenUsage().totalTokenCount())
.build())
.id(Optional.ofNullable(response.metadata().get("id")).map(Object::toString).orElse(null))
.build());
handleClose.run();
}

@Override
public void onError(@NonNull Throwable throwable) {
handleError.accept(throwable);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.comet.opik.domain.llmproviders;

import lombok.RequiredArgsConstructor;

/*
Langchain4j doesn't provide gemini models enum.
This information is taken from: https://ai.google.dev/gemini-api/docs/models/gemini
*/
@RequiredArgsConstructor
public enum GeminiModelName {
GEMINI_2_0_FLASH("gemini-2.0-flash-exp"),
GEMINI_1_5_FLASH("gemini-1.5-flash"),
GEMINI_1_5_FLASH_8B("gemini-1.5-flash-8b"),
GEMINI_1_5_PRO("gemini-1.5-pro"),
GEMINI_1_0_PRO("gemini-1.0-pro"),
TEXT_EMBEDDING("text-embedding-004"),
AQA("aqa");

private final String value;

@Override
public String toString() {
return value;
}
}
Original file line number Diff line number Diff line change
@@ -1,68 +1,33 @@
package com.comet.opik.domain.llmproviders;

import com.comet.opik.infrastructure.LlmProviderClientConfig;
import dev.ai4j.openai4j.chat.AssistantMessage;
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.ai4j.openai4j.chat.Message;
import dev.ai4j.openai4j.chat.Role;
import dev.ai4j.openai4j.chat.SystemMessage;
import dev.ai4j.openai4j.chat.UserMessage;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicRole;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice;
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
import dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException;
import dev.langchain4j.model.output.Response;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.ws.rs.BadRequestException;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

import static com.comet.opik.domain.ChatCompletionService.ERROR_EMPTY_MESSAGES;
import static com.comet.opik.domain.ChatCompletionService.ERROR_NO_COMPLETION_TOKENS;

@RequiredArgsConstructor
@Slf4j
class LlmProviderAnthropic implements LlmProviderService {
private final @NonNull LlmProviderClientConfig llmProviderClientConfig;
private final @NonNull AnthropicClient anthropicClient;

public LlmProviderAnthropic(@NonNull LlmProviderClientConfig llmProviderClientConfig, @NonNull String apiKey) {
this.llmProviderClientConfig = llmProviderClientConfig;
this.anthropicClient = newClient(apiKey);
}

@Override
public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
var response = anthropicClient.createMessage(toAnthropicCreateMessageRequest(request));
var mapper = LlmProviderAnthropicMapper.INSTANCE;
var response = anthropicClient.createMessage(mapper.toCreateMessageRequest(request));

return ChatCompletionResponse.builder()
.id(response.id)
.model(response.model)
.choices(response.content.stream().map(content -> toChatCompletionChoice(response, content))
.toList())
.usage(Usage.builder()
.promptTokens(response.usage.inputTokens)
.completionTokens(response.usage.outputTokens)
.totalTokens(response.usage.inputTokens + response.usage.outputTokens)
.build())
.build();
return mapper.toResponse(response);
}

@Override
Expand All @@ -72,7 +37,7 @@ public void generateStream(
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose, @NonNull Consumer<Throwable> handleError) {
validateRequest(request);
anthropicClient.createMessage(toAnthropicCreateMessageRequest(request),
anthropicClient.createMessage(LlmProviderAnthropicMapper.INSTANCE.toCreateMessageRequest(request),
new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model()));
}

Expand All @@ -88,151 +53,12 @@ public void validateRequest(@NonNull ChatCompletionRequest request) {
}

@Override
public @NonNull Optional<ErrorMessage> getLlmProviderError(Throwable runtimeException) {
public Optional<ErrorMessage> getLlmProviderError(@NonNull Throwable runtimeException) {
if (runtimeException instanceof AnthropicHttpException anthropicHttpException) {
return Optional.of(new ErrorMessage(anthropicHttpException.statusCode(),
anthropicHttpException.getMessage()));
}

return Optional.empty();
}

private AnthropicCreateMessageRequest toAnthropicCreateMessageRequest(ChatCompletionRequest request) {
var builder = AnthropicCreateMessageRequest.builder();
Optional.ofNullable(request.toolChoice())
.ifPresent(toolChoice -> builder.toolChoice(AnthropicToolChoice.from(
request.toolChoice().toString())));
return builder
.stream(request.stream())
.model(request.model())
.messages(request.messages().stream()
.filter(message -> List.of(Role.ASSISTANT, Role.USER).contains(message.role()))
.map(this::toMessage).toList())
.system(request.messages().stream()
.filter(message -> message.role() == Role.SYSTEM)
.map(this::toSystemMessage).toList())
.temperature(request.temperature())
.topP(request.topP())
.stopSequences(request.stop())
.maxTokens(request.maxCompletionTokens())
.build();
}

private AnthropicMessage toMessage(Message message) {
if (message.role() == Role.ASSISTANT) {
return AnthropicMessage.builder()
.role(AnthropicRole.ASSISTANT)
.content(List.of(new AnthropicTextContent(((AssistantMessage) message).content())))
.build();
}

if (message.role() == Role.USER) {
return AnthropicMessage.builder()
.role(AnthropicRole.USER)
.content(List.of(toAnthropicMessageContent(((UserMessage) message).content())))
.build();
}

throw new BadRequestException("unexpected message role: " + message.role());
}

private AnthropicTextContent toSystemMessage(Message message) {
if (message.role() != Role.SYSTEM) {
throw new BadRequestException("expecting only system role, got: " + message.role());
}

return new AnthropicTextContent(((SystemMessage) message).content());
}

private AnthropicMessageContent toAnthropicMessageContent(Object rawContent) {
if (rawContent instanceof String content) {
return new AnthropicTextContent(content);
}

throw new BadRequestException("only text content is supported");
}

private ChatCompletionChoice toChatCompletionChoice(
AnthropicCreateMessageResponse response, AnthropicContent content) {
return ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.name(content.name)
.content(content.text)
.build())
.finishReason(response.stopReason)
.build();
}

private AnthropicClient newClient(String apiKey) {
var anthropicClientBuilder = AnthropicClient.builder();
Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
.map(LlmProviderClientConfig.AnthropicClientConfig::url)
.ifPresent(url -> {
if (StringUtils.isNotEmpty(url)) {
anthropicClientBuilder.baseUrl(url);
}
});
Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
.map(LlmProviderClientConfig.AnthropicClientConfig::version)
.ifPresent(version -> {
if (StringUtils.isNotBlank(version)) {
anthropicClientBuilder.version(version);
}
});
Optional.ofNullable(llmProviderClientConfig.getLogRequests())
.ifPresent(anthropicClientBuilder::logRequests);
Optional.ofNullable(llmProviderClientConfig.getLogResponses())
.ifPresent(anthropicClientBuilder::logResponses);
// anthropic client builder only receives one timeout variant
Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
.ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration()));
return anthropicClientBuilder
.apiKey(apiKey)
.build();
}

private record ChunkedResponseHandler(
Consumer<ChatCompletionResponse> handleMessage,
Runnable handleClose,
Consumer<Throwable> handleError,
String model) implements StreamingResponseHandler<AiMessage> {

@Override
public void onNext(String s) {
handleMessage.accept(ChatCompletionResponse.builder()
.model(model)
.choices(List.of(ChatCompletionChoice.builder()
.delta(Delta.builder()
.content(s)
.role(Role.ASSISTANT)
.build())
.build()))
.build());
}

@Override
public void onComplete(Response<AiMessage> response) {
handleMessage.accept(ChatCompletionResponse.builder()
.model(model)
.choices(List.of(ChatCompletionChoice.builder()
.delta(Delta.builder()
.content("")
.role(Role.ASSISTANT)
.build())
.build()))
.usage(Usage.builder()
.promptTokens(response.tokenUsage().inputTokenCount())
.completionTokens(response.tokenUsage().outputTokenCount())
.totalTokens(response.tokenUsage().totalTokenCount())
.build())
.id((String) response.metadata().get("id"))
.build());
handleClose.run();
}

@Override
public void onError(Throwable throwable) {
handleError.accept(throwable);
}
}
}
Loading

0 comments on commit 1d5bdac

Please sign in to comment.