Browse Source

Adj ivf postings list building (#130843)

* Make postings list building more IO friendly

* iter

* iter

* fixing assertion

* [CI] Auto commit changes from spotless

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Benjamin Trent 3 months ago
parent
commit
44497b7b05

+ 5 - 4
server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java

@@ -9,10 +9,11 @@
 
 package org.elasticsearch.index.codec.vectors;
 
-record CentroidAssignments(int numCentroids, float[][] centroids, int[][] assignmentsByCluster) {
+record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) {
 
-    CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) {
-        this(centroids.length, centroids, assignmentsByCluster);
-        assert centroids.length == assignmentsByCluster.length;
+    CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) {
+        this(centroids.length, centroids, assignments, overspillAssignments);
+        assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0
+            : "assignments and overspillAssignments must have the same length";
     }
 }

+ 283 - 35
server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

@@ -14,9 +14,11 @@ import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.FloatVectorValues;
 import org.apache.lucene.index.MergeState;
 import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.store.IOContext;
 import org.apache.lucene.store.IndexInput;
 import org.apache.lucene.store.IndexOutput;
 import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.IntToIntFunction;
 import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
 import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
 import org.elasticsearch.logging.LogManager;
@@ -49,32 +51,58 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
         CentroidSupplier centroidSupplier,
         FloatVectorValues floatVectorValues,
         IndexOutput postingsOutput,
-        int[][] assignmentsByCluster
+        int[] assignments,
+        int[] overspillAssignments
     ) throws IOException {
+        int[] centroidVectorCount = new int[centroidSupplier.size()];
+        for (int i = 0; i < assignments.length; i++) {
+            centroidVectorCount[assignments[i]]++;
+            // if soar assignments are present, count them as well
+            if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
+                centroidVectorCount[overspillAssignments[i]]++;
+            }
+        }
+
+        int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
+        for (int c = 0; c < centroidSupplier.size(); c++) {
+            assignmentsByCluster[c] = new int[centroidVectorCount[c]];
+        }
+        Arrays.fill(centroidVectorCount, 0);
+
+        for (int i = 0; i < assignments.length; i++) {
+            int c = assignments[i];
+            assignmentsByCluster[c][centroidVectorCount[c]++] = i;
+            // if soar assignments are present, add them to the cluster as well
+            if (overspillAssignments.length > i) {
+                int s = overspillAssignments[i];
+                if (s != -1) {
+                    assignmentsByCluster[s][centroidVectorCount[s]++] = i;
+                }
+            }
+        }
         // write the posting lists
         final long[] offsets = new long[centroidSupplier.size()];
-        OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
         DocIdsWriter docIdsWriter = new DocIdsWriter();
-        DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
-            ES91OSQVectorsScorer.BULK_SIZE,
-            quantizer,
+        DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
+        OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
             floatVectorValues,
-            postingsOutput
+            fieldInfo.getVectorDimension(),
+            new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
         );
         for (int c = 0; c < centroidSupplier.size(); c++) {
             float[] centroid = centroidSupplier.centroid(c);
-            // TODO: add back in sorting vectors by distance to centroid
             int[] cluster = assignmentsByCluster[c];
             // TODO align???
             offsets[c] = postingsOutput.getFilePointer();
             int size = cluster.length;
             postingsOutput.writeVInt(size);
             postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
+            onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
             // TODO we might want to consider putting the docIds in a separate file
             // to aid with only having to fetch vectors from slower storage when they are required
             // keeping them in the same file indicates we pull the entire file into cache
             docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
-            bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
+            bulkWriter.writeVectors(onHeapQuantizedVectors);
         }
 
         if (logger.isDebugEnabled()) {
@@ -84,6 +112,124 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
         return offsets;
     }
 
+    @Override
+    long[] buildAndWritePostingsLists(
+        FieldInfo fieldInfo,
+        CentroidSupplier centroidSupplier,
+        FloatVectorValues floatVectorValues,
+        IndexOutput postingsOutput,
+        MergeState mergeState,
+        int[] assignments,
+        int[] overspillAssignments
+    ) throws IOException {
+        // first, quantize all the vectors into a temporary file
+        String quantizedVectorsTempName = null;
+        IndexOutput quantizedVectorsTemp = null;
+        boolean success = false;
+        try {
+            quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "qvec_", IOContext.DEFAULT);
+            quantizedVectorsTempName = quantizedVectorsTemp.getName();
+            OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
+            int[] quantized = new int[fieldInfo.getVectorDimension()];
+            byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8];
+            float[] overspillScratch = new float[fieldInfo.getVectorDimension()];
+            for (int i = 0; i < assignments.length; i++) {
+                int c = assignments[i];
+                float[] centroid = centroidSupplier.centroid(c);
+                float[] vector = floatVectorValues.vectorValue(i);
+                boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1;
+                // if overspilling, this means we quantize twice, and quantization mutates the in-memory representation of the vector
+                // so, make a copy of the vector to avoid mutating it
+                if (overspill) {
+                    System.arraycopy(vector, 0, overspillScratch, 0, fieldInfo.getVectorDimension());
+                }
+
+                OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid);
+                BQVectorUtils.packAsBinary(quantized, binary);
+                writeQuantizedValue(quantizedVectorsTemp, binary, result);
+                if (overspill) {
+                    int s = overspillAssignments[i];
+                    // write the overspill vector as well
+                    result = quantizer.scalarQuantize(overspillScratch, quantized, (byte) 1, centroidSupplier.centroid(s));
+                    BQVectorUtils.packAsBinary(quantized, binary);
+                    writeQuantizedValue(quantizedVectorsTemp, binary, result);
+                } else {
+                    // write a zero vector for the overspill
+                    Arrays.fill(binary, (byte) 0);
+                    OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0);
+                    writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult);
+                }
+            }
+            // close the temporary file so we can read it later
+            quantizedVectorsTemp.close();
+            success = true;
+        } finally {
+            if (success == false && quantizedVectorsTemp != null) {
+                mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName());
+            }
+        }
+        int[] centroidVectorCount = new int[centroidSupplier.size()];
+        for (int i = 0; i < assignments.length; i++) {
+            centroidVectorCount[assignments[i]]++;
+            // if soar assignments are present, count them as well
+            if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
+                centroidVectorCount[overspillAssignments[i]]++;
+            }
+        }
+
+        int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
+        boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][];
+        for (int c = 0; c < centroidSupplier.size(); c++) {
+            assignmentsByCluster[c] = new int[centroidVectorCount[c]];
+            isOverspillByCluster[c] = new boolean[centroidVectorCount[c]];
+        }
+        Arrays.fill(centroidVectorCount, 0);
+
+        for (int i = 0; i < assignments.length; i++) {
+            int c = assignments[i];
+            assignmentsByCluster[c][centroidVectorCount[c]++] = i;
+            // if soar assignments are present, add them to the cluster as well
+            if (overspillAssignments.length > i) {
+                int s = overspillAssignments[i];
+                if (s != -1) {
+                    assignmentsByCluster[s][centroidVectorCount[s]] = i;
+                    isOverspillByCluster[s][centroidVectorCount[s]++] = true;
+                }
+            }
+        }
+        // now we can read the quantized vectors from the temporary file
+        try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) {
+            final long[] offsets = new long[centroidSupplier.size()];
+            OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors(
+                quantizedVectorsInput,
+                fieldInfo.getVectorDimension()
+            );
+            DocIdsWriter docIdsWriter = new DocIdsWriter();
+            DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
+            for (int c = 0; c < centroidSupplier.size(); c++) {
+                float[] centroid = centroidSupplier.centroid(c);
+                int[] cluster = assignmentsByCluster[c];
+                boolean[] isOverspill = isOverspillByCluster[c];
+                // TODO align???
+                offsets[c] = postingsOutput.getFilePointer();
+                int size = cluster.length;
+                postingsOutput.writeVInt(size);
+                postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
+                offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
+                // TODO we might want to consider putting the docIds in a separate file
+                // to aid with only having to fetch vectors from slower storage when they are required
+                // keeping them in the same file indicates we pull the entire file into cache
+                docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
+                bulkWriter.writeVectors(offHeapQuantizedVectors);
+            }
+
+            if (logger.isDebugEnabled()) {
+                printClusterQualityStatistics(assignmentsByCluster);
+            }
+            return offsets;
+        }
+    }
+
     private static void printClusterQualityStatistics(int[][] clusters) {
         float min = Float.MAX_VALUE;
         float max = Float.MIN_VALUE;
@@ -210,33 +356,7 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
         float[][] centroids = kMeansResult.centroids();
         int[] assignments = kMeansResult.assignments();
         int[] soarAssignments = kMeansResult.soarAssignments();
-        int[] centroidVectorCount = new int[centroids.length];
-        for (int i = 0; i < assignments.length; i++) {
-            centroidVectorCount[assignments[i]]++;
-            // if soar assignments are present, count them as well
-            if (soarAssignments.length > i && soarAssignments[i] != -1) {
-                centroidVectorCount[soarAssignments[i]]++;
-            }
-        }
-
-        int[][] assignmentsByCluster = new int[centroids.length][];
-        for (int c = 0; c < centroids.length; c++) {
-            assignmentsByCluster[c] = new int[centroidVectorCount[c]];
-        }
-        Arrays.fill(centroidVectorCount, 0);
-
-        for (int i = 0; i < assignments.length; i++) {
-            int c = assignments[i];
-            assignmentsByCluster[c][centroidVectorCount[c]++] = i;
-            // if soar assignments are present, add them to the cluster as well
-            if (soarAssignments.length > i) {
-                int s = soarAssignments[i];
-                if (s != -1) {
-                    assignmentsByCluster[s][centroidVectorCount[s]++] = i;
-                }
-            }
-        }
-        return new CentroidAssignments(centroids, assignmentsByCluster);
+        return new CentroidAssignments(centroids, assignments, soarAssignments);
     }
 
     static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
@@ -281,4 +401,132 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
             return scratch;
         }
     }
+
+    interface QuantizedVectorValues {
+        int count();
+
+        byte[] next() throws IOException;
+
+        OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException;
+    }
+
+    interface IntToBooleanFunction {
+        boolean apply(int ord);
+    }
+
+    static class OnHeapQuantizedVectors implements QuantizedVectorValues {
+        private final FloatVectorValues vectorValues;
+        private final OptimizedScalarQuantizer quantizer;
+        private final byte[] quantizedVector;
+        private final int[] quantizedVectorScratch;
+        private OptimizedScalarQuantizer.QuantizationResult corrections;
+        private float[] currentCentroid;
+        private IntToIntFunction ordTransformer = null;
+        private int currOrd = -1;
+        private int count;
+
+        OnHeapQuantizedVectors(FloatVectorValues vectorValues, int dimension, OptimizedScalarQuantizer quantizer) {
+            this.vectorValues = vectorValues;
+            this.quantizer = quantizer;
+            this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
+            this.quantizedVectorScratch = new int[dimension];
+            this.corrections = null;
+        }
+
+        private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) {
+            this.currentCentroid = centroid;
+            this.ordTransformer = ordTransformer;
+            this.currOrd = -1;
+            this.count = count;
+        }
+
+        @Override
+        public int count() {
+            return count;
+        }
+
+        @Override
+        public byte[] next() throws IOException {
+            if (currOrd >= count() - 1) {
+                throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count());
+            }
+            currOrd++;
+            int ord = ordTransformer.apply(currOrd);
+            float[] vector = vectorValues.vectorValue(ord);
+            corrections = quantizer.scalarQuantize(vector, quantizedVectorScratch, (byte) 1, currentCentroid);
+            BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector);
+            return quantizedVector;
+        }
+
+        @Override
+        public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
+            if (currOrd == -1) {
+                throw new IllegalStateException("No vector read yet, call next first");
+            }
+            return corrections;
+        }
+    }
+
+    static class OffHeapQuantizedVectors implements QuantizedVectorValues {
+        private final IndexInput quantizedVectorsInput;
+        private final byte[] binaryScratch;
+        private final float[] corrections = new float[3];
+
+        private final int vectorByteSize;
+        private short bitSum;
+        private int currOrd = -1;
+        private int count;
+        private IntToBooleanFunction isOverspill = null;
+        private IntToIntFunction ordTransformer = null;
+
+        OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) {
+            this.quantizedVectorsInput = quantizedVectorsInput;
+            this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
+            this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES);
+        }
+
+        private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) {
+            this.count = count;
+            this.isOverspill = isOverspill;
+            this.ordTransformer = ordTransformer;
+            this.currOrd = -1;
+        }
+
+        @Override
+        public int count() {
+            return count;
+        }
+
+        @Override
+        public byte[] next() throws IOException {
+            if (currOrd >= count - 1) {
+                throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count);
+            }
+            currOrd++;
+            int ord = ordTransformer.apply(currOrd);
+            boolean isOverspill = this.isOverspill.apply(currOrd);
+            return getVector(ord, isOverspill);
+        }
+
+        @Override
+        public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
+            if (currOrd == -1) {
+                throw new IllegalStateException("No vector read yet, call readQuantizedVector first");
+            }
+            return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum);
+        }
+
+        byte[] getVector(int ord, boolean isOverspill) throws IOException {
+            readQuantizedVector(ord, isOverspill);
+            return binaryScratch;
+        }
+
+        public void readQuantizedVector(int ord, boolean isOverspill) throws IOException {
+            long offset = (long) ord * (vectorByteSize * 2L) + (isOverspill ? vectorByteSize : 0);
+            quantizedVectorsInput.seek(offset);
+            quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length);
+            quantizedVectorsInput.readFloats(corrections, 0, 3);
+            bitSum = quantizedVectorsInput.readShort();
+        }
+    }
 }

+ 15 - 32
server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java

@@ -9,34 +9,25 @@
 
 package org.elasticsearch.index.codec.vectors;
 
-import org.apache.lucene.index.FloatVectorValues;
 import org.apache.lucene.store.IndexOutput;
-import org.apache.lucene.util.hnsw.IntToIntFunction;
 
 import java.io.IOException;
 
-import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
-import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
-
 /**
  * Base class for bulk writers that write vectors to disk using the BBQ encoding.
  * This class provides the structure for writing vectors in bulk, with specific
  * implementations for different bit sizes strategies.
  */
-public abstract class DiskBBQBulkWriter {
+abstract class DiskBBQBulkWriter {
     protected final int bulkSize;
-    protected final OptimizedScalarQuantizer quantizer;
     protected final IndexOutput out;
-    protected final FloatVectorValues fvv;
 
-    protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
+    protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) {
         this.bulkSize = bulkSize;
-        this.quantizer = quantizer;
         this.out = out;
-        this.fvv = fvv;
     }
 
-    public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException;
+    abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException;
 
     private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
         for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
@@ -64,39 +55,31 @@ public abstract class DiskBBQBulkWriter {
         out.writeShort((short) targetComponentSum);
     }
 
-    public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
-        private final byte[] binarized;
-        private final int[] initQuantized;
+    static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
         private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
 
-        public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
-            super(bulkSize, quantizer, fvv, out);
-            this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
-            this.initQuantized = new int[fvv.dimension()];
+        OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
+            super(bulkSize, out);
             this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
         }
 
         @Override
-        public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException {
-            int limit = count - bulkSize + 1;
+        void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException {
+            int limit = qvv.count() - bulkSize + 1;
             int i = 0;
             for (; i < limit; i += bulkSize) {
                 for (int j = 0; j < bulkSize; j++) {
-                    int ord = ords.apply(i + j);
-                    float[] fv = fvv.vectorValue(ord);
-                    corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
-                    packAsBinary(initQuantized, binarized);
-                    out.writeBytes(binarized, binarized.length);
+                    byte[] qv = qvv.next();
+                    corrections[j] = qvv.getCorrections();
+                    out.writeBytes(qv, qv.length);
                 }
                 writeCorrections(corrections, out);
             }
             // write tail
-            for (; i < count; ++i) {
-                int ord = ords.apply(i);
-                float[] fv = fvv.vectorValue(ord);
-                OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
-                packAsBinary(initQuantized, binarized);
-                out.writeBytes(binarized, binarized.length);
+            for (; i < qvv.count(); ++i) {
+                byte[] qv = qvv.next();
+                OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections();
+                out.writeBytes(qv, qv.length);
                 writeCorrection(correction, out);
             }
         }

+ 21 - 5
server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

@@ -139,7 +139,18 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
         CentroidSupplier centroidSupplier,
         FloatVectorValues floatVectorValues,
         IndexOutput postingsOutput,
-        int[][] assignmentsByCluster
+        int[] assignments,
+        int[] overspillAssignments
+    ) throws IOException;
+
+    abstract long[] buildAndWritePostingsLists(
+        FieldInfo fieldInfo,
+        CentroidSupplier centroidSupplier,
+        FloatVectorValues floatVectorValues,
+        IndexOutput postingsOutput,
+        MergeState mergeState,
+        int[] assignments,
+        int[] overspillAssignments
     ) throws IOException;
 
     abstract CentroidSupplier createCentroidSupplier(
@@ -174,7 +185,8 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
                 centroidSupplier,
                 floatVectorValues,
                 ivfClusters,
-                centroidAssignments.assignmentsByCluster()
+                centroidAssignments.assignments(),
+                centroidAssignments.overspillAssignments()
             );
             // write posting lists
             writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
@@ -284,7 +296,8 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
             final long centroidOffset;
             final long centroidLength;
             final int numCentroids;
-            final int[][] assignmentsByCluster;
+            final int[] assignments;
+            final int[] overspillAssignments;
             final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
             String centroidTempName = null;
             IndexOutput centroidTemp = null;
@@ -300,7 +313,8 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
                     calculatedGlobalCentroid
                 );
                 numCentroids = centroidAssignments.numCentroids();
-                assignmentsByCluster = centroidAssignments.assignmentsByCluster();
+                assignments = centroidAssignments.assignments();
+                overspillAssignments = centroidAssignments.overspillAssignments();
                 success = true;
             } finally {
                 if (success == false && centroidTempName != null) {
@@ -337,7 +351,9 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
                         centroidSupplier,
                         floatVectorValues,
                         ivfClusters,
-                        assignmentsByCluster
+                        mergeState,
+                        assignments,
+                        overspillAssignments
                     );
                     assert offsets.length == centroidSupplier.size();
                     writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);