Skip to content

Commit

Permalink
Correct bit * byte and bit * float script comparisons (#117404) (#117507
Browse files Browse the repository at this point in the history
)

I goofed on the bit * byte and bit * float comparisons. Naturally, these
should be bigendian and compare the dimensions with the binary ones
appropriately.

Additionally, I added a test to ensure that this is handled correctly.

(cherry picked from commit 374c88a)
  • Loading branch information
benwtrent authored Nov 25, 2024
1 parent dae91cd commit 8d16529
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 32 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/117404.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117404
summary: Correct bit * byte and bit * float script comparisons
area: Vector Search
type: bug
issues: []
4 changes: 4 additions & 0 deletions docs/reference/vectors/vector-functions.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ When using `bit` vectors, not all the vector functions are available. The suppor
this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is
the sum of the floating point values using the stored `bit` vector as a mask.

NOTE: When comparing `floats` and `bytes` with `bit` vectors, the `bit` vector is treated as a mask in big-endian order.
For example, if the `bit` vector is `10100001` (e.g. the single byte value `161`) and its compared
with array of values `[1, 2, 3, 4, 5, 6, 7, 8]` the `dotProduct` will be `1 + 3 + 8 = 16`.

Here is an example of using dot-product with bit vectors.

[source,console]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ public static long ipByteBinByte(byte[] q, byte[] d) {
/**
* Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* When comparing the bits with the bytes, they are done in "big endian" order. For example, if the byte vector
* is [1, 2, 3, 4, 5, 6, 7, 8] and the bit vector is [0b10000000], the inner product will be 1.0.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
Expand All @@ -63,9 +65,9 @@ public static int ipByteBit(byte[] q, byte[] d) {
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
for (int j = Byte.SIZE - 1; j >= 0; j--) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
}
}
}
Expand All @@ -75,6 +77,8 @@ public static int ipByteBit(byte[] q, byte[] d) {
/**
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* When comparing the bits with the floats, they are done in "big endian" order. For example, if the float vector
* is [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] and the bit vector is [0b10000000], the inner product will be 1.0.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
Expand All @@ -86,9 +90,9 @@ public static float ipFloatBit(float[] q, byte[] d) {
float result = 0;
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
for (int j = Byte.SIZE - 1; j >= 0; j--) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();

public void testIpByteBit() {
byte[] q = new byte[16];
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
random().nextBytes(q);
int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
assertEquals(expected, ESVectorUtil.ipByteBit(q, d));
}

public void testIpFloatBit() {
float[] q = new float[16];
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
random().nextFloat();
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
}

public void testBitAndCount() {
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ setup:
capabilities:
- method: POST
path: /_search
capabilities: [ multi_dense_vector_script_max_sim ]
capabilities: [ multi_dense_vector_script_max_sim_with_bugfix ]
test_runner_features: capabilities
reason: "Support for multi dense vector max-sim functions capability required"
- skip:
Expand Down Expand Up @@ -136,10 +136,10 @@ setup:
- match: {hits.total: 2}

- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 190, error: 0.01}}
- close_to: {hits.hits.0._score: {value: 220, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 125, error: 0.01}}
- close_to: {hits.hits.1._score: {value: 147, error: 0.01}}
---
"Test max-sim inv hamming scoring":
- skip:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ setup:
capabilities:
- method: POST
path: /_search
capabilities: [ byte_float_bit_dot_product ]
capabilities: [ byte_float_bit_dot_product_with_bugfix ]
reason: Capability required to run test
- do:
catch: bad_request
Expand Down Expand Up @@ -399,7 +399,7 @@ setup:
capabilities:
- method: POST
path: /_search
capabilities: [ byte_float_bit_dot_product ]
capabilities: [ byte_float_bit_dot_product_with_bugfix ]
test_runner_features: [capabilities, close_to]
reason: Capability required to run test
- do:
Expand All @@ -419,13 +419,13 @@ setup:
- match: { hits.total: 3 }

- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}
- close_to: {hits.hits.0._score: {value: 33.78, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}
- close_to: {hits.hits.1._score:{value: 22.579, error: 0.01}}

- match: {hits.hits.2._id: "1"}
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
- close_to: {hits.hits.2._score: {value: 11.919, error: 0.01}}

- do:
headers:
Expand All @@ -444,20 +444,20 @@ setup:
- match: { hits.total: 3 }

- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}
- close_to: {hits.hits.0._score: {value: 33.78, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}
- close_to: {hits.hits.1._score:{value: 22.579, error: 0.01}}

- match: {hits.hits.2._id: "1"}
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
- close_to: {hits.hits.2._score: {value: 11.919, error: 0.01}}
---
"Dot product with byte":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ byte_float_bit_dot_product ]
capabilities: [ byte_float_bit_dot_product_with_bugfix ]
test_runner_features: capabilities
reason: Capability required to run test
- do:
Expand All @@ -476,14 +476,14 @@ setup:

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "1"}
- match: {hits.hits.0._score: 248}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0._score: 415}

- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1._score: 136}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 168}

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 20}
- match: {hits.hits.2._id: "2"}
- match: {hits.hits.2._score: 126}

- do:
headers:
Expand All @@ -501,11 +501,11 @@ setup:

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "1"}
- match: {hits.hits.0._score: 248}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0._score: 415}

- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1._score: 136}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 168}

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 20}
- match: {hits.hits.2._id: "2"}
- match: {hits.hits.2._score: 126}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ private SearchCapabilities() {}
/** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
/** Support Byte and Float with Bit dot product. */
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product";
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product_with_bugfix";
/** Support docvalue_fields parameter for `dense_vector` field. */
private static final String DENSE_VECTOR_DOCVALUE_FIELDS = "dense_vector_docvalue_fields";
/** Support kql query. */
Expand All @@ -39,7 +39,7 @@ private SearchCapabilities() {}
/** Support multi-dense-vector script field access. */
private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access";
/** Initial support for multi-dense-vector maxSim functions access. */
private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim";
private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim_with_bugfix";

private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void testBitMultiVectorClassBindingsDotProduct() throws IOException {
function = new MaxSimDotProduct(scoreScript, floatQueryVector, fieldName);
assertEquals(
"maxSimDotProduct result is not equal to the expected value!",
0.42f + 0f + 1f - 1f - 0.42f,
-1.4f + 0.42f + 0f + 1f - 1f,
function.maxSimDotProduct(),
0.001
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ public void testBitVectorClassBindingsDotProduct() throws IOException {
function = new DotProduct(scoreScript, floatQueryVector, fieldName);
assertEquals(
"dotProduct result is not equal to the expected value!",
0.42f + 0f + 1f - 1f - 0.42f,
-1.4f + 0.42f + 0f + 1f - 1f,
function.dotProduct(),
0.001
);
Expand Down

0 comments on commit 8d16529

Please sign in to comment.