Skip to content

Commit

Permalink
OPIK-611 pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
idoberko2 committed Jan 13, 2025
1 parent db57fa6 commit 138ca1b
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,17 @@ default List<AnthropicTextContent> mapToSystemMessages(@NonNull ChatCompletionRe
}

default AnthropicMessage mapToAnthropicMessage(@NonNull Message message) {
if (message instanceof AssistantMessage assistantMessage) {
return AnthropicMessage.builder()
return switch (message) {
case AssistantMessage assistantMessage -> AnthropicMessage.builder()
.role(AnthropicRole.ASSISTANT)
.content(List.of(new AnthropicTextContent(assistantMessage.content())))
.build();
}

if (message instanceof UserMessage userMessage) {
return AnthropicMessage.builder()
case UserMessage userMessage -> AnthropicMessage.builder()
.role(AnthropicRole.USER)
.content(List.of(toAnthropicMessageContent(userMessage.content())))
.build();
}

throw new BadRequestException("unexpected message role: " + message.role());
default -> throw new BadRequestException("unexpected message role: " + message.role());
};
}

default AnthropicMessageContent toAnthropicMessageContent(@NonNull Object rawContent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,12 @@ public AnthropicClient newAnthropicClient(@NonNull String apiKey) {
var anthropicClientBuilder = AnthropicClient.builder();
Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
.map(LlmProviderClientConfig.AnthropicClientConfig::url)
.ifPresent(url -> {
if (StringUtils.isNotEmpty(url)) {
anthropicClientBuilder.baseUrl(url);
}
});
.filter(StringUtils::isNotEmpty)
.ifPresent(anthropicClientBuilder::baseUrl);
Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
.map(LlmProviderClientConfig.AnthropicClientConfig::version)
.ifPresent(version -> {
if (StringUtils.isNotBlank(version)) {
anthropicClientBuilder.version(version);
}
});
.filter(StringUtils::isNotBlank)
.ifPresent(anthropicClientBuilder::version);
Optional.ofNullable(llmProviderClientConfig.getLogRequests())
.ifPresent(anthropicClientBuilder::logRequests);
Optional.ofNullable(llmProviderClientConfig.getLogResponses())
Expand All @@ -48,11 +42,8 @@ public OpenAiClient newOpenAiClient(@NonNull String apiKey) {
var openAiClientBuilder = OpenAiClient.builder();
Optional.ofNullable(llmProviderClientConfig.getOpenAiClient())
.map(LlmProviderClientConfig.OpenAiClientConfig::url)
.ifPresent(baseUrl -> {
if (StringUtils.isNotBlank(baseUrl)) {
openAiClientBuilder.baseUrl(baseUrl);
}
});
.filter(StringUtils::isNotBlank)
.ifPresent(openAiClientBuilder::baseUrl);
Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
.ifPresent(callTimeout -> openAiClientBuilder.callTimeout(callTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getConnectTimeout())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,14 @@
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.Message;
import dev.ai4j.openai4j.chat.Role;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicRole;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
Expand All @@ -41,59 +37,30 @@ public class LlmProviderClientsMappersTest {
@Nested
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class AnthropicMappers {
// anthropic POJOs don't have setters or builders, therefore podam can't manufacture objects correctly

@Test
void testToResponse() {
var content = new AnthropicContent();
content.name = podamFactory.manufacturePojo(String.class);
content.text = podamFactory.manufacturePojo(String.class);
content.id = podamFactory.manufacturePojo(String.class);

var usage = new AnthropicUsage();
usage.inputTokens = podamFactory.manufacturePojo(Integer.class);
usage.outputTokens = podamFactory.manufacturePojo(Integer.class);

var response = new AnthropicCreateMessageResponse();
response.id = podamFactory.manufacturePojo(String.class);
response.model = podamFactory.manufacturePojo(String.class);
response.stopReason = podamFactory.manufacturePojo(String.class);
response.content = List.of(content);
response.usage = usage;
var response = podamFactory.manufacturePojo(AnthropicCreateMessageResponse.class);

var actual = LlmProviderAnthropicMapper.INSTANCE.toResponse(response);
assertThat(actual).isNotNull();
assertThat(actual.id()).isEqualTo(response.id);
assertThat(actual.choices()).isEqualTo(List.of(ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.name(content.name)
.content(content.text)
.name(response.content.getFirst().name)
.content(response.content.getFirst().text)
.build())
.finishReason(response.stopReason)
.build()));
assertThat(actual.usage()).isEqualTo(Usage.builder()
.promptTokens(usage.inputTokens)
.completionTokens(usage.outputTokens)
.totalTokens(usage.inputTokens + usage.outputTokens)
.promptTokens(response.usage.inputTokens)
.completionTokens(response.usage.outputTokens)
.totalTokens(response.usage.inputTokens + response.usage.outputTokens)
.build());
}

@Test
void toCreateMessage() {
var userMessageContent = podamFactory.manufacturePojo(String.class);
var assistantMessageContent = podamFactory.manufacturePojo(String.class);
var systemMessageContent = podamFactory.manufacturePojo(String.class);
ChatCompletionRequest request = ChatCompletionRequest.builder()
.model(podamFactory.manufacturePojo(String.class))
.stream(podamFactory.manufacturePojo(Boolean.class))
.temperature(podamFactory.manufacturePojo(Double.class))
.topP(podamFactory.manufacturePojo(Double.class))
.stop(podamFactory.manufacturePojo(String.class))
.addUserMessage(userMessageContent)
.addAssistantMessage(assistantMessageContent)
.addSystemMessage(systemMessageContent)
.maxCompletionTokens(podamFactory.manufacturePojo(Integer.class))
.build();
var request = podamFactory.manufacturePojo(ChatCompletionRequest.class);

AnthropicCreateMessageRequest actual = LlmProviderAnthropicMapper.INSTANCE
.toCreateMessageRequest(request);
Expand All @@ -104,16 +71,16 @@ void toCreateMessage() {
assertThat(actual.temperature).isEqualTo(request.temperature());
assertThat(actual.topP).isEqualTo(request.topP());
assertThat(actual.stopSequences).isEqualTo(request.stop());
assertThat(actual.messages).containsExactlyInAnyOrder(
AnthropicMessage.builder()
.role(AnthropicRole.USER)
.content(List.of(new AnthropicTextContent(userMessageContent)))
.build(),
AnthropicMessage.builder()
.role(AnthropicRole.ASSISTANT)
.content(List.of(new AnthropicTextContent(assistantMessageContent)))
.build());
assertThat(actual.system).isEqualTo(List.of(new AnthropicTextContent(systemMessageContent)));
assertThat(actual.messages).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo(
request.messages().stream()
.filter(message -> List.of(Role.USER, Role.ASSISTANT).contains(message.role()))
.map(LlmProviderAnthropicMapper.INSTANCE::mapToAnthropicMessage)
.toList());
assertThat(actual.system).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo(
request.messages().stream()
.filter(message -> message.role() == Role.SYSTEM)
.map(LlmProviderAnthropicMapper.INSTANCE::mapToAnthropicMessage)
.toList());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
import com.comet.opik.podam.manufacturer.ProviderApiKeyManufacturer;
import com.comet.opik.podam.manufacturer.ProviderApiKeyUpdateManufacturer;
import com.comet.opik.podam.manufacturer.UUIDTypeManufacturer;
import com.comet.opik.podam.manufacturer.anthropic.AnthropicContentManufacturer;
import com.comet.opik.podam.manufacturer.anthropic.AnthropicCreateMessageResponseManufacturer;
import com.comet.opik.podam.manufacturer.anthropic.AnthropicUsageManufacturer;
import com.comet.opik.podam.manufacturer.anthropic.ChatCompletionRequestManufacturer;
import com.fasterxml.jackson.databind.JsonNode;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import jakarta.validation.constraints.DecimalMax;
import jakarta.validation.constraints.DecimalMin;
import jakarta.validation.constraints.Pattern;
Expand Down Expand Up @@ -50,6 +58,12 @@ public static PodamFactory newPodamFactory() {
strategy.addOrReplaceTypeManufacturer(PromptVersion.class, PromptVersionManufacturer.INSTANCE);
strategy.addOrReplaceTypeManufacturer(ProviderApiKey.class, ProviderApiKeyManufacturer.INSTANCE);
strategy.addOrReplaceTypeManufacturer(ProviderApiKeyUpdate.class, ProviderApiKeyUpdateManufacturer.INSTANCE);
strategy.addOrReplaceTypeManufacturer(AnthropicContent.class, AnthropicContentManufacturer.INSTANCE);
strategy.addOrReplaceTypeManufacturer(AnthropicUsage.class, AnthropicUsageManufacturer.INSTANCE);
strategy.addOrReplaceTypeManufacturer(AnthropicCreateMessageResponse.class,
AnthropicCreateMessageResponseManufacturer.INSTANCE);
strategy.addOrReplaceTypeManufacturer(ChatCompletionRequest.class, ChatCompletionRequestManufacturer.INSTANCE);

return podamFactory;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.comet.opik.podam.manufacturer.anthropic;

import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
import uk.co.jemos.podam.api.AttributeMetadata;
import uk.co.jemos.podam.api.DataProviderStrategy;
import uk.co.jemos.podam.common.ManufacturingContext;
import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;

public class AnthropicContentManufacturer extends AbstractTypeManufacturer<AnthropicContent> {
public static final AnthropicContentManufacturer INSTANCE = new AnthropicContentManufacturer();

@Override
public AnthropicContent getType(DataProviderStrategy strategy, AttributeMetadata metadata,
ManufacturingContext context) {
var content = new AnthropicContent();
content.name = strategy.getTypeValue(metadata, context, String.class);
content.text = strategy.getTypeValue(metadata, context, String.class);
content.id = strategy.getTypeValue(metadata, context, String.class);

return content;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.comet.opik.podam.manufacturer.anthropic;

import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import uk.co.jemos.podam.api.AttributeMetadata;
import uk.co.jemos.podam.api.DataProviderStrategy;
import uk.co.jemos.podam.common.ManufacturingContext;
import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;

import java.util.List;

public class AnthropicCreateMessageResponseManufacturer
extends
AbstractTypeManufacturer<AnthropicCreateMessageResponse> {
public static final AnthropicCreateMessageResponseManufacturer INSTANCE = new AnthropicCreateMessageResponseManufacturer();

@Override
public AnthropicCreateMessageResponse getType(DataProviderStrategy strategy, AttributeMetadata metadata,
ManufacturingContext context) {
var response = new AnthropicCreateMessageResponse();
response.id = strategy.getTypeValue(metadata, context, String.class);
response.model = strategy.getTypeValue(metadata, context, String.class);
response.stopReason = strategy.getTypeValue(metadata, context, String.class);
response.content = List.of(strategy.getTypeValue(metadata, context, AnthropicContent.class));
response.usage = strategy.getTypeValue(metadata, context, AnthropicUsage.class);

return response;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.comet.opik.podam.manufacturer.anthropic;

import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage;
import uk.co.jemos.podam.api.AttributeMetadata;
import uk.co.jemos.podam.api.DataProviderStrategy;
import uk.co.jemos.podam.common.ManufacturingContext;
import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;

public class AnthropicUsageManufacturer extends AbstractTypeManufacturer<AnthropicUsage> {
public static final AnthropicUsageManufacturer INSTANCE = new AnthropicUsageManufacturer();

@Override
public AnthropicUsage getType(DataProviderStrategy strategy, AttributeMetadata metadata,
ManufacturingContext context) {
var usage = new AnthropicUsage();
usage.inputTokens = strategy.getTypeValue(metadata, context, Integer.class);
usage.outputTokens = strategy.getTypeValue(metadata, context, Integer.class);

return usage;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.comet.opik.podam.manufacturer.anthropic;

import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import uk.co.jemos.podam.api.AttributeMetadata;
import uk.co.jemos.podam.api.DataProviderStrategy;
import uk.co.jemos.podam.common.ManufacturingContext;
import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer;

public class ChatCompletionRequestManufacturer extends AbstractTypeManufacturer<ChatCompletionRequest> {
public static final ChatCompletionRequestManufacturer INSTANCE = new ChatCompletionRequestManufacturer();

@Override
public ChatCompletionRequest getType(DataProviderStrategy strategy, AttributeMetadata metadata,
ManufacturingContext context) {
var userMessageContent = strategy.getTypeValue(metadata, context, String.class);
var assistantMessageContent = strategy.getTypeValue(metadata, context, String.class);
var systemMessageContent = strategy.getTypeValue(metadata, context, String.class);

return ChatCompletionRequest.builder()
.model(strategy.getTypeValue(metadata, context, String.class))
.stream(strategy.getTypeValue(metadata, context, Boolean.class))
.temperature(strategy.getTypeValue(metadata, context, Double.class))
.topP(strategy.getTypeValue(metadata, context, Double.class))
.stop(strategy.getTypeValue(metadata, context, String.class))
.addUserMessage(userMessageContent)
.addAssistantMessage(assistantMessageContent)
.addSystemMessage(systemMessageContent)
.maxCompletionTokens(strategy.getTypeValue(metadata, context, Integer.class))
.build();
}
}

0 comments on commit 138ca1b

Please sign in to comment.