Skip to content

Commit

Permalink
finishing PR review requests
Browse files Browse the repository at this point in the history
  • Loading branch information
ldaugusto committed Jan 8, 2025
1 parent 8f20678 commit 830a5be
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 255 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.NotNull;
import lombok.Builder;
import org.jdbi.v3.json.Json;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record AutomationRuleEvaluatorUpdate(
@NotNull String name,
@Json @NotNull AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode code,
@NotNull AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode code,
@NotNull Float samplingRate) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package com.comet.opik.api.resources.v1.events;

import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge;
import com.comet.opik.api.Trace;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.jayway.jsonpath.JsonPath;
import dev.ai4j.openai4j.chat.Message;
import dev.ai4j.openai4j.chat.SystemMessage;
import dev.ai4j.openai4j.chat.UserMessage;
import lombok.Builder;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.text.StringSubstitutor;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

@UtilityClass
@Slf4j
public class LlmAsJudgeMessageRender {

/**
* Render the rule evaluator message template using the values from an actual trace.
*
* As the rule my consist in multiple messages, we check each one of them for variables to fill.
* Then we go through every variable template to replace them for the value from the trace.
*
* @param trace the trace with value to use to replace template variables
* @param evaluatorCode the evaluator
* @return a list of AI messages, with templates rendered
*/
public static List<Message> renderMessages(Trace trace,
AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode evaluatorCode) {
// prepare the map of replacements to use in all messages
var parsedVariables = variableMapping(evaluatorCode.variables());

// extract the actual value from the Trace
var replacements = parsedVariables.stream().map(mapper -> {
var traceSection = switch (mapper.traceSection) {
case INPUT -> trace.input();
case OUTPUT -> trace.output();
case METADATA -> trace.metadata();
};

return mapper.toBuilder()
.valueToReplace(extractFromJson(traceSection, mapper.jsonPath()))
.build();
})
.filter(mapper -> mapper.valueToReplace() != null)
.collect(
Collectors.toMap(LlmAsJudgeMessageRender.MessageVariableMapping::variableName,
LlmAsJudgeMessageRender.MessageVariableMapping::valueToReplace));

// will convert all '{{key}}' into 'value'
// TODO: replace with Mustache Java to be in confirm with FE
var templateRenderer = new StringSubstitutor(replacements, "{{", "}}");

// render the message templates from evaluator rule
return evaluatorCode.messages().stream()
.map(templateMessage -> {
var renderedMessage = templateRenderer.replace(templateMessage.content());

return switch (templateMessage.role()) {
case USER -> UserMessage.from(renderedMessage);
case SYSTEM -> SystemMessage.from(renderedMessage);
default -> {
log.info("No mapping for message role type {}", templateMessage.role());
yield null;
}
};
})
.filter(Objects::nonNull)
.toList();
}

/**
* Parse evaluator\'s variable mapper into an usable list of
*
* @param evaluatorVariables a map with variables and a path into a trace input/output/metadata to replace
* @return a parsed list of mappings, easier to use for the template rendering
*/
public static List<MessageVariableMapping> variableMapping(Map<String, String> evaluatorVariables) {
return evaluatorVariables.entrySet().stream()
.map(mapper -> {
var templateVariable = mapper.getKey();
var tracePath = mapper.getValue();

var builder = MessageVariableMapping.builder().variableName(templateVariable);

if (tracePath.startsWith("input.")) {
builder.traceSection(TraceSection.INPUT)
.jsonPath("$" + tracePath.substring("input".length()));
} else if (tracePath.startsWith("output.")) {
builder.traceSection(TraceSection.OUTPUT)
.jsonPath("$" + tracePath.substring("output".length()));
} else if (tracePath.startsWith("metadata.")) {
builder.traceSection(TraceSection.METADATA)
.jsonPath("$" + tracePath.substring("metadata".length()));
} else {
log.info("Couldn't map trace path '{}' into a input/output/metadata path", tracePath);
return null;
}

return builder.build();
})
.filter(Objects::nonNull)
.toList();
}

final ObjectMapper objectMapper = new ObjectMapper();

String extractFromJson(JsonNode json, String path) {
try {
// JsonPath didnt work with JsonNode, even explicitly using JacksonJsonProvider, so we convert to a Map
var forcedObject = objectMapper.convertValue(json, Map.class);
return JsonPath.parse(forcedObject).read(path);
} catch (Exception e) {
log.debug("Couldn't find path '{}' inside json {}: {}", path, json, e.getMessage());
return null;
}
}

public enum TraceSection {
INPUT,
OUTPUT,
METADATA
}

@Builder(toBuilder = true)
public record MessageVariableMapping(TraceSection traceSection, String variableName, String jsonPath,
String valueToReplace) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,15 @@
import com.comet.opik.domain.AutomationRuleEvaluatorService;
import com.comet.opik.domain.ChatCompletionService;
import com.comet.opik.domain.FeedbackScoreService;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.eventbus.EventBus;
import com.google.common.eventbus.Subscribe;
import com.jayway.jsonpath.JsonPath;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.Message;
import dev.ai4j.openai4j.chat.SystemMessage;
import dev.ai4j.openai4j.chat.UserMessage;
import jakarta.inject.Inject;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.text.StringSubstitutor;
import ru.vyarus.dropwizard.guice.module.installer.feature.eager.EagerSingleton;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.UUID;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -99,118 +90,12 @@ private void score(Trace trace, String workspaceId, AutomationRuleEvaluatorLlmAs
var baseRequestBuilder = ChatCompletionRequest.builder()
.model(evaluator.getCode().model().name())
.temperature(evaluator.getCode().model().temperature())
.messages(renderMessages(trace, evaluator.getCode()))
.messages(LlmAsJudgeMessageRender.renderMessages(trace, evaluator.getCode()))
.build();

// TODO: call AI Proxy and parse response into 1+ FeedbackScore

// TODO: store FeedbackScores
}

final ObjectMapper objectMapper = new ObjectMapper();

String extractFromJson(JsonNode json, String path) {
try {
// JsonPath didnt work with JsonNode, even explicitly using JacksonJsonProvider, so we convert to a Map
var forcedObject = objectMapper.convertValue(json, Map.class);
return JsonPath.parse(forcedObject).read(path);
}
catch (Exception e) {
log.debug("Couldn't find path '{}' inside json {}: {}", path, json, e.getMessage());
return null;
}
}

/**
* Render the rule evaluator message template using the values from an actual trace.
*
* As the rule my consist in multiple messages, we check each one of them for variables to fill.
* Then we go through every variable template to replace them for the value from the trace.
*
* @param trace the trace with value to use to replace template variables
* @param evaluatorCode the evaluator
* @return a list of AI messages, with templates rendered
*/
List<Message> renderMessages(Trace trace, AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode evaluatorCode) {
// prepare the map of replacements to use in all messages, extracting the actual value from the Trace
var parsedVariables = variableMapping(evaluatorCode.variables());
var replacements = parsedVariables.stream().map(mapper -> {
var traceSection = switch (mapper.traceSection) {
case INPUT -> trace.input();
case OUTPUT -> trace.output();
case METADATA -> trace.metadata();
};

return mapper.toBuilder()
.valueToReplace(extractFromJson(traceSection, mapper.jsonPath()))
.build();
})
.filter(mapper -> mapper.valueToReplace() != null)
.collect(
Collectors.toMap(MessageVariableMapping::variableName, MessageVariableMapping::valueToReplace));

// will convert all '{{key}}' into 'value'
var templateRenderer = new StringSubstitutor(replacements, "{{", "}}");

// render the message templates from evaluator rule
return evaluatorCode.messages().stream()
.map(templateMessage -> {
var renderedMessage = templateRenderer.replace(templateMessage.content());

return switch (templateMessage.role()) {
case USER -> UserMessage.from(renderedMessage);
case SYSTEM -> SystemMessage.from(renderedMessage);
default -> {
log.info("No mapping for message role type {}", templateMessage.role());
yield null;
}
};
})
.filter(Objects::nonNull)
.toList();
}

/**
* Parse evaluator\'s variable mapper into an usable list of
*
* @param evaluatorVariables a map with variables and a path into a trace input/output/metadata to replace
* @return a parsed list of mappings, easier to use for the template rendering
*/
List<MessageVariableMapping> variableMapping(Map<String, String> evaluatorVariables) {
return evaluatorVariables.entrySet().stream()
.map(mapper -> {
var templateVariable = mapper.getKey();
var tracePath = mapper.getValue();

var builder = MessageVariableMapping.builder().variableName(templateVariable);

if (tracePath.startsWith("input.")) {
builder.traceSection(TraceSection.INPUT)
.jsonPath("$" + tracePath.substring("input".length()));
} else if (tracePath.startsWith("output.")) {
builder.traceSection(TraceSection.OUTPUT)
.jsonPath("$" + tracePath.substring("output".length()));
} else if (tracePath.startsWith("metadata.")) {
builder.traceSection(TraceSection.METADATA)
.jsonPath("$" + tracePath.substring("metadata".length()));
} else {
log.info("Couldn't map trace path '{}' into a input/output/metadata path", tracePath);
return null;
}

return builder.build();
})
.filter(Objects::nonNull)
.toList();
}

enum TraceSection {
INPUT,
OUTPUT,
METADATA
}
@Builder(toBuilder = true)
record MessageVariableMapping(TraceSection traceSection, String variableName, String jsonPath,
String valueToReplace) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ interface AutomationModelEvaluatorMapper {
LlmAsJudgeAutomationRuleEvaluatorModel map(AutomationRuleEvaluatorLlmAsJudge dto);

AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode map(LlmAsJudgeAutomationRuleEvaluatorModel.LlmAsJudgeCode detail);

LlmAsJudgeAutomationRuleEvaluatorModel.LlmAsJudgeCode map(AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode code);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.comet.opik.api.AutomationRule;
import com.comet.opik.api.AutomationRuleEvaluatorCriteria;
import com.comet.opik.api.AutomationRuleEvaluatorType;
import com.comet.opik.api.AutomationRuleEvaluatorUpdate;
import com.comet.opik.infrastructure.db.JsonNodeArgumentFactory;
import com.comet.opik.infrastructure.db.UUIDArgumentFactory;
import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory;
Expand Down Expand Up @@ -35,11 +34,10 @@ public interface AutomationRuleEvaluatorDAO extends AutomationRuleDAO {
@SqlUpdate("""
UPDATE automation_rule_evaluators
SET code = :rule.code,
last_updated_by = :userName
last_updated_by = :rule.lastUpdatedBy
WHERE id = :id
""")
int updateEvaluator(@Bind("id") UUID id, @BindMethods("rule") AutomationRuleEvaluatorUpdate ruleUpdate,
@Bind("userName") String userName);
<T> int updateEvaluator(@Bind("id") UUID id, @BindMethods("rule") AutomationRuleEvaluatorModel<T> rule);

@SqlQuery("""
SELECT rule.id, rule.project_id, rule.action, rule.name, rule.sampling_rate, evaluator.type, evaluator.code,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ public void update(@NonNull UUID id, @NonNull UUID projectId, @NonNull String wo
try {
int resultBase = dao.updateBaseRule(id, projectId, workspaceId, evaluatorUpdate.name(),
evaluatorUpdate.samplingRate(), userName);
int resultEval = dao.updateEvaluator(id, evaluatorUpdate, userName);

var modelUpdate = LlmAsJudgeAutomationRuleEvaluatorModel.builder()
.code(AutomationModelEvaluatorMapper.INSTANCE.map(evaluatorUpdate.code()))
.lastUpdatedBy(userName)
.build();

int resultEval = dao.updateEvaluator(id, modelUpdate);

if (resultEval == 0 || resultBase == 0) {
throw newNotFoundException();
Expand Down
Loading

0 comments on commit 830a5be

Please sign in to comment.