Skip to content

Commit

Permalink
Fix tests, address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Borys Tkachenko authored and Borys Tkachenko committed Jan 10, 2025
1 parent 50028b1 commit 1b455fe
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 5 deletions.
3 changes: 1 addition & 2 deletions apps/opik-backend/src/main/java/com/comet/opik/api/Span.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ public record Span(
@JsonView({Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy,
@JsonView({
Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List<FeedbackScore> feedbackScores,
@JsonView({Span.View.Public.class,
Span.View.Write.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @DecimalMin("0.0") BigDecimal totalEstimatedCost,
@JsonView({Span.View.Public.class, Span.View.Write.class}) @DecimalMin("0.0") BigDecimal totalEstimatedCost,
String totalEstimatedCostVersion,
@JsonView({
Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY, description = "Duration in milliseconds as a decimal number to support sub-millisecond precision") Double duration){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ private void bindUpdateParams(SpanUpdate spanUpdate, Statement statement, boolea
Optional.ofNullable(spanUpdate.errorInfo())
.ifPresent(errorInfo -> statement.bind("error_info", JsonUtils.readTree(errorInfo).toString()));

if (Objects.nonNull(spanUpdate.totalEstimatedCost())) {
if (spanUpdate.totalEstimatedCost() != null) {
// Update with new manually set cost
statement.bind("total_estimated_cost", spanUpdate.totalEstimatedCost().toString());
statement.bind("total_estimated_cost_version", "");
Expand Down Expand Up @@ -1032,9 +1032,11 @@ private ST newUpdateTemplate(SpanUpdate spanUpdate, String sql, boolean isManual
.ifPresent(usage -> template.add("usage", usage.toString()));
Optional.ofNullable(spanUpdate.errorInfo())
.ifPresent(errorInfo -> template.add("error_info", JsonUtils.readTree(errorInfo).toString()));

// If we have manual cost in update OR if we can calculate it and user didn't set manual cost before
if ((!isManualCostExist && StringUtils.isNotBlank(spanUpdate.model()) && Objects.nonNull(spanUpdate.usage()))
|| Objects.nonNull(spanUpdate.totalEstimatedCost())) {
boolean shouldRecalculateEstimatedCost = !isManualCostExist && StringUtils.isNotBlank(spanUpdate.model())
&& spanUpdate.usage() != null;
if (spanUpdate.totalEstimatedCost() != null || shouldRecalculateEstimatedCost) {
template.add("total_estimated_cost", "total_estimated_cost");
template.add("total_estimated_cost_version", "total_estimated_cost_version");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,7 @@ private BigDecimal createSpans(
"prompt_tokens", RANDOM.nextInt(),
"completion_tokens", RANDOM.nextInt()))
.traceId(trace.id())
.totalEstimatedCost(null)
.build())
.toList();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,7 @@ private Project buildProjectStats(Project project, String apiKey, String workspa
.model(spanResourceClient.randomModelPrice().getName())
.traceId(trace.id())
.projectName(trace.projectName())
.totalEstimatedCost(null)
.build())
.toList();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3509,6 +3509,7 @@ void getTraceWithCost(String model) {
.usage(Map.of("completion_tokens", Math.abs(factory.manufacturePojo(Integer.class)),
"prompt_tokens", Math.abs(factory.manufacturePojo(Integer.class))))
.model(model)
.totalEstimatedCost(null)
.build())
.collect(Collectors.toList());

Expand Down Expand Up @@ -5237,6 +5238,7 @@ void findWithUsage() {
.map(span -> span.toBuilder()
.projectName(projectName)
.traceId(trace.id())
.totalEstimatedCost(null)
.build()))
.collect(Collectors.groupingBy(Span::traceId));
batchCreateSpansAndAssert(
Expand Down Expand Up @@ -5281,6 +5283,7 @@ void findWithoutUsage() {
.traceId(trace.id())
.startTime(trace.startTime())
.usage(null)
.totalEstimatedCost(null)
.build()))
.toList();
batchCreateSpansAndAssert(spans, apiKey, workspaceName);
Expand Down Expand Up @@ -5402,6 +5405,7 @@ void getTraceStats__whenTracesHaveCostEstimation__thenReturnTotalCostEstimation(
.traceId(trace.id())
.projectName(projectName)
.feedbackScores(null)
.totalEstimatedCost(null)
.build())
.toList();

Expand Down Expand Up @@ -6967,6 +6971,7 @@ void getTraceStats__whenFilterUsageEqual__thenReturnTracesFiltered(String usageK
.projectName(projectName)
.traceId(trace.id())
.usage(Map.of(usageKey, otherUsageValue))
.totalEstimatedCost(null)
.build())
.collect(Collectors.toMap(Span::traceId, Function.identity()));
traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder()
Expand Down Expand Up @@ -7022,6 +7027,7 @@ void getTraceStats__whenFilterUsageGreaterThan__thenReturnTracesFiltered(String
.projectName(projectName)
.traceId(trace.id())
.usage(Map.of(usageKey, 123))
.totalEstimatedCost(null)
.build())
.collect(Collectors.toMap(Span::traceId, Function.identity()));
traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder()
Expand Down Expand Up @@ -7072,6 +7078,7 @@ void getTraceStats__whenFilterUsageGreaterThanEqual__thenReturnTracesFiltered(St
.projectName(projectName)
.traceId(trace.id())
.usage(Map.of(usageKey, 123))
.totalEstimatedCost(null)
.build())
.collect(Collectors.toMap(Span::traceId, Function.identity()));
traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder()
Expand Down Expand Up @@ -7122,6 +7129,7 @@ void getTraceStats__whenFilterUsageLessThan__thenReturnTracesFiltered(String usa
.projectName(projectName)
.traceId(trace.id())
.usage(Map.of(usageKey, 456))
.totalEstimatedCost(null)
.build())
.collect(Collectors.toMap(Span::traceId, Function.identity()));
traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder()
Expand Down Expand Up @@ -7172,6 +7180,7 @@ void getTraceStats__whenFilterUsageLessThanEqual__thenReturnTracesFiltered(Strin
.projectName(projectName)
.traceId(trace.id())
.usage(Map.of(usageKey, 456))
.totalEstimatedCost(null)
.build())
.collect(Collectors.toMap(Span::traceId, Function.identity()));
traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder()
Expand Down

0 comments on commit 1b455fe

Please sign in to comment.