|
@@ -78,15 +78,20 @@ public class OptimizedScalarQuantizer {
|
|
|
int points = (1 << bits[i]);
|
|
|
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
|
|
initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch);
|
|
|
- optimizeIntervals(intervalScratch, vector, norm2, points);
|
|
|
+ boolean hasQuantization = optimizeIntervals(intervalScratch, destinations[i], vector, norm2, points);
|
|
|
// Now we have the optimized intervals, quantize the vector
|
|
|
- int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(
|
|
|
- vector,
|
|
|
- destinations[i],
|
|
|
- intervalScratch[0],
|
|
|
- intervalScratch[1],
|
|
|
- bits[i]
|
|
|
- );
|
|
|
+ int sumQuery;
|
|
|
+ if (hasQuantization) {
|
|
|
+ sumQuery = getSumQuery(destinations[i]);
|
|
|
+ } else {
|
|
|
+ sumQuery = ESVectorUtil.quantizeVectorWithIntervals(
|
|
|
+ vector,
|
|
|
+ destinations[i],
|
|
|
+ intervalScratch[0],
|
|
|
+ intervalScratch[1],
|
|
|
+ bits[i]
|
|
|
+ );
|
|
|
+ }
|
|
|
results[i] = new QuantizationResult(
|
|
|
intervalScratch[0],
|
|
|
intervalScratch[1],
|
|
@@ -97,8 +102,7 @@ public class OptimizedScalarQuantizer {
|
|
|
return results;
|
|
|
}
|
|
|
|
|
|
- // This method is only used for benchmarking purposes, it is not used in production
|
|
|
- public QuantizationResult legacyScalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
|
|
|
+ public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte bits, float[] centroid) {
|
|
|
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
|
|
|
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
|
|
assert vector.length <= destination.length;
|
|
@@ -117,49 +121,14 @@ public class OptimizedScalarQuantizer {
|
|
|
float vecStd = (float) Math.sqrt(vecVar);
|
|
|
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
|
|
initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
|
|
|
- optimizeIntervals(intervalScratch, vector, norm2, points);
|
|
|
- float nSteps = ((1 << bits) - 1);
|
|
|
+ boolean hasQuantization = optimizeIntervals(intervalScratch, destination, vector, norm2, points);
|
|
|
// Now we have the optimized intervals, quantize the vector
|
|
|
- float a = intervalScratch[0];
|
|
|
- float b = intervalScratch[1];
|
|
|
- float step = (b - a) / nSteps;
|
|
|
- int sumQuery = 0;
|
|
|
- for (int h = 0; h < vector.length; h++) {
|
|
|
- float xi = (float) clamp(vector[h], a, b);
|
|
|
- int assignment = Math.round((xi - a) / step);
|
|
|
- sumQuery += assignment;
|
|
|
- destination[h] = (byte) assignment;
|
|
|
- }
|
|
|
- return new QuantizationResult(
|
|
|
- intervalScratch[0],
|
|
|
- intervalScratch[1],
|
|
|
- similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
|
|
|
- sumQuery
|
|
|
- );
|
|
|
- }
|
|
|
-
|
|
|
- public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte bits, float[] centroid) {
|
|
|
- assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
|
|
|
- assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
|
|
- assert vector.length <= destination.length;
|
|
|
- assert bits > 0 && bits <= 8;
|
|
|
- int points = 1 << bits;
|
|
|
- if (similarityFunction == EUCLIDEAN) {
|
|
|
- ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
|
|
|
+ int sumQuery;
|
|
|
+ if (hasQuantization) {
|
|
|
+ sumQuery = getSumQuery(destination);
|
|
|
} else {
|
|
|
- ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
|
|
|
+ sumQuery = ESVectorUtil.quantizeVectorWithIntervals(vector, destination, intervalScratch[0], intervalScratch[1], bits);
|
|
|
}
|
|
|
- float vecMean = statsScratch[0];
|
|
|
- float vecVar = statsScratch[1];
|
|
|
- float norm2 = statsScratch[2];
|
|
|
- float min = statsScratch[3];
|
|
|
- float max = statsScratch[4];
|
|
|
- float vecStd = (float) Math.sqrt(vecVar);
|
|
|
- // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
|
|
- initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
|
|
|
- optimizeIntervals(intervalScratch, vector, norm2, points);
|
|
|
- // Now we have the optimized intervals, quantize the vector
|
|
|
- int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(vector, destination, intervalScratch[0], intervalScratch[1], bits);
|
|
|
return new QuantizationResult(
|
|
|
intervalScratch[0],
|
|
|
intervalScratch[1],
|
|
@@ -176,16 +145,18 @@ public class OptimizedScalarQuantizer {
|
|
|
* @param vector raw vector
|
|
|
* @param norm2 squared norm of the vector
|
|
|
* @param points number of quantization points
|
|
|
+ *
|
|
|
+ * @return true if {@param destination} contains the quantize vector and we can skip the quantization.
|
|
|
*/
|
|
|
- private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) {
|
|
|
- double initialLoss = ESVectorUtil.calculateOSQLoss(vector, initInterval, points, norm2, lambda);
|
|
|
+ private boolean optimizeIntervals(float[] initInterval, int[] destination, float[] vector, float norm2, int points) {
|
|
|
+ double initialLoss = ESVectorUtil.calculateOSQLoss(vector, initInterval[0], initInterval[1], points, norm2, lambda, destination);
|
|
|
final float scale = (1.0f - lambda) / norm2;
|
|
|
if (Float.isFinite(scale) == false) {
|
|
|
- return;
|
|
|
+ return true;
|
|
|
}
|
|
|
for (int i = 0; i < iters; ++i) {
|
|
|
// calculate the grid points for coordinate descent
|
|
|
- ESVectorUtil.calculateOSQGridPoints(vector, initInterval, points, gridScratch);
|
|
|
+ ESVectorUtil.calculateOSQGridPoints(vector, destination, points, gridScratch);
|
|
|
float daa = gridScratch[0];
|
|
|
float dab = gridScratch[1];
|
|
|
float dbb = gridScratch[2];
|
|
@@ -197,26 +168,35 @@ public class OptimizedScalarQuantizer {
|
|
|
// its possible that the determinant is 0, in which case we can't update the interval
|
|
|
double det = m0 * m2 - m1 * m1;
|
|
|
if (det == 0) {
|
|
|
- return;
|
|
|
+ return true;
|
|
|
}
|
|
|
float aOpt = (float) ((m2 * dax - m1 * dbx) / det);
|
|
|
float bOpt = (float) ((m0 * dbx - m1 * dax) / det);
|
|
|
// If there is no change in the interval, we can stop
|
|
|
if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) {
|
|
|
- return;
|
|
|
+ return true;
|
|
|
}
|
|
|
- double newLoss = ESVectorUtil.calculateOSQLoss(vector, new float[] { aOpt, bOpt }, points, norm2, lambda);
|
|
|
+ double newLoss = ESVectorUtil.calculateOSQLoss(vector, aOpt, bOpt, points, norm2, lambda, destination);
|
|
|
// If the new loss is worse, don't update the interval and exit
|
|
|
// This optimization, unlike kMeans, does not always converge to better loss
|
|
|
// So exit if we are getting worse
|
|
|
if (newLoss > initialLoss) {
|
|
|
- return;
|
|
|
+ return false;
|
|
|
}
|
|
|
// Update the interval and go again
|
|
|
initInterval[0] = aOpt;
|
|
|
initInterval[1] = bOpt;
|
|
|
initialLoss = newLoss;
|
|
|
}
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ private static int getSumQuery(int[] quantize) {
|
|
|
+ int sum = 0;
|
|
|
+ for (int q : quantize) {
|
|
|
+ sum += q;
|
|
|
+ }
|
|
|
+ return sum;
|
|
|
}
|
|
|
|
|
|
private static double clamp(double x, double a, double b) {
|