1
0
Эх сурвалжийг харах

[DISKBBQ] Don't spill vectors that are numerically equivalent to the centroid (#132706)

This commit changes the degenerated case, where the vector is equivalent to the centroid, then the vector does not ]
get a soar assignment, which is defined as a -1 in the soar assignments array.
Ignacio Vera 2 сар өмнө
parent
commit
4d1b7a69cd

+ 21 - 39
server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

@@ -28,8 +28,9 @@ import java.util.Random;
 class KMeansLocal {
 
     // the minimum distance that is considered to be "far enough" to a centroid in order to compute the soar distance.
-    // For vectors that are closer than this distance to the centroid, we use the squared distance to find the
-    // second closest centroid.
+    // For vectors that are closer than this distance to the centroid don't get spilled because they are well represented
+    // by the centroid itself. In many cases, it indicates a degenerated distribution, e.g the cluster is composed of the
+    // many equal vectors.
     private static final float SOAR_MIN_DISTANCE = 1e-16f;
 
     final int sampleSize;
@@ -281,19 +282,18 @@ class KMeansLocal {
         final float[] distances = new float[4];
         for (int i = 0; i < vectors.size(); i++) {
             float[] vector = vectors.vectorValue(i);
-
             int currAssignment = assignments[i];
             float[] currentCentroid = centroids[currAssignment];
-
             // TODO: cache these?
             float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);
-
-            if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
-                for (int j = 0; j < vectors.dimension(); j++) {
-                    diffs[j] = vector[j] - currentCentroid[j];
-                }
+            if (vectorCentroidDist <= SOAR_MIN_DISTANCE) {
+                spilledAssignments[i] = -1; // no SOAR assignment
+                continue;
             }
 
+            for (int j = 0; j < vectors.dimension(); j++) {
+                diffs[j] = vector[j] - currentCentroid[j];
+            }
             final int centroidCount;
             final IntToIntFunction centroidOrds;
             if (neighborhoods != null) {
@@ -310,29 +310,17 @@ class KMeansLocal {
             float minSoar = Float.MAX_VALUE;
             int j = 0;
             for (; j < limit; j += 4) {
-                if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
-                    ESVectorUtil.soarDistanceBulk(
-                        vector,
-                        centroids[centroidOrds.apply(j)],
-                        centroids[centroidOrds.apply(j + 1)],
-                        centroids[centroidOrds.apply(j + 2)],
-                        centroids[centroidOrds.apply(j + 3)],
-                        diffs,
-                        soarLambda,
-                        vectorCentroidDist,
-                        distances
-                    );
-                } else {
-                    // if the vector is very close to the centroid, we look for the second-nearest centroid
-                    ESVectorUtil.squareDistanceBulk(
-                        vector,
-                        centroids[centroidOrds.apply(j)],
-                        centroids[centroidOrds.apply(j + 1)],
-                        centroids[centroidOrds.apply(j + 2)],
-                        centroids[centroidOrds.apply(j + 3)],
-                        distances
-                    );
-                }
+                ESVectorUtil.soarDistanceBulk(
+                    vector,
+                    centroids[centroidOrds.apply(j)],
+                    centroids[centroidOrds.apply(j + 1)],
+                    centroids[centroidOrds.apply(j + 2)],
+                    centroids[centroidOrds.apply(j + 3)],
+                    diffs,
+                    soarLambda,
+                    vectorCentroidDist,
+                    distances
+                );
                 for (int k = 0; k < distances.length; k++) {
                     float soar = distances[k];
                     if (soar < minSoar) {
@@ -344,13 +332,7 @@ class KMeansLocal {
 
             for (; j < centroidCount; j++) {
                 int centroidOrd = centroidOrds.apply(j);
-                float soar;
-                if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
-                    soar = ESVectorUtil.soarDistance(vector, centroids[centroidOrd], diffs, soarLambda, vectorCentroidDist);
-                } else {
-                    // if the vector is very close to the centroid, we look for the second-nearest centroid
-                    soar = VectorUtil.squareDistance(vector, centroids[centroidOrd]);
-                }
+                float soar = ESVectorUtil.soarDistance(vector, centroids[centroidOrd], diffs, soarLambda, vectorCentroidDist);
                 if (soar < minSoar) {
                     minSoar = soar;
                     bestAssignment = centroidOrd;

+ 62 - 0
server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java

@@ -22,8 +22,10 @@ import org.apache.lucene.index.DirectoryReader;
 import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.VectorEncoding;
 import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
 import org.apache.lucene.tests.util.TestUtil;
@@ -145,6 +147,66 @@ public class IVFVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
         }
     }
 
+    public void testFewVectorManyTimes() throws IOException {
+        int numDifferentVectors = random().nextInt(1, 20);
+        float[][] vectors = new float[numDifferentVectors][];
+        int dimensions = random().nextInt(12, 500);
+        for (int i = 0; i < numDifferentVectors; i++) {
+            vectors[i] = randomVector(dimensions);
+        }
+        int numDocs = random().nextInt(100, 10_000);
+        try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+            for (int i = 0; i < numDocs; i++) {
+                float[] vector = vectors[random().nextInt(numDifferentVectors)];
+                Document doc = new Document();
+                doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
+                w.addDocument(doc);
+            }
+            w.commit();
+            if (rarely()) {
+                w.forceMerge(1);
+            }
+            try (IndexReader reader = DirectoryReader.open(w)) {
+                List<LeafReaderContext> subReaders = reader.leaves();
+                for (LeafReaderContext r : subReaders) {
+                    LeafReader leafReader = r.reader();
+                    float[] vector = randomVector(dimensions);
+                    TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE);
+                    assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length);
+                }
+
+            }
+        }
+    }
+
+    public void testOneRepeatedVector() throws IOException {
+        int dimensions = random().nextInt(12, 500);
+        float[] repeatedVector = randomVector(dimensions);
+        int numDocs = random().nextInt(100, 10_000);
+        try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+            for (int i = 0; i < numDocs; i++) {
+                float[] vector = random().nextInt(3) == 0 ? repeatedVector : randomVector(dimensions);
+                Document doc = new Document();
+                doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
+                w.addDocument(doc);
+            }
+            w.commit();
+            if (rarely()) {
+                w.forceMerge(1);
+            }
+            try (IndexReader reader = DirectoryReader.open(w)) {
+                List<LeafReaderContext> subReaders = reader.leaves();
+                for (LeafReaderContext r : subReaders) {
+                    LeafReader leafReader = r.reader();
+                    float[] vector = randomVector(dimensions);
+                    TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE);
+                    assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length);
+                }
+
+            }
+        }
+    }
+
     // this is a modified version of lucene's TestSearchWithThreads test case
     public void testWithThreads() throws Exception {
         final int numThreads = random().nextInt(2, 5);