|
@@ -29,7 +29,6 @@ import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
-import java.util.Arrays;
|
|
|
import java.util.HexFormat;
|
|
|
import java.util.List;
|
|
|
|
|
@@ -43,8 +42,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|
|
String fieldName = "vector";
|
|
|
int dims = 5;
|
|
|
float[] docVector = new float[] { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f };
|
|
|
- List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
|
|
|
- List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
|
|
|
+ List<Number> queryVector = List.of(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
|
|
|
+ List<Number> invalidQueryVector = List.of(0.5, 111.3);
|
|
|
|
|
|
List<DenseVectorDocValuesField> fields = List.of(
|
|
|
new BinaryDenseVectorDocValuesField(
|
|
@@ -141,8 +140,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|
|
String fieldName = "vector";
|
|
|
int dims = 5;
|
|
|
float[] docVector = new float[] { 1, 127, -128, 5, -10 };
|
|
|
- List<Number> queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
|
|
|
- List<Number> invalidQueryVector = Arrays.asList((byte) 1, (byte) 1);
|
|
|
+ List<Number> queryVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
|
|
|
+ List<Number> invalidQueryVector = List.of((byte) 1, (byte) 1);
|
|
|
String hexidecimalString = HexFormat.of().formatHex(new byte[] { 1, 125, -12, 2, 4 });
|
|
|
|
|
|
List<DenseVectorDocValuesField> fields = List.of(
|
|
@@ -183,11 +182,12 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|
|
for (int i = 0; i < queryVectorArray.length; i++) {
|
|
|
queryVectorArray[i] = queryVector.get(i).floatValue();
|
|
|
}
|
|
|
- UnsupportedOperationException uoe = expectThrows(
|
|
|
- UnsupportedOperationException.class,
|
|
|
- () -> field.getInternal().cosineSimilarity(queryVectorArray, true)
|
|
|
+ assertEquals(
|
|
|
+ "cosineSimilarity result is not equal to the expected value!",
|
|
|
+ cosineSimilarityExpected,
|
|
|
+ field.getInternal().cosineSimilarity(queryVectorArray, true),
|
|
|
+ 0.001
|
|
|
);
|
|
|
- assertThat(uoe.getMessage(), containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead"));
|
|
|
|
|
|
// Check each function rejects query vectors with the wrong dimension
|
|
|
IllegalArgumentException e = expectThrows(
|
|
@@ -240,9 +240,9 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|
|
int dims = 8;
|
|
|
float[] docVector = new float[] { 124 };
|
|
|
// 124 in binary is b01111100
|
|
|
- List<Number> queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12);
|
|
|
- List<Number> floatQueryVector = Arrays.asList(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f);
|
|
|
- List<Number> invalidQueryVector = Arrays.asList((byte) 1, (byte) 1);
|
|
|
+ List<Number> queryVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12);
|
|
|
+ List<Number> floatQueryVector = List.of(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f);
|
|
|
+ List<Number> invalidQueryVector = List.of((byte) 1, (byte) 1);
|
|
|
String hexidecimalString = HexFormat.of().formatHex(new byte[] { 124 });
|
|
|
|
|
|
List<DenseVectorDocValuesField> fields = List.of(
|
|
@@ -293,8 +293,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|
|
public void testByteVsFloatSimilarity() throws IOException {
|
|
|
int dims = 5;
|
|
|
float[] docVector = new float[] { 1f, 127f, -128f, 5f, -10f };
|
|
|
- List<Number> listFloatVector = Arrays.asList(1f, 125f, -12f, 2f, 4f);
|
|
|
- List<Number> listByteVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
|
|
|
+ List<Number> listFloatVector = List.of(1f, 125f, -12f, 2f, 4f);
|
|
|
+ List<Number> listByteVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
|
|
|
float[] floatVector = new float[] { 1f, 125f, -12f, 2f, 4f };
|
|
|
byte[] byteVector = new byte[] { (byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4 };
|
|
|
|
|
@@ -342,11 +342,7 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|
|
switch (field.getElementType()) {
|
|
|
case BYTE -> {
|
|
|
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(byteVector));
|
|
|
- UnsupportedOperationException e = expectThrows(
|
|
|
- UnsupportedOperationException.class,
|
|
|
- () -> field.get().dotProduct(floatVector)
|
|
|
- );
|
|
|
- assertThat(e.getMessage(), containsString("use [int dotProduct(byte[] queryVector)] instead"));
|
|
|
+ assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
|
|
|
}
|
|
|
case FLOAT -> {
|
|
|
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
|
|
@@ -423,14 +419,7 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|
|
switch (field.getElementType()) {
|
|
|
case BYTE -> {
|
|
|
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(byteVector), 0.001);
|
|
|
- UnsupportedOperationException e = expectThrows(
|
|
|
- UnsupportedOperationException.class,
|
|
|
- () -> field.get().cosineSimilarity(floatVector)
|
|
|
- );
|
|
|
- assertThat(
|
|
|
- e.getMessage(),
|
|
|
- containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead")
|
|
|
- );
|
|
|
+ assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);
|
|
|
}
|
|
|
case FLOAT -> {
|
|
|
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);
|
|
@@ -471,81 +460,55 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|
|
ScoreScript scoreScript = mock(ScoreScript.class);
|
|
|
when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
|
|
|
|
|
|
- IllegalArgumentException e;
|
|
|
-
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, greaterThanVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [128.0]"
|
|
|
- );
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, greaterThanVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [128.0]"
|
|
|
- );
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, greaterThanVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [128.0]"
|
|
|
+ expectThrows(
|
|
|
+ IllegalArgumentException.class,
|
|
|
+ containsString(
|
|
|
+ "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
|
|
+ + "Preview of invalid vector: [128.0]"
|
|
|
+ ),
|
|
|
+ () -> new L1Norm(scoreScript, greaterThanVector, fieldName)
|
|
|
);
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, greaterThanVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [128.0]"
|
|
|
+ expectThrows(
|
|
|
+ IllegalArgumentException.class,
|
|
|
+ containsString(
|
|
|
+ "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
|
|
+ + "Preview of invalid vector: [128.0]"
|
|
|
+ ),
|
|
|
+ () -> new L2Norm(scoreScript, greaterThanVector, fieldName)
|
|
|
);
|
|
|
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, lessThanVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [-129.0]"
|
|
|
- );
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, lessThanVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [-129.0]"
|
|
|
- );
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, lessThanVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [-129.0]"
|
|
|
+ expectThrows(
|
|
|
+ IllegalArgumentException.class,
|
|
|
+ containsString(
|
|
|
+ "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
|
|
+ + "Preview of invalid vector: [-129.0]"
|
|
|
+ ),
|
|
|
+ () -> new L1Norm(scoreScript, lessThanVector, fieldName)
|
|
|
);
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, lessThanVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [-129.0]"
|
|
|
+ expectThrows(
|
|
|
+ IllegalArgumentException.class,
|
|
|
+ containsString(
|
|
|
+ "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
|
|
+ + "Preview of invalid vector: [-129.0]"
|
|
|
+ ),
|
|
|
+ () -> new L2Norm(scoreScript, lessThanVector, fieldName)
|
|
|
);
|
|
|
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, decimalVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [0.5]"
|
|
|
- );
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, decimalVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [0.5]"
|
|
|
- );
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, decimalVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [0.5]"
|
|
|
+ expectThrows(
|
|
|
+ IllegalArgumentException.class,
|
|
|
+ containsString(
|
|
|
+ "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
|
|
+ + "Preview of invalid vector: [0.5]"
|
|
|
+ ),
|
|
|
+ () -> new L1Norm(scoreScript, decimalVector, fieldName)
|
|
|
);
|
|
|
- e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, decimalVector, fieldName));
|
|
|
- assertEquals(
|
|
|
- e.getMessage(),
|
|
|
- "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
|
|
- + "Preview of invalid vector: [0.5]"
|
|
|
+ expectThrows(
|
|
|
+ IllegalArgumentException.class,
|
|
|
+ containsString(
|
|
|
+ "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
|
|
+ + "Preview of invalid vector: [0.5]"
|
|
|
+ ),
|
|
|
+ () -> new L2Norm(scoreScript, decimalVector, fieldName)
|
|
|
);
|
|
|
}
|
|
|
}
|