|
@@ -44,11 +44,14 @@ public class DenseVectorFunctionTests extends ESTestCase {
|
|
|
public void testVectorFunctions() {
|
|
|
for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
|
|
|
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
|
|
|
+ float magnitude = VectorEncoderDecoder.getMagnitude(indexVersion, encodedDocVector);
|
|
|
+
|
|
|
DenseVectorScriptDocValues docValues = mock(DenseVectorScriptDocValues.class);
|
|
|
when(docValues.getEncodedValue()).thenReturn(encodedDocVector);
|
|
|
+ when(docValues.getMagnitude()).thenReturn(magnitude);
|
|
|
+ when(docValues.dims()).thenReturn(docVector.length);
|
|
|
|
|
|
ScoreScript scoreScript = mock(ScoreScript.class);
|
|
|
- when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
|
|
|
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues));
|
|
|
|
|
|
testDotProduct(scoreScript);
|
|
@@ -63,8 +66,8 @@ public class DenseVectorFunctionTests extends ESTestCase {
|
|
|
double result = function.dotProduct();
|
|
|
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
|
|
|
|
|
|
- DotProduct invalidFunction = new DotProduct(scoreScript, invalidQueryVector, field);
|
|
|
- IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::dotProduct);
|
|
|
+ IllegalArgumentException e =
|
|
|
+ expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, invalidQueryVector, field));
|
|
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
|
|
}
|
|
|
|
|
@@ -73,8 +76,8 @@ public class DenseVectorFunctionTests extends ESTestCase {
|
|
|
double result = function.cosineSimilarity();
|
|
|
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result, 0.001);
|
|
|
|
|
|
- CosineSimilarity invalidFunction = new CosineSimilarity(scoreScript, invalidQueryVector, field);
|
|
|
- IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::cosineSimilarity);
|
|
|
+ IllegalArgumentException e =
|
|
|
+ expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, invalidQueryVector, field));
|
|
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
|
|
}
|
|
|
|
|
@@ -83,8 +86,8 @@ public class DenseVectorFunctionTests extends ESTestCase {
|
|
|
double result = function.l1norm();
|
|
|
assertEquals("l1norm result is not equal to the expected value!", 485.184, result, 0.001);
|
|
|
|
|
|
- L1Norm invalidFunction = new L1Norm(scoreScript, invalidQueryVector, field);
|
|
|
- IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l1norm);
|
|
|
+ IllegalArgumentException e =
|
|
|
+ expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, invalidQueryVector, field));
|
|
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
|
|
}
|
|
|
|
|
@@ -93,12 +96,11 @@ public class DenseVectorFunctionTests extends ESTestCase {
|
|
|
double result = function.l2norm();
|
|
|
assertEquals("l2norm result is not equal to the expected value!", 301.361, result, 0.001);
|
|
|
|
|
|
- L2Norm invalidFunction = new L2Norm(scoreScript, invalidQueryVector, field);
|
|
|
- IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l2norm);
|
|
|
+ IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, invalidQueryVector, field));
|
|
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
|
|
}
|
|
|
|
|
|
- private static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) {
|
|
|
+ static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) {
|
|
|
byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0)
|
|
|
? new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES]
|
|
|
: new byte[VectorEncoderDecoder.INT_BYTES * values.length];
|