|
@@ -19,7 +19,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|
|
|
|
|
static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
|
|
|
|
|
|
- static float[] deQuantize(byte[] quantized, byte bits, float[] interval, float[] centroid) {
|
|
|
+ static float[] deQuantize(int[] quantized, byte bits, float[] interval, float[] centroid) {
|
|
|
float[] dequantized = new float[quantized.length];
|
|
|
float a = interval[0];
|
|
|
float b = interval[1];
|
|
@@ -52,10 +52,12 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|
|
float[] scratch = new float[dims];
|
|
|
for (byte bit : ALL_BITS) {
|
|
|
float eps = (1f / (float) (1 << (bit)));
|
|
|
- byte[] destination = new byte[dims];
|
|
|
+ byte[] legacyDestination = new byte[dims];
|
|
|
+ int[] destination = new int[dims];
|
|
|
for (int i = 0; i < numVectors; ++i) {
|
|
|
System.arraycopy(vectors[i], 0, scratch, 0, dims);
|
|
|
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(scratch, destination, bit, centroid);
|
|
|
+
|
|
|
assertValidResults(result);
|
|
|
assertValidQuantizedRange(destination, bit);
|
|
|
|
|
@@ -71,6 +73,19 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|
|
}
|
|
|
mae /= dims;
|
|
|
assertTrue("bits: " + bit + " mae: " + mae + " > eps: " + eps, mae <= eps);
|
|
|
+
|
|
|
+ // check we get the same result from the int version
|
|
|
+ System.arraycopy(vectors[i], 0, scratch, 0, dims);
|
|
|
+ OptimizedScalarQuantizer.QuantizationResult intResults = osq.legacyScalarQuantize(
|
|
|
+ scratch,
|
|
|
+ legacyDestination,
|
|
|
+ bit,
|
|
|
+ centroid
|
|
|
+ );
|
|
|
+ assertEquals(result, intResults);
|
|
|
+ for (int h = 0; h < dims; ++h) {
|
|
|
+ assertEquals((byte) destination[h], legacyDestination[h]);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -84,18 +99,18 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|
|
float[] vector = new float[4096];
|
|
|
float[] centroid = new float[4096];
|
|
|
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
|
|
|
- byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096];
|
|
|
+ int[][] destinations = new int[MINIMUM_MSE_GRID.length][4096];
|
|
|
OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
|
|
|
assertEquals(MINIMUM_MSE_GRID.length, results.length);
|
|
|
assertValidResults(results);
|
|
|
- for (byte[] destination : destinations) {
|
|
|
- assertArrayEquals(new byte[4096], destination);
|
|
|
+ for (int[] destination : destinations) {
|
|
|
+ assertArrayEquals(new int[4096], destination);
|
|
|
}
|
|
|
- byte[] destination = new byte[4096];
|
|
|
+ int[] destination = new int[4096];
|
|
|
for (byte bit : ALL_BITS) {
|
|
|
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
|
|
|
assertValidResults(result);
|
|
|
- assertArrayEquals(new byte[4096], destination);
|
|
|
+ assertArrayEquals(new int[4096], destination);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -108,7 +123,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|
|
VectorUtil.l2normalize(centroid);
|
|
|
}
|
|
|
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
|
|
|
- byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1];
|
|
|
+ int[][] destinations = new int[MINIMUM_MSE_GRID.length][1];
|
|
|
OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
|
|
|
assertEquals(MINIMUM_MSE_GRID.length, results.length);
|
|
|
assertValidResults(results);
|
|
@@ -122,7 +137,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|
|
VectorUtil.l2normalize(vector);
|
|
|
VectorUtil.l2normalize(centroid);
|
|
|
}
|
|
|
- byte[] destination = new byte[1];
|
|
|
+ int[] destination = new int[1];
|
|
|
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
|
|
|
assertValidResults(result);
|
|
|
assertValidQuantizedRange(destination, bit);
|
|
@@ -150,7 +165,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|
|
VectorUtil.l2normalize(centroid);
|
|
|
}
|
|
|
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
|
|
|
- byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims];
|
|
|
+ int[][] destinations = new int[MINIMUM_MSE_GRID.length][dims];
|
|
|
OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid);
|
|
|
assertEquals(MINIMUM_MSE_GRID.length, results.length);
|
|
|
assertValidResults(results);
|
|
@@ -158,7 +173,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|
|
assertValidQuantizedRange(destinations[i], ALL_BITS[i]);
|
|
|
}
|
|
|
for (byte bit : ALL_BITS) {
|
|
|
- byte[] destination = new byte[dims];
|
|
|
+ int[] destination = new int[dims];
|
|
|
System.arraycopy(vector, 0, copy, 0, dims);
|
|
|
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
|
|
|
VectorUtil.l2normalize(copy);
|
|
@@ -171,8 +186,8 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- static void assertValidQuantizedRange(byte[] quantized, byte bits) {
|
|
|
- for (byte b : quantized) {
|
|
|
+ static void assertValidQuantizedRange(int[] quantized, byte bits) {
|
|
|
+ for (int b : quantized) {
|
|
|
if (bits < 8) {
|
|
|
assertTrue(b >= 0);
|
|
|
}
|