Jelajahi Sumber

Fix iterating for best centroid when algorithm is neighbour aware and decrease SAMPLES_PER_CLUSTER_DEFAULT (#130069)


* KMeansIntermediate shares assigments
Ignacio Vera 3 bulan lalu
induk
melakukan
ce74df5c0c

+ 5 - 35
server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

@@ -10,7 +10,6 @@
 package org.elasticsearch.index.codec.vectors.cluster;
 
 import org.apache.lucene.index.FloatVectorValues;
-import org.apache.lucene.util.VectorUtil;
 
 import java.io.IOException;
 
@@ -21,7 +20,7 @@ public class HierarchicalKMeans {
 
     static final int MAXK = 128;
     static final int MAX_ITERATIONS_DEFAULT = 6;
-    static final int SAMPLES_PER_CLUSTER_DEFAULT = 256;
+    static final int SAMPLES_PER_CLUSTER_DEFAULT = 64;
     static final float DEFAULT_SOAR_LAMBDA = 1.0f;
 
     final int dimension;
@@ -67,8 +66,7 @@ public class HierarchicalKMeans {
         // partition the space
         KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
         if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
-            float f = Math.min((float) samplesPerCluster / targetSize, 1.0f);
-            int localSampleSize = (int) (f * vectors.size());
+            int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster, vectors.size());
             KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
             kMeansLocal.cluster(vectors, kMeansIntermediate, true);
         }
@@ -86,42 +84,16 @@ public class HierarchicalKMeans {
 
         // TODO: instead of creating a sub-cluster assignments reuse the parent array each time
         int[] assignments = new int[vectors.size()];
-
         KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
         float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
-        KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
+        KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
         kmeans.cluster(vectors, kMeansIntermediate);
 
         // TODO: consider adding cluster size counts to the kmeans algo
         // handle assignment here so we can track distance and cluster size
         int[] centroidVectorCount = new int[centroids.length];
-        float[][] nextCentroids = new float[centroids.length][dimension];
-        for (int i = 0; i < vectors.size(); i++) {
-            float smallest = Float.MAX_VALUE;
-            int centroidIdx = -1;
-            float[] vector = vectors.vectorValue(i);
-            for (int j = 0; j < centroids.length; j++) {
-                float[] centroid = centroids[j];
-                float d = VectorUtil.squareDistance(vector, centroid);
-                if (d < smallest) {
-                    smallest = d;
-                    centroidIdx = j;
-                }
-            }
-            centroidVectorCount[centroidIdx]++;
-            for (int j = 0; j < dimension; j++) {
-                nextCentroids[centroidIdx][j] += vector[j];
-            }
-            assignments[i] = centroidIdx;
-        }
-
-        // update centroids based on assignments of all vectors
-        for (int i = 0; i < centroids.length; i++) {
-            if (centroidVectorCount[i] > 0) {
-                for (int j = 0; j < dimension; j++) {
-                    centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i];
-                }
-            }
+        for (int assigment : assignments) {
+            centroidVectorCount[assigment]++;
         }
 
         int effectiveK = 0;
@@ -131,8 +103,6 @@ public class HierarchicalKMeans {
             }
         }
 
-        kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
-
         if (effectiveK == 1) {
             return kMeansIntermediate;
         }

+ 0 - 4
server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java

@@ -31,10 +31,6 @@ class KMeansIntermediate extends KMeansResult {
         this(new float[0][0], new int[0], i -> i, new int[0]);
     }
 
-    KMeansIntermediate(float[][] centroids) {
-        this(centroids, new int[0], i -> i, new int[0]);
-    }
-
     KMeansIntermediate(float[][] centroids, int[] assignments) {
         this(centroids, assignments, i -> i, new int[0]);
     }

+ 29 - 23
server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

@@ -87,17 +87,17 @@ class KMeansLocal {
 
         for (int i = 0; i < sampleSize; i++) {
             float[] vector = vectors.vectorValue(i);
-            int[] neighborOffsets = null;
-            int centroidIdx = -1;
+            final int assignment = assignments[i];
+            final int bestCentroidOffset;
             if (neighborhoods != null) {
-                neighborOffsets = neighborhoods.get(assignments[i]);
-                centroidIdx = assignments[i];
+                bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
+            } else {
+                bestCentroidOffset = getBestCentroid(centroids, vector);
             }
-            int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets);
-            if (assignments[i] != bestCentroidOffset) {
+            if (assignment != bestCentroidOffset) {
+                assignments[i] = bestCentroidOffset;
                 changed = true;
             }
-            assignments[i] = bestCentroidOffset;
             centroidCounts[bestCentroidOffset]++;
             for (int d = 0; d < dim; d++) {
                 nextCentroids[bestCentroidOffset][d] += vector[d];
@@ -116,23 +116,28 @@ class KMeansLocal {
         return changed;
     }
 
-    int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
+    int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
         int bestCentroidOffset = centroidIdx;
-        float minDsq;
-        if (centroidIdx > 0 && centroidIdx < centroids.length) {
-            minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
-        } else {
-            minDsq = Float.MAX_VALUE;
+        assert centroidIdx >= 0 && centroidIdx < centroids.length;
+        float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
+        for (int offset : centroidOffsets) {
+            float dsq = VectorUtil.squareDistance(vector, centroids[offset]);
+            if (dsq < minDsq) {
+                minDsq = dsq;
+                bestCentroidOffset = offset;
+            }
         }
+        return bestCentroidOffset;
+    }
 
-        int k = 0;
-        for (int j = 0; j < centroids.length; j++) {
-            if (centroidOffsets == null || j == centroidOffsets[k]) {
-                float dsq = VectorUtil.squareDistance(vector, centroids[j]);
-                if (dsq < minDsq) {
-                    minDsq = dsq;
-                    bestCentroidOffset = j;
-                }
+    int getBestCentroid(float[][] centroids, float[] vector) {
+        int bestCentroidOffset = 0;
+        float minDsq = Float.MAX_VALUE;
+        for (int i = 0; i < centroids.length; i++) {
+            float dsq = VectorUtil.squareDistance(vector, centroids[i]);
+            if (dsq < minDsq) {
+                minDsq = dsq;
+                bestCentroidOffset = i;
             }
         }
         return bestCentroidOffset;
@@ -271,7 +276,8 @@ class KMeansLocal {
             return;
         }
 
-        int[] assignments = new int[n];
+        int[] assignments = kMeansIntermediate.assignments();
+        assert assignments.length == n;
         float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
         for (int i = 0; i < maxIterations; i++) {
             if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
@@ -291,7 +297,7 @@ class KMeansLocal {
      * @param maxIterations the max iterations to shift centroids
      */
     public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
-        KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
+        KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc);
         KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
         kMeans.cluster(vectors, kMeansIntermediate);
     }