diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Span.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Span.java index 0c01d03d06..d1bfd1b361 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Span.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Span.java @@ -52,8 +52,7 @@ public record Span( @JsonView({Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy, @JsonView({ Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List feedbackScores, - @JsonView({Span.View.Public.class, - Span.View.Write.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @DecimalMin("0.0") BigDecimal totalEstimatedCost, + @JsonView({Span.View.Public.class, Span.View.Write.class}) @DecimalMin("0.0") BigDecimal totalEstimatedCost, String totalEstimatedCostVersion, @JsonView({ Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY, description = "Duration in milliseconds as a decimal number to support sub-millisecond precision") Double duration){ diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java index 6336f80f95..8930fe5a0d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java @@ -997,7 +997,7 @@ private void bindUpdateParams(SpanUpdate spanUpdate, Statement statement, boolea Optional.ofNullable(spanUpdate.errorInfo()) .ifPresent(errorInfo -> statement.bind("error_info", JsonUtils.readTree(errorInfo).toString())); - if (Objects.nonNull(spanUpdate.totalEstimatedCost())) { + if (spanUpdate.totalEstimatedCost() != null) { // Update with new manually set cost statement.bind("total_estimated_cost", spanUpdate.totalEstimatedCost().toString()); statement.bind("total_estimated_cost_version", ""); @@ -1032,9 +1032,11 @@ private ST newUpdateTemplate(SpanUpdate spanUpdate, String sql, boolean isManual .ifPresent(usage -> template.add("usage", usage.toString())); Optional.ofNullable(spanUpdate.errorInfo()) .ifPresent(errorInfo -> template.add("error_info", JsonUtils.readTree(errorInfo).toString())); + // If we have manual cost in update OR if we can calculate it and user didn't set manual cost before - if ((!isManualCostExist && StringUtils.isNotBlank(spanUpdate.model()) && Objects.nonNull(spanUpdate.usage())) - || Objects.nonNull(spanUpdate.totalEstimatedCost())) { + boolean shouldRecalculateEstimatedCost = !isManualCostExist && StringUtils.isNotBlank(spanUpdate.model()) + && spanUpdate.usage() != null; + if (spanUpdate.totalEstimatedCost() != null || shouldRecalculateEstimatedCost) { template.add("total_estimated_cost", "total_estimated_cost"); template.add("total_estimated_cost_version", "total_estimated_cost_version"); } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java index 9ccb370369..0ffaa1936e 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java @@ -708,6 +708,7 @@ private BigDecimal createSpans( "prompt_tokens", RANDOM.nextInt(), "completion_tokens", RANDOM.nextInt())) .traceId(trace.id()) + .totalEstimatedCost(null) .build()) .toList(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java index 6317b2b556..5d48a4dc88 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java @@ -1313,6 +1313,7 @@ private Project buildProjectStats(Project project, String apiKey, String workspa .model(spanResourceClient.randomModelPrice().getName()) .traceId(trace.id()) .projectName(trace.projectName()) + .totalEstimatedCost(null) .build()) .toList(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java index 7f26f22b6a..dcd7434cbe 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java @@ -3509,6 +3509,7 @@ void getTraceWithCost(String model) { .usage(Map.of("completion_tokens", Math.abs(factory.manufacturePojo(Integer.class)), "prompt_tokens", Math.abs(factory.manufacturePojo(Integer.class)))) .model(model) + .totalEstimatedCost(null) .build()) .collect(Collectors.toList()); @@ -5237,6 +5238,7 @@ void findWithUsage() { .map(span -> span.toBuilder() .projectName(projectName) .traceId(trace.id()) + .totalEstimatedCost(null) .build())) .collect(Collectors.groupingBy(Span::traceId)); batchCreateSpansAndAssert( @@ -5281,6 +5283,7 @@ void findWithoutUsage() { .traceId(trace.id()) .startTime(trace.startTime()) .usage(null) + .totalEstimatedCost(null) .build())) .toList(); batchCreateSpansAndAssert(spans, apiKey, workspaceName); @@ -5402,6 +5405,7 @@ void getTraceStats__whenTracesHaveCostEstimation__thenReturnTotalCostEstimation( .traceId(trace.id()) .projectName(projectName) .feedbackScores(null) + .totalEstimatedCost(null) .build()) .toList(); @@ -6967,6 +6971,7 @@ void getTraceStats__whenFilterUsageEqual__thenReturnTracesFiltered(String usageK .projectName(projectName) .traceId(trace.id()) .usage(Map.of(usageKey, otherUsageValue)) + .totalEstimatedCost(null) .build()) .collect(Collectors.toMap(Span::traceId, Function.identity())); traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder() @@ -7022,6 +7027,7 @@ void getTraceStats__whenFilterUsageGreaterThan__thenReturnTracesFiltered(String .projectName(projectName) .traceId(trace.id()) .usage(Map.of(usageKey, 123)) + .totalEstimatedCost(null) .build()) .collect(Collectors.toMap(Span::traceId, Function.identity())); traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder() @@ -7072,6 +7078,7 @@ void getTraceStats__whenFilterUsageGreaterThanEqual__thenReturnTracesFiltered(St .projectName(projectName) .traceId(trace.id()) .usage(Map.of(usageKey, 123)) + .totalEstimatedCost(null) .build()) .collect(Collectors.toMap(Span::traceId, Function.identity())); traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder() @@ -7122,6 +7129,7 @@ void getTraceStats__whenFilterUsageLessThan__thenReturnTracesFiltered(String usa .projectName(projectName) .traceId(trace.id()) .usage(Map.of(usageKey, 456)) + .totalEstimatedCost(null) .build()) .collect(Collectors.toMap(Span::traceId, Function.identity())); traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder() @@ -7172,6 +7180,7 @@ void getTraceStats__whenFilterUsageLessThanEqual__thenReturnTracesFiltered(Strin .projectName(projectName) .traceId(trace.id()) .usage(Map.of(usageKey, 456)) + .totalEstimatedCost(null) .build()) .collect(Collectors.toMap(Span::traceId, Function.identity())); traceIdToSpanMap.put(traces.getFirst().id(), traceIdToSpanMap.get(traces.getFirst().id()).toBuilder()