Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Inference Service Refactoring results format #102429

Merged

Conversation

jonathan-buttner
Copy link
Contributor

@jonathan-buttner jonathan-buttner commented Nov 21, 2023

Specifies the results format for the inference plugin's services. This should match the results proposed here: elastic/elasticsearch-specification#2329

Notable changes

  • Created a new interface InferenceServiceResults that the formats implement to give us some more freedom (rather than using the ml plugin's InferenceResults
  • Created a LegacyTextEmbeddingResults to represent the format of the openai response when it used the InferenceResults
  • Created a TextEmbeddingResults that adheres to the new InferenceServiceResults interface
  • Created a SparseEmbeddingResults that adheres to the new InferenceServiceResults interface
  • The elser and hugging face service use SparseEmbeddingResults, openai uses TextEmbeddingResults

LegacyTextEmbeddingResults

The LegacyTextEmbeddingResults is functionally identical to the new TextEmbeddingResults. The legacy one implements InferenceResults and the new one implements InferenceServiceResults. I wasn't sure if it made sense to separate them or simply have one class, say TextEmbeddingResults that implemented both InferenceResults and InferenceServiceResults. I'm open to either approach. I figured it might be more clear that we're moving away from InferenceResults if I had that class marked as deprecated.

Format

TextEmbeddingResults

{
  "text_embedding": [
    {
      "embedding": [
        0.1
      ]
    },
    {
      "embedding": [
        0.2
      ]
    }
  ]
}

SparseEmbeddingResults

{
  "sparse_embedding": {
    "is_truncated": false,
    "embedding": [
      {
        "token": 0.1
      }
    ]
  }
}

@jonathan-buttner jonathan-buttner added >non-issue :ml Machine learning Team:ML Meta label for the ML team v8.12.0 labels Nov 21, 2023
@@ -183,45 +188,99 @@ public Request build() {

public static class Response extends ActionResponse implements ToXContentObject {

private final List<? extends InferenceResults> results;
private final InferenceServiceResults results;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this class need a thorough review 😬

* @deprecated use {@link TextEmbeddingResults} instead
*/
@Deprecated
public record LegacyTextEmbeddingResults(List<Embedding> embeddings) implements InferenceResults {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could fold this into the TextEmbeddingResults class by simply making it implement InferenceResults as well. I wasn't sure if that was less clear though or how that'd impact adding optional fields in the future. Or how it'd impact us if the InferenceResults interface changes in the future. I'm open to other ideas though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good to keep this as a separate class rather than cluttering up TextEmbeddingResults

*/
@Deprecated
public record LegacyTextEmbeddingResults(List<Embedding> embeddings) implements InferenceResults {
public static final String NAME = "text_embedding_results";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the name that was used in the openai PR here


for (InferenceResults result : results) {
if (result instanceof TextExpansionResults expansionResults) {
isTruncated |= expansionResults.isTruncated();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If for some reason only 1 of the results is truncated we'll mark them all as truncated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the structure would be a list of objects each with is_truncated and embedding properties.

{
  "sparse_embedding": [
    {
      "is_truncated": false,
      "embedding": [
        {
          "token": 0.1
        },
        ...   
      ]
    },
    {
      "is_truncated": true,
      "embedding": [
        {
          "token": 2.0
        },
        ...   
      ]
    }
  ]
}

The matches the text_embedding result structure

public static final String NAME = "text_embedding_results";
public record TextEmbeddingResults(List<Embedding> embeddings) implements InferenceServiceResults {
// TODO: what should the name be here?
public static final String NAME = "text_embedding_results_v2";
Copy link
Contributor Author

@jonathan-buttner jonathan-buttner Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on what to use for the name?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public static final String NAME = "text_embedding_results_v2";
public static final String NAME = "text_embedding_service_results";

🤷

@@ -35,7 +36,12 @@ protected Writeable.Reader<InferenceAction.Response> instanceReader() {

@Override
protected InferenceAction.Response createTestInstance() {
return new InferenceAction.Response(List.of(TextExpansionResultsTests.createRandomResults()));
var result = switch (randomIntBetween(0, 1)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other tests that we should have to get coverage over that if-block in InferenceAction.Response?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if-block can be tested explicitly if you use AbstractBWCWireSerializationTestCase as the test base class.

Implement the method mutateInstanceForVersion and you can simulate mixed version transport comms and assert on the expected output.

AbstractBWCWireSerializationTestCase is in xpack core and should be accessible here

@jonathan-buttner jonathan-buttner marked this pull request as ready for review November 21, 2023 19:30
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

@droberts195 droberts195 changed the title [M] Inference Service Refactoring results format [ML] Inference Service Refactoring results format Nov 21, 2023
@davidkyle
Copy link
Member

@elasticmachine test this please

1 similar comment
@davidkyle
Copy link
Member

@elasticmachine test this please


for (InferenceResults result : results) {
if (result instanceof TextExpansionResults expansionResults) {
isTruncated |= expansionResults.isTruncated();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the structure would be a list of objects each with is_truncated and embedding properties.

{
  "sparse_embedding": [
    {
      "is_truncated": false,
      "embedding": [
        {
          "token": 0.1
        },
        ...   
      ]
    },
    {
      "is_truncated": true,
      "embedding": [
        {
          "token": 2.0
        },
        ...   
      ]
    }
  ]
}

The matches the text_embedding result structure

public static final String NAME = "text_embedding_results";
public record TextEmbeddingResults(List<Embedding> embeddings) implements InferenceServiceResults {
// TODO: what should the name be here?
public static final String NAME = "text_embedding_results_v2";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public static final String NAME = "text_embedding_results_v2";
public static final String NAME = "text_embedding_service_results";

🤷

this(in.readCollectionAsList(Embedding::new), in.readBoolean());
}

public static SparseEmbeddingResults create(List<? extends InferenceResults> results) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public static SparseEmbeddingResults create(List<? extends InferenceResults> results) {
public static SparseEmbeddingResults of(List<? extends InferenceResults> results) {

of is more idiomatic of Java


public class TestUtils {

public static String toJsonString(ToXContentFragment entity) throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strings.toString handles ToXContentFragment and pretty printing, is there a reason not to use that?

https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/common/Strings.java#L803

@@ -35,7 +36,12 @@ protected Writeable.Reader<InferenceAction.Response> instanceReader() {

@Override
protected InferenceAction.Response createTestInstance() {
return new InferenceAction.Response(List.of(TextExpansionResultsTests.createRandomResults()));
var result = switch (randomIntBetween(0, 1)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if-block can be tested explicitly if you use AbstractBWCWireSerializationTestCase as the test base class.

Implement the method mutateInstanceForVersion and you can simulate mixed version transport comms and assert on the expected output.

AbstractBWCWireSerializationTestCase is in xpack core and should be accessible here

Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

* @deprecated use {@link TextEmbeddingResults} instead
*/
@Deprecated
public record LegacyTextEmbeddingResults(List<Embedding> embeddings) implements InferenceResults {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good to keep this as a separate class rather than cluttering up TextEmbeddingResults

Comment on lines 80 to 81
// Map<String, Object> sparseEmbeddingMap = new LinkedHashMap<>();
// sparseEmbeddingMap.put(EMBEDDING, embeddingList);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Map<String, Object> sparseEmbeddingMap = new LinkedHashMap<>();
// sparseEmbeddingMap.put(EMBEDDING, embeddingList);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops thanks.

@davidkyle
Copy link
Member

@elasticmachine update branch

@jonathan-buttner jonathan-buttner merged commit a022483 into elastic:main Nov 28, 2023
@jonathan-buttner jonathan-buttner deleted the ml-infer-results-format branch November 28, 2023 14:39
timgrein pushed a commit to timgrein/elasticsearch that referenced this pull request Nov 30, 2023
* Adding results

* Fixing merge issues

* Understanding the complexity

* Making progress on tests

* Tests working

* Some comments

* More comments

* Addressing pr feedback

* Fixing test

* Fixing test

* Fixing up comments and dead code

---------

Co-authored-by: Elastic Machine <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
:ml Machine learning >non-issue Team:ML Meta label for the ML team v8.12.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants