|
@@ -14,9 +14,11 @@ import org.apache.lucene.index.FieldInfo;
|
|
|
import org.apache.lucene.index.FloatVectorValues;
|
|
|
import org.apache.lucene.index.MergeState;
|
|
|
import org.apache.lucene.index.SegmentWriteState;
|
|
|
+import org.apache.lucene.store.IOContext;
|
|
|
import org.apache.lucene.store.IndexInput;
|
|
|
import org.apache.lucene.store.IndexOutput;
|
|
|
import org.apache.lucene.util.VectorUtil;
|
|
|
+import org.apache.lucene.util.hnsw.IntToIntFunction;
|
|
|
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
|
|
|
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
|
|
|
import org.elasticsearch.logging.LogManager;
|
|
@@ -49,32 +51,58 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|
|
CentroidSupplier centroidSupplier,
|
|
|
FloatVectorValues floatVectorValues,
|
|
|
IndexOutput postingsOutput,
|
|
|
- int[][] assignmentsByCluster
|
|
|
+ int[] assignments,
|
|
|
+ int[] overspillAssignments
|
|
|
) throws IOException {
|
|
|
+ int[] centroidVectorCount = new int[centroidSupplier.size()];
|
|
|
+ for (int i = 0; i < assignments.length; i++) {
|
|
|
+ centroidVectorCount[assignments[i]]++;
|
|
|
+ // if soar assignments are present, count them as well
|
|
|
+ if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
|
|
|
+ centroidVectorCount[overspillAssignments[i]]++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
|
|
|
+ for (int c = 0; c < centroidSupplier.size(); c++) {
|
|
|
+ assignmentsByCluster[c] = new int[centroidVectorCount[c]];
|
|
|
+ }
|
|
|
+ Arrays.fill(centroidVectorCount, 0);
|
|
|
+
|
|
|
+ for (int i = 0; i < assignments.length; i++) {
|
|
|
+ int c = assignments[i];
|
|
|
+ assignmentsByCluster[c][centroidVectorCount[c]++] = i;
|
|
|
+ // if soar assignments are present, add them to the cluster as well
|
|
|
+ if (overspillAssignments.length > i) {
|
|
|
+ int s = overspillAssignments[i];
|
|
|
+ if (s != -1) {
|
|
|
+ assignmentsByCluster[s][centroidVectorCount[s]++] = i;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
// write the posting lists
|
|
|
final long[] offsets = new long[centroidSupplier.size()];
|
|
|
- OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
|
|
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
|
|
- DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
|
|
|
- ES91OSQVectorsScorer.BULK_SIZE,
|
|
|
- quantizer,
|
|
|
+ DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
|
|
|
+ OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
|
|
|
floatVectorValues,
|
|
|
- postingsOutput
|
|
|
+ fieldInfo.getVectorDimension(),
|
|
|
+ new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
|
|
|
);
|
|
|
for (int c = 0; c < centroidSupplier.size(); c++) {
|
|
|
float[] centroid = centroidSupplier.centroid(c);
|
|
|
- // TODO: add back in sorting vectors by distance to centroid
|
|
|
int[] cluster = assignmentsByCluster[c];
|
|
|
// TODO align???
|
|
|
offsets[c] = postingsOutput.getFilePointer();
|
|
|
int size = cluster.length;
|
|
|
postingsOutput.writeVInt(size);
|
|
|
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
|
|
|
+ onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
|
|
|
// TODO we might want to consider putting the docIds in a separate file
|
|
|
// to aid with only having to fetch vectors from slower storage when they are required
|
|
|
// keeping them in the same file indicates we pull the entire file into cache
|
|
|
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
|
|
|
- bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
|
|
|
+ bulkWriter.writeVectors(onHeapQuantizedVectors);
|
|
|
}
|
|
|
|
|
|
if (logger.isDebugEnabled()) {
|
|
@@ -84,6 +112,124 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|
|
return offsets;
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ long[] buildAndWritePostingsLists(
|
|
|
+ FieldInfo fieldInfo,
|
|
|
+ CentroidSupplier centroidSupplier,
|
|
|
+ FloatVectorValues floatVectorValues,
|
|
|
+ IndexOutput postingsOutput,
|
|
|
+ MergeState mergeState,
|
|
|
+ int[] assignments,
|
|
|
+ int[] overspillAssignments
|
|
|
+ ) throws IOException {
|
|
|
+ // first, quantize all the vectors into a temporary file
|
|
|
+ String quantizedVectorsTempName = null;
|
|
|
+ IndexOutput quantizedVectorsTemp = null;
|
|
|
+ boolean success = false;
|
|
|
+ try {
|
|
|
+ quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "qvec_", IOContext.DEFAULT);
|
|
|
+ quantizedVectorsTempName = quantizedVectorsTemp.getName();
|
|
|
+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
|
|
+ int[] quantized = new int[fieldInfo.getVectorDimension()];
|
|
|
+ byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8];
|
|
|
+ float[] overspillScratch = new float[fieldInfo.getVectorDimension()];
|
|
|
+ for (int i = 0; i < assignments.length; i++) {
|
|
|
+ int c = assignments[i];
|
|
|
+ float[] centroid = centroidSupplier.centroid(c);
|
|
|
+ float[] vector = floatVectorValues.vectorValue(i);
|
|
|
+ boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1;
|
|
|
+ // if overspilling, this means we quantize twice, and quantization mutates the in-memory representation of the vector
|
|
|
+ // so, make a copy of the vector to avoid mutating it
|
|
|
+ if (overspill) {
|
|
|
+ System.arraycopy(vector, 0, overspillScratch, 0, fieldInfo.getVectorDimension());
|
|
|
+ }
|
|
|
+
|
|
|
+ OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid);
|
|
|
+ BQVectorUtils.packAsBinary(quantized, binary);
|
|
|
+ writeQuantizedValue(quantizedVectorsTemp, binary, result);
|
|
|
+ if (overspill) {
|
|
|
+ int s = overspillAssignments[i];
|
|
|
+ // write the overspill vector as well
|
|
|
+ result = quantizer.scalarQuantize(overspillScratch, quantized, (byte) 1, centroidSupplier.centroid(s));
|
|
|
+ BQVectorUtils.packAsBinary(quantized, binary);
|
|
|
+ writeQuantizedValue(quantizedVectorsTemp, binary, result);
|
|
|
+ } else {
|
|
|
+ // write a zero vector for the overspill
|
|
|
+ Arrays.fill(binary, (byte) 0);
|
|
|
+ OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0);
|
|
|
+ writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // close the temporary file so we can read it later
|
|
|
+ quantizedVectorsTemp.close();
|
|
|
+ success = true;
|
|
|
+ } finally {
|
|
|
+ if (success == false && quantizedVectorsTemp != null) {
|
|
|
+ mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ int[] centroidVectorCount = new int[centroidSupplier.size()];
|
|
|
+ for (int i = 0; i < assignments.length; i++) {
|
|
|
+ centroidVectorCount[assignments[i]]++;
|
|
|
+ // if soar assignments are present, count them as well
|
|
|
+ if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
|
|
|
+ centroidVectorCount[overspillAssignments[i]]++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
|
|
|
+ boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][];
|
|
|
+ for (int c = 0; c < centroidSupplier.size(); c++) {
|
|
|
+ assignmentsByCluster[c] = new int[centroidVectorCount[c]];
|
|
|
+ isOverspillByCluster[c] = new boolean[centroidVectorCount[c]];
|
|
|
+ }
|
|
|
+ Arrays.fill(centroidVectorCount, 0);
|
|
|
+
|
|
|
+ for (int i = 0; i < assignments.length; i++) {
|
|
|
+ int c = assignments[i];
|
|
|
+ assignmentsByCluster[c][centroidVectorCount[c]++] = i;
|
|
|
+ // if soar assignments are present, add them to the cluster as well
|
|
|
+ if (overspillAssignments.length > i) {
|
|
|
+ int s = overspillAssignments[i];
|
|
|
+ if (s != -1) {
|
|
|
+ assignmentsByCluster[s][centroidVectorCount[s]] = i;
|
|
|
+ isOverspillByCluster[s][centroidVectorCount[s]++] = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // now we can read the quantized vectors from the temporary file
|
|
|
+ try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) {
|
|
|
+ final long[] offsets = new long[centroidSupplier.size()];
|
|
|
+ OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors(
|
|
|
+ quantizedVectorsInput,
|
|
|
+ fieldInfo.getVectorDimension()
|
|
|
+ );
|
|
|
+ DocIdsWriter docIdsWriter = new DocIdsWriter();
|
|
|
+ DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
|
|
|
+ for (int c = 0; c < centroidSupplier.size(); c++) {
|
|
|
+ float[] centroid = centroidSupplier.centroid(c);
|
|
|
+ int[] cluster = assignmentsByCluster[c];
|
|
|
+ boolean[] isOverspill = isOverspillByCluster[c];
|
|
|
+ // TODO align???
|
|
|
+ offsets[c] = postingsOutput.getFilePointer();
|
|
|
+ int size = cluster.length;
|
|
|
+ postingsOutput.writeVInt(size);
|
|
|
+ postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
|
|
|
+ offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
|
|
|
+ // TODO we might want to consider putting the docIds in a separate file
|
|
|
+ // to aid with only having to fetch vectors from slower storage when they are required
|
|
|
+ // keeping them in the same file indicates we pull the entire file into cache
|
|
|
+ docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
|
|
|
+ bulkWriter.writeVectors(offHeapQuantizedVectors);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (logger.isDebugEnabled()) {
|
|
|
+ printClusterQualityStatistics(assignmentsByCluster);
|
|
|
+ }
|
|
|
+ return offsets;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private static void printClusterQualityStatistics(int[][] clusters) {
|
|
|
float min = Float.MAX_VALUE;
|
|
|
float max = Float.MIN_VALUE;
|
|
@@ -210,33 +356,7 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|
|
float[][] centroids = kMeansResult.centroids();
|
|
|
int[] assignments = kMeansResult.assignments();
|
|
|
int[] soarAssignments = kMeansResult.soarAssignments();
|
|
|
- int[] centroidVectorCount = new int[centroids.length];
|
|
|
- for (int i = 0; i < assignments.length; i++) {
|
|
|
- centroidVectorCount[assignments[i]]++;
|
|
|
- // if soar assignments are present, count them as well
|
|
|
- if (soarAssignments.length > i && soarAssignments[i] != -1) {
|
|
|
- centroidVectorCount[soarAssignments[i]]++;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- int[][] assignmentsByCluster = new int[centroids.length][];
|
|
|
- for (int c = 0; c < centroids.length; c++) {
|
|
|
- assignmentsByCluster[c] = new int[centroidVectorCount[c]];
|
|
|
- }
|
|
|
- Arrays.fill(centroidVectorCount, 0);
|
|
|
-
|
|
|
- for (int i = 0; i < assignments.length; i++) {
|
|
|
- int c = assignments[i];
|
|
|
- assignmentsByCluster[c][centroidVectorCount[c]++] = i;
|
|
|
- // if soar assignments are present, add them to the cluster as well
|
|
|
- if (soarAssignments.length > i) {
|
|
|
- int s = soarAssignments[i];
|
|
|
- if (s != -1) {
|
|
|
- assignmentsByCluster[s][centroidVectorCount[s]++] = i;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return new CentroidAssignments(centroids, assignmentsByCluster);
|
|
|
+ return new CentroidAssignments(centroids, assignments, soarAssignments);
|
|
|
}
|
|
|
|
|
|
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
|
|
@@ -281,4 +401,132 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
|
|
return scratch;
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ interface QuantizedVectorValues {
|
|
|
+ int count();
|
|
|
+
|
|
|
+ byte[] next() throws IOException;
|
|
|
+
|
|
|
+ OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException;
|
|
|
+ }
|
|
|
+
|
|
|
+ interface IntToBooleanFunction {
|
|
|
+ boolean apply(int ord);
|
|
|
+ }
|
|
|
+
|
|
|
+ static class OnHeapQuantizedVectors implements QuantizedVectorValues {
|
|
|
+ private final FloatVectorValues vectorValues;
|
|
|
+ private final OptimizedScalarQuantizer quantizer;
|
|
|
+ private final byte[] quantizedVector;
|
|
|
+ private final int[] quantizedVectorScratch;
|
|
|
+ private OptimizedScalarQuantizer.QuantizationResult corrections;
|
|
|
+ private float[] currentCentroid;
|
|
|
+ private IntToIntFunction ordTransformer = null;
|
|
|
+ private int currOrd = -1;
|
|
|
+ private int count;
|
|
|
+
|
|
|
+ OnHeapQuantizedVectors(FloatVectorValues vectorValues, int dimension, OptimizedScalarQuantizer quantizer) {
|
|
|
+ this.vectorValues = vectorValues;
|
|
|
+ this.quantizer = quantizer;
|
|
|
+ this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
|
|
|
+ this.quantizedVectorScratch = new int[dimension];
|
|
|
+ this.corrections = null;
|
|
|
+ }
|
|
|
+
|
|
|
+ private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) {
|
|
|
+ this.currentCentroid = centroid;
|
|
|
+ this.ordTransformer = ordTransformer;
|
|
|
+ this.currOrd = -1;
|
|
|
+ this.count = count;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int count() {
|
|
|
+ return count;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public byte[] next() throws IOException {
|
|
|
+ if (currOrd >= count() - 1) {
|
|
|
+ throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count());
|
|
|
+ }
|
|
|
+ currOrd++;
|
|
|
+ int ord = ordTransformer.apply(currOrd);
|
|
|
+ float[] vector = vectorValues.vectorValue(ord);
|
|
|
+ corrections = quantizer.scalarQuantize(vector, quantizedVectorScratch, (byte) 1, currentCentroid);
|
|
|
+ BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector);
|
|
|
+ return quantizedVector;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
|
|
|
+ if (currOrd == -1) {
|
|
|
+ throw new IllegalStateException("No vector read yet, call next first");
|
|
|
+ }
|
|
|
+ return corrections;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ static class OffHeapQuantizedVectors implements QuantizedVectorValues {
|
|
|
+ private final IndexInput quantizedVectorsInput;
|
|
|
+ private final byte[] binaryScratch;
|
|
|
+ private final float[] corrections = new float[3];
|
|
|
+
|
|
|
+ private final int vectorByteSize;
|
|
|
+ private short bitSum;
|
|
|
+ private int currOrd = -1;
|
|
|
+ private int count;
|
|
|
+ private IntToBooleanFunction isOverspill = null;
|
|
|
+ private IntToIntFunction ordTransformer = null;
|
|
|
+
|
|
|
+ OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) {
|
|
|
+ this.quantizedVectorsInput = quantizedVectorsInput;
|
|
|
+ this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
|
|
|
+ this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) {
|
|
|
+ this.count = count;
|
|
|
+ this.isOverspill = isOverspill;
|
|
|
+ this.ordTransformer = ordTransformer;
|
|
|
+ this.currOrd = -1;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int count() {
|
|
|
+ return count;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public byte[] next() throws IOException {
|
|
|
+ if (currOrd >= count - 1) {
|
|
|
+ throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count);
|
|
|
+ }
|
|
|
+ currOrd++;
|
|
|
+ int ord = ordTransformer.apply(currOrd);
|
|
|
+ boolean isOverspill = this.isOverspill.apply(currOrd);
|
|
|
+ return getVector(ord, isOverspill);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
|
|
|
+ if (currOrd == -1) {
|
|
|
+ throw new IllegalStateException("No vector read yet, call readQuantizedVector first");
|
|
|
+ }
|
|
|
+ return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum);
|
|
|
+ }
|
|
|
+
|
|
|
+ byte[] getVector(int ord, boolean isOverspill) throws IOException {
|
|
|
+ readQuantizedVector(ord, isOverspill);
|
|
|
+ return binaryScratch;
|
|
|
+ }
|
|
|
+
|
|
|
+ public void readQuantizedVector(int ord, boolean isOverspill) throws IOException {
|
|
|
+ long offset = (long) ord * (vectorByteSize * 2L) + (isOverspill ? vectorByteSize : 0);
|
|
|
+ quantizedVectorsInput.seek(offset);
|
|
|
+ quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length);
|
|
|
+ quantizedVectorsInput.readFloats(corrections, 0, 3);
|
|
|
+ bitSum = quantizedVectorsInput.readShort();
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|