Skip to content

Commit

Permalink
Add support for highlighting the new format of the semantic text field (
Browse files Browse the repository at this point in the history
elastic#119604)

This change adapts the semantic highlighter to work with the new format introduced in elastic#119183.

Co-authored-by: Kathleen DeRusso <[email protected]>
  • Loading branch information
jimczi and kderusso committed Jan 7, 2025
1 parent d4d5b1c commit b7ec719
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 4,366 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ protected static BreakIterator getBreakIterator(SearchHighlightContext.Field fie
}
}

protected static String convertFieldValue(MappedFieldType type, Object value) {
public static String convertFieldValue(MappedFieldType type, Object value) {
if (value instanceof BytesRef) {
return type.valueForDisplay(value).toString();
} else {
return value.toString();
}
}

protected static String mergeFieldValues(List<Object> fieldValues, char valuesSeparator) {
public static String mergeFieldValues(List<Object> fieldValues, char valuesSeparator) {
// postings highlighter accepts all values in a single string, as offsets etc. need to match with content
// loaded from stored fields, we merge all values using a proper separator
String rawValue = Strings.collectionToDelimitedString(fieldValues, String.valueOf(valuesSeparator));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,31 @@
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.fetch.subphase.highlight.DefaultHighlighter;
import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightUtils;
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceField;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SemanticTextFieldType;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

import static org.elasticsearch.lucene.search.uhighlight.CustomUnifiedHighlighter.MULTIVAL_SEP_CHAR;

/**
* A {@link Highlighter} designed for the {@link SemanticTextFieldMapper}.
Expand All @@ -49,20 +60,19 @@
public class SemanticTextHighlighter implements Highlighter {
public static final String NAME = "semantic";

private record OffsetAndScore(int offset, float score) {}
private record OffsetAndScore(int index, OffsetSourceFieldMapper.OffsetSource offset, float score) {}

@Override
public boolean canHighlight(MappedFieldType fieldType) {
if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
// TODO: Implement highlighting when using inference metadata fields
return semanticTextFieldType.useLegacyFormat();
}
return false;
return fieldType instanceof SemanticTextFieldType;
}

@Override
public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException {
SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldContext.fieldType;
if (canHighlight(fieldContext.fieldType) == false) {
return null;
}
SemanticTextFieldType fieldType = (SemanticTextFieldType) fieldContext.fieldType;
if (fieldType.getEmbeddingsField() == null) {
// nothing indexed yet
return null;
Expand Down Expand Up @@ -105,28 +115,36 @@ public HighlightField highlight(FieldHighlightContext fieldContext) throws IOExc
int size = Math.min(chunks.size(), numberOfFragments);
if (fieldContext.field.fieldOptions().scoreOrdered() == false) {
chunks = chunks.subList(0, size);
chunks.sort(Comparator.comparingInt(c -> c.offset));
chunks.sort(Comparator.comparingInt(c -> c.index));
}
Text[] snippets = new Text[size];
List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources(
fieldType.getChunksField().fullPath(),
fieldContext.hitContext.source().source()
);
final Function<OffsetAndScore, String> offsetToContent;
if (fieldType.useLegacyFormat()) {
List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources(
fieldType.getChunksField().fullPath(),
fieldContext.hitContext.source().source()
);
offsetToContent = entry -> getContentFromLegacyNestedSources(fieldType.name(), entry, nestedSources);
} else {
Map<String, String> fieldToContent = new HashMap<>();
offsetToContent = entry -> {
String content = fieldToContent.computeIfAbsent(entry.offset().field(), key -> {
try {
return extractFieldContent(
fieldContext.context.getSearchExecutionContext(),
fieldContext.hitContext,
entry.offset.field()
);
} catch (IOException e) {
throw new UncheckedIOException("Error extracting field content from field " + entry.offset.field(), e);
}
});
return content.substring(entry.offset().start(), entry.offset().end());
};
}
for (int i = 0; i < size; i++) {
var chunk = chunks.get(i);
if (nestedSources.size() <= chunk.offset) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"Invalid content detected for field [%s]: the chunks size is [%d], "
+ "but a reference to offset [%d] was found in the result.",
fieldType.name(),
nestedSources.size(),
chunk.offset
)
);
}
String content = (String) nestedSources.get(chunk.offset).get(SemanticTextField.CHUNKED_TEXT_FIELD);
String content = offsetToContent.apply(chunk);
if (content == null) {
throw new IllegalStateException(
String.format(
Expand All @@ -143,10 +161,43 @@ public HighlightField highlight(FieldHighlightContext fieldContext) throws IOExc
return new HighlightField(fieldContext.fieldName, snippets);
}

private String extractFieldContent(SearchExecutionContext searchContext, FetchSubPhase.HitContext hitContext, String sourceField)
throws IOException {
var sourceFieldType = searchContext.getMappingLookup().getFieldType(sourceField);
if (sourceFieldType == null) {
return null;
}

var values = HighlightUtils.loadFieldValues(sourceFieldType, searchContext, hitContext)
.stream()
.<Object>map((s) -> DefaultHighlighter.convertFieldValue(sourceFieldType, s))
.toList();
if (values.size() == 0) {
return null;
}
return DefaultHighlighter.mergeFieldValues(values, MULTIVAL_SEP_CHAR);
}

private String getContentFromLegacyNestedSources(String fieldName, OffsetAndScore cand, List<Map<?, ?>> nestedSources) {
if (nestedSources.size() <= cand.index) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"Invalid content detected for field [%s]: the chunks size is [%d], "
+ "but a reference to offset [%d] was found in the result.",
fieldName,
nestedSources.size(),
cand.index
)
);
}
return (String) nestedSources.get(cand.index).get(SemanticTextField.CHUNKED_TEXT_FIELD);
}

private List<OffsetAndScore> extractOffsetAndScores(
SearchExecutionContext context,
LeafReader reader,
SemanticTextFieldMapper.SemanticTextFieldType fieldType,
SemanticTextFieldType fieldType,
int docId,
List<Query> leafQueries
) throws IOException {
Expand All @@ -164,10 +215,31 @@ private List<OffsetAndScore> extractOffsetAndScores(
} else if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
return List.of();
}

OffsetSourceField.OffsetSourceLoader offsetReader = null;
if (fieldType.useLegacyFormat() == false) {
var terms = reader.terms(fieldType.getOffsetsField().fullPath());
if (terms == null) {
// The field is empty
return List.of();
}
offsetReader = OffsetSourceField.loader(terms);
}

List<OffsetAndScore> results = new ArrayList<>();
int offset = 0;
int index = 0;
while (scorer.docID() < docId) {
results.add(new OffsetAndScore(offset++, scorer.score()));
if (offsetReader != null) {
var offset = offsetReader.advanceTo(scorer.docID());
if (offset == null) {
throw new IllegalStateException(
"Cannot highlight field [" + fieldType.name() + "], missing offsets for doc [" + docId + "]"
);
}
results.add(new OffsetAndScore(index++, offset, scorer.score()));
} else {
results.add(new OffsetAndScore(index++, null, scorer.score()));
}
if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.inference.highlight;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriterConfig;
Expand Down Expand Up @@ -51,7 +53,6 @@
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.junit.Before;
import org.mockito.Mockito;

import java.io.IOException;
Expand All @@ -71,31 +72,35 @@ public class SemanticTextHighlighterTests extends MapperServiceTestCase {
private static final String SEMANTIC_FIELD_E5 = "body-e5";
private static final String SEMANTIC_FIELD_ELSER = "body-elser";

private Map<String, Object> queries;
private final boolean useLegacyFormat;
private final Map<String, Object> queries;

@Override
protected Collection<? extends Plugin> getPlugins() {
return List.of(new InferencePlugin(Settings.EMPTY));
public SemanticTextHighlighterTests(boolean useLegacyFormat) throws IOException {
this.useLegacyFormat = useLegacyFormat;
var input = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("queries.json"));
this.queries = XContentHelper.convertToMap(input, false, XContentType.JSON).v2();
}

@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
return List.of(new Object[] { true }, new Object[] { false });
}

@Override
@Before
public void setUp() throws Exception {
super.setUp();
var input = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("queries.json"));
this.queries = XContentHelper.convertToMap(input, false, XContentType.JSON).v2();
protected Collection<? extends Plugin> getPlugins() {
return List.of(new InferencePlugin(Settings.EMPTY));
}

@SuppressWarnings("unchecked")
public void testDenseVector() throws Exception {
var mapperService = createDefaultMapperService();
var mapperService = createDefaultMapperService(useLegacyFormat);
Map<String, Object> queryMap = (Map<String, Object>) queries.get("dense_vector_1");
float[] vector = readDenseVector(queryMap.get("embeddings"));
var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_E5);
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null, null);
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), knnQuery, ScoreMode.Max);
var shardRequest = createShardSearchRequest(nestedQueryBuilder);
var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON);
var sourceToParse = new SourceToParse("0", readSampleDoc(useLegacyFormat), XContentType.JSON);

String[] expectedScorePassages = ((List<String>) queryMap.get("expected_by_score")).toArray(String[]::new);
for (int i = 0; i < expectedScorePassages.length; i++) {
Expand Down Expand Up @@ -124,7 +129,7 @@ public void testDenseVector() throws Exception {

@SuppressWarnings("unchecked")
public void testSparseVector() throws Exception {
var mapperService = createDefaultMapperService();
var mapperService = createDefaultMapperService(useLegacyFormat);
Map<String, Object> queryMap = (Map<String, Object>) queries.get("sparse_vector_1");
List<WeightedToken> tokens = readSparseVector(queryMap.get("embeddings"));
var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_ELSER);
Expand All @@ -138,7 +143,7 @@ public void testSparseVector() throws Exception {
);
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), sparseQuery, ScoreMode.Max);
var shardRequest = createShardSearchRequest(nestedQueryBuilder);
var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON);
var sourceToParse = new SourceToParse("0", readSampleDoc(useLegacyFormat), XContentType.JSON);

String[] expectedScorePassages = ((List<String>) queryMap.get("expected_by_score")).toArray(String[]::new);
for (int i = 0; i < expectedScorePassages.length; i++) {
Expand All @@ -165,9 +170,11 @@ public void testSparseVector() throws Exception {
);
}

private MapperService createDefaultMapperService() throws IOException {
private MapperService createDefaultMapperService(boolean useLegacyFormat) throws IOException {
var mappings = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("mappings.json"));
var settings = Settings.builder().put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), true).build();
var settings = Settings.builder()
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
.build();
return createMapperService(settings, mappings.utf8ToString());
}

Expand Down Expand Up @@ -282,7 +289,8 @@ private ShardSearchRequest createShardSearchRequest(QueryBuilder queryBuilder) {
return new ShardSearchRequest(OriginalIndices.NONE, request, new ShardId("index", "index", 0), 0, 1, AliasFilter.EMPTY, 1, 0, null);
}

private BytesReference readSampleDoc(String fileName) throws IOException {
private BytesReference readSampleDoc(boolean useLegacyFormat) throws IOException {
String fileName = useLegacyFormat ? "sample-doc-legacy.json.gz" : "sample-doc.json.gz";
try (var in = new GZIPInputStream(SemanticTextHighlighterTests.class.getResourceAsStream(fileName))) {
return new BytesArray(new BytesRef(in.readAllBytes()));
}
Expand Down
Binary file not shown.
Loading

0 comments on commit b7ec719

Please sign in to comment.