1
0
Эх сурвалжийг харах

Small optimization in OptimizedScalarQuantizer by using mul instead of div (#132397)

Ignacio Vera 2 сар өмнө
parent
commit
a4045d8f53

+ 2 - 1
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

@@ -80,10 +80,11 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
         float dbb = 0;
         float dax = 0;
         float dbx = 0;
+        float invPmOnes = 1f / (points - 1f);
         for (int i = 0; i < target.length; ++i) {
             float v = target[i];
             float k = quantize[i];
-            float s = k / (points - 1);
+            float s = k * invPmOnes;
             float ms = 1f - s;
             daa = fma(ms, ms, daa);
             dab = fma(ms, s, dab);

+ 11 - 9
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

@@ -132,7 +132,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
                 FloatVector centeredVec = v.sub(c);
                 FloatVector deltaVec = centeredVec.sub(vecMeanVec);
                 norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
-                vecMeanVec = vecMeanVec.add(deltaVec.div(count));
+                vecMeanVec = vecMeanVec.add(deltaVec.mul(1f / count));
                 FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
                 m2Vec = fma(deltaVec, delta2Vec, m2Vec);
                 minVec = minVec.min(centeredVec);
@@ -214,7 +214,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
                 FloatVector centeredVec = v.sub(c);
                 FloatVector deltaVec = centeredVec.sub(vecMeanVec);
                 norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
-                vecMeanVec = vecMeanVec.add(deltaVec.div(count));
+                vecMeanVec = vecMeanVec.add(deltaVec.mul(1f / count));
                 FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
                 m2Vec = fma(deltaVec, delta2Vec, m2Vec);
                 minVec = minVec.min(centeredVec);
@@ -278,6 +278,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
         float dbb = 0;
         float dax = 0;
         float dbx = 0;
+        float invPmOnes = 1f / (points - 1f);
         // if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
         if (target.length > 2 * FLOAT_SPECIES.length()) {
             FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
@@ -286,11 +287,11 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
             FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
             FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
             FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f);
-            FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f);
+            FloatVector invPmOnesVec = FloatVector.broadcast(FLOAT_SPECIES, invPmOnes);
             for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
                 FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
                 FloatVector oVec = IntVector.fromArray(INTEGER_SPECIES, quantize, i).convert(VectorOperators.I2F, 0).reinterpretAsFloats();
-                FloatVector sVec = oVec.div(pmOnes);
+                FloatVector sVec = oVec.mul(invPmOnesVec);
                 FloatVector smVec = ones.sub(sVec);
                 daaVec = fma(smVec, smVec, daaVec);
                 dabVec = fma(smVec, sVec, dabVec);
@@ -307,7 +308,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
 
         for (; i < target.length; i++) {
             float k = quantize[i];
-            float s = k / (points - 1);
+            float s = k * invPmOnes;
             float ms = 1f - s;
             daa = fma(ms, ms, daa);
             dab = fma(ms, s, dab);
@@ -798,25 +799,26 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
     @Override
     public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
         float nSteps = ((1 << bits) - 1);
-        float step = (upperInterval - lowInterval) / nSteps;
+        float invStep = nSteps / (upperInterval - lowInterval);
         int sumQuery = 0;
         int i = 0;
         if (vector.length > 2 * FLOAT_SPECIES.length()) {
             int limit = FLOAT_SPECIES.loopBound(vector.length);
             FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval);
             FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval);
-            FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step);
+            FloatVector invStepVec = FloatVector.broadcast(FLOAT_SPECIES, invStep);
             for (; i < limit; i += FLOAT_SPECIES.length()) {
                 FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
                 FloatVector xi = v.max(lowVec).min(upperVec); // clamp
-                IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round
+                // round
+                IntVector assignment = xi.sub(lowVec).mul(invStepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts();
                 sumQuery += assignment.reduceLanes(ADD);
                 assignment.intoArray(destination, i);
             }
         }
         for (; i < vector.length; i++) {
             float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval);
-            int assignment = Math.round((xi - lowInterval) / step);
+            int assignment = Math.round((xi - lowInterval) * invStep);
             sumQuery += assignment;
             destination[i] = assignment;
         }