Skip to content

Commit

Permalink
[OPIK-582] Get experiment by name endpoint (#891)
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisTkachenko authored Dec 16, 2024
1 parent 4e02d4e commit 956003b
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 13 deletions.
13 changes: 13 additions & 0 deletions apps/opik-backend/src/main/java/com/comet/opik/api/Identifier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.NotBlank;
import lombok.Builder;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record Identifier(@NotBlank String name) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.comet.opik.api.ExperimentsDelete;
import com.comet.opik.api.FeedbackDefinition;
import com.comet.opik.api.FeedbackScoreNames;
import com.comet.opik.api.Identifier;
import com.comet.opik.api.resources.v1.priv.validate.ExperimentParamsValidator;
import com.comet.opik.domain.ExperimentItemService;
import com.comet.opik.domain.ExperimentService;
Expand Down Expand Up @@ -159,6 +160,28 @@ public Response deleteExperimentsById(
return Response.noContent().build();
}

@POST
@Path("/retrieve")
@Operation(operationId = "getExperimentByName", summary = "Get experiment by name", description = "Get experiment by name", responses = {
@ApiResponse(responseCode = "200", description = "Experiments resource", content = @Content(schema = @Schema(implementation = Experiment.class))),
@ApiResponse(responseCode = "404", description = "Not found", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))
})
@JsonView(Experiment.View.Public.class)
public Response getExperimentByName(
@RequestBody(content = @Content(schema = @Schema(implementation = Identifier.class))) @NotNull @Valid Identifier identifier) {

String workspaceId = requestContext.get().getWorkspaceId();
String name = identifier.name();

log.info("Finding experiment by name '{}' on workspace_id '{}'", name, workspaceId);
var experiment = experimentService.getByName(name)
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();
log.info("Found experiment by name '{}' on workspace_id '{}'", name, workspaceId);

return Response.ok(experiment).build();
}

// Experiment Item Resources

@GET
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.comet.opik.api.FeedbackScoreAverage;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import io.opentelemetry.instrumentation.annotations.WithSpan;
import io.r2dbc.spi.Connection;
Expand Down Expand Up @@ -98,7 +99,7 @@ LEFT JOIN (
;
""";

private static final String SELECT_BY_ID = """
private static final String SELECT_BY = """
SELECT
e.workspace_id as workspace_id,
e.dataset_id as dataset_id,
Expand Down Expand Up @@ -161,8 +162,9 @@ LEFT JOIN (
SELECT
*
FROM experiments
WHERE id = :id
AND workspace_id = :workspace_id
WHERE workspace_id = :workspace_id
<if(id)> AND id = :id <endif>
<if(name)> AND name = :name <endif>
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
) AS e
Expand Down Expand Up @@ -473,18 +475,31 @@ private String getOrDefault(JsonNode jsonNode) {

@WithSpan
Mono<Experiment> getById(@NonNull UUID id) {
log.info("Getting experiment by id '{}'", id);
var template = new ST(SELECT_BY);
template.add("id", id.toString());
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> getById(id, connection))
.flatMapMany(connection -> get(template.render(), connection, statement -> statement.bind("id", id)))
.flatMap(this::mapToDto)
.singleOrEmpty();
}

private Publisher<? extends Result> getById(UUID id, Connection connection) {
log.info("Getting experiment by id '{}'", id);
var statement = connection.createStatement(SELECT_BY_ID)
.bind("id", id)
@WithSpan
Mono<Experiment> getByName(@NonNull String name) {
log.info("Getting experiment by name '{}'", name);
var template = new ST(SELECT_BY);
template.add("name", name);
return Mono.from(connectionFactory.create())
.flatMapMany(
connection -> get(template.render(), connection, statement -> statement.bind("name", name)))
.flatMap(this::mapToDto)
.singleOrEmpty();
}

private Publisher<? extends Result> get(String query, Connection connection, Function<Statement, Statement> bind) {
var statement = connection.createStatement(query)
.bind("entity_type", FeedbackScoreDAO.EntityType.TRACE.getType());
return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
return makeFluxContextAware(bindWorkspaceIdToFlux(bind.apply(statement)));
}

private Publisher<Experiment> mapToDto(Result result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,18 @@ public Flux<Experiment> findByName(String name) {
@WithSpan
public Mono<Experiment> getById(@NonNull UUID id) {
log.info("Getting experiment by id '{}'", id);
return experimentDAO.getById(id)
.switchIfEmpty(Mono.defer(() -> Mono.error(newNotFoundException(id))))
return enrichExperiment(experimentDAO.getById(id), "Not found experiment with id '%s'".formatted(id));
}

@WithSpan
public Mono<Experiment> getByName(@NonNull String name) {
log.info("Getting experiment by name '{}'", name);
return enrichExperiment(experimentDAO.getByName(name), "Not found experiment with name '%s'".formatted(name));
}

private Mono<Experiment> enrichExperiment(Mono<Experiment> experimentMono, String errorMsg) {
return experimentMono
.switchIfEmpty(Mono.defer(() -> Mono.error(newNotFoundException(errorMsg))))
.flatMap(experiment -> Mono.deferContextual(ctx -> {
String workspaceId = ctx.get(RequestContext.WORKSPACE_ID);
Set<UUID> promptVersionIds = experiment.promptVersion() != null
Expand Down Expand Up @@ -283,8 +293,7 @@ private ClientErrorException newConflictException(UUID id) {
return new ClientErrorException(message, Response.Status.CONFLICT);
}

private NotFoundException newNotFoundException(UUID id) {
String message = "Not found experiment with id '%s'".formatted(id);
private NotFoundException newNotFoundException(String message) {
log.info(message);
return new NotFoundException(message);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.comet.opik.api.FeedbackScoreBatch;
import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.FeedbackScoreNames;
import com.comet.opik.api.Identifier;
import com.comet.opik.api.Project;
import com.comet.opik.api.Prompt;
import com.comet.opik.api.PromptVersion;
Expand Down Expand Up @@ -1659,6 +1660,50 @@ void createAndGet() {

}

@Test
void createAndGetByName() {
var expectedExperiment = generateExperiment();
createAndAssert(expectedExperiment, API_KEY, TEST_WORKSPACE);

getAndAssert(expectedExperiment.id(), expectedExperiment, TEST_WORKSPACE, API_KEY);

try (var actualResponse = client.target(getExperimentsPath())
.path("retrieve")
.request()
.header(HttpHeaders.AUTHORIZATION, API_KEY)
.header(WORKSPACE_HEADER, TEST_WORKSPACE)
.post(Entity.json(new Identifier(expectedExperiment.name())))) {

assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200);
var actualExperiment = actualResponse.readEntity(Experiment.class);
assertThat(actualExperiment.id()).isEqualTo(expectedExperiment.id());

assertThat(actualExperiment)
.usingRecursiveComparison()
.ignoringFields(EXPERIMENT_IGNORED_FIELDS)
.isEqualTo(expectedExperiment);
}
}

@Test
void getByNameNotFound() {
String name = UUID.randomUUID().toString();
var expectedError = new ErrorMessage(404, "Not found experiment with name '%s'".formatted(name));
try (var actualResponse = client.target(getExperimentsPath())
.path("retrieve")
.request()
.header(HttpHeaders.AUTHORIZATION, API_KEY)
.header(WORKSPACE_HEADER, TEST_WORKSPACE)
.post(Entity.json(new Identifier(name)))) {

assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(404);

var actualError = actualResponse.readEntity(ErrorMessage.class);

assertThat(actualError).isEqualTo(expectedError);
}
}

@Test
void createAndGetFeedbackAvg() {
var expectedExperiment = podamFactory.manufacturePojo(Experiment.class).toBuilder()
Expand Down

0 comments on commit 956003b

Please sign in to comment.