Browse Source

Add vector distance scoring to micro benchmarks (#92340)

This change adds the foundation for micro benchmarks for vector distance methods used in scoring 
vector relevance specifically for scripting (though, some of these are proxies to Lucene methods used 
directly by HNSW). The parameters included are element (byte, float), type (knn, binary), dims (96), and 
function (dot, cosine, l1, l2). The design attempts to be easy to extend for local individual tests as well.
Jack Conradson 2 years ago
parent
commit
8ebf29ba74

+ 467 - 0
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java

@@ -0,0 +1,467 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.benchmark.vector;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.Version;
+import org.elasticsearch.script.field.vectors.BinaryDenseVector;
+import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector;
+import org.elasticsearch.script.field.vectors.ByteKnnDenseVector;
+import org.elasticsearch.script.field.vectors.KnnDenseVector;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OperationsPerInvocation;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+/**
+ * Various benchmarks for the distance functions
+ * used by indexed and non-indexed vectors.
+ * Parameters include element, dims, function, and type.
+ * For individual local tests it may be useful to increase
+ * fork, measurement, and operations per invocation. (Note
+ * to also update the benchmark loop if operations per invocation
+ * is increased.)
+ */
+@Fork(1)
+@Warmup(iterations = 1)
+@Measurement(iterations = 2)
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@OperationsPerInvocation(25000)
+@State(Scope.Benchmark)
+public class DistanceFunctionBenchmark {
+
+    @Param({ "float", "byte" })
+    private String element;
+
+    @Param({ "96" })
+    private int dims;
+
+    @Param({ "dot", "cosine", "l1", "l2" })
+    private String function;
+
+    @Param({ "knn", "binary" })
+    private String type;
+
+    private abstract static class BenchmarkFunction {
+
+        final int dims;
+
+        private BenchmarkFunction(int dims) {
+            this.dims = dims;
+        }
+
+        abstract void execute(Consumer<Object> consumer);
+    }
+
+    private abstract static class KnnFloatBenchmarkFunction extends BenchmarkFunction {
+
+        final float[] docVector;
+        final float[] queryVector;
+
+        private KnnFloatBenchmarkFunction(int dims, boolean normalize) {
+            super(dims);
+
+            docVector = new float[dims];
+            queryVector = new float[dims];
+
+            float docMagnitude = 0f;
+            float queryMagnitude = 0f;
+
+            for (int i = 0; i < dims; ++i) {
+                docVector[i] = (float) (dims - i);
+                queryVector[i] = (float) i;
+
+                docMagnitude += (float) (dims - i);
+                queryMagnitude += (float) i;
+            }
+
+            docMagnitude /= dims;
+            queryMagnitude /= dims;
+
+            if (normalize) {
+                for (int i = 0; i < dims; ++i) {
+                    docVector[i] /= docMagnitude;
+                    queryVector[i] /= queryMagnitude;
+                }
+            }
+        }
+    }
+
+    private abstract static class BinaryFloatBenchmarkFunction extends BenchmarkFunction {
+
+        final BytesRef docVector;
+        final float[] queryVector;
+
+        private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {
+            super(dims);
+
+            float[] docVector = new float[dims];
+            queryVector = new float[dims];
+
+            float docMagnitude = 0f;
+            float queryMagnitude = 0f;
+
+            for (int i = 0; i < dims; ++i) {
+                docVector[i] = (float) (dims - i);
+                queryVector[i] = (float) i;
+
+                docMagnitude += (float) (dims - i);
+                queryMagnitude += (float) i;
+            }
+
+            docMagnitude /= dims;
+            queryMagnitude /= dims;
+
+            ByteBuffer byteBuffer = ByteBuffer.allocate(dims * 4 + 4);
+
+            for (int i = 0; i < dims; ++i) {
+                if (normalize) {
+                    docVector[i] /= docMagnitude;
+                    queryVector[i] /= queryMagnitude;
+                }
+
+                byteBuffer.putFloat(docVector[i]);
+            }
+
+            byteBuffer.putFloat(docMagnitude);
+            this.docVector = new BytesRef(byteBuffer.array());
+        }
+    }
+
+    private abstract static class KnnByteBenchmarkFunction extends BenchmarkFunction {
+
+        final BytesRef docVector;
+        final byte[] queryVector;
+
+        final float queryMagnitude;
+
+        private KnnByteBenchmarkFunction(int dims) {
+            super(dims);
+
+            ByteBuffer docVector = ByteBuffer.allocate(dims);
+            queryVector = new byte[dims];
+
+            float queryMagnitude = 0f;
+
+            for (int i = 0; i < dims; ++i) {
+                docVector.put((byte) (dims - i));
+                queryVector[i] = (byte) i;
+
+                queryMagnitude += (float) i;
+            }
+
+            this.docVector = new BytesRef(docVector.array());
+            this.queryMagnitude = queryMagnitude / dims;
+        }
+    }
+
+    private abstract static class BinaryByteBenchmarkFunction extends BenchmarkFunction {
+
+        final BytesRef docVector;
+        final byte[] queryVector;
+
+        final float queryMagnitude;
+
+        private BinaryByteBenchmarkFunction(int dims) {
+            super(dims);
+
+            ByteBuffer docVector = ByteBuffer.allocate(dims + 4);
+            queryVector = new byte[dims];
+
+            float docMagnitude = 0f;
+            float queryMagnitude = 0f;
+
+            for (int i = 0; i < dims; ++i) {
+                docVector.put((byte) (dims - i));
+                queryVector[i] = (byte) i;
+
+                docMagnitude += (float) (dims - i);
+                queryMagnitude += (float) i;
+            }
+
+            docVector.putFloat(docMagnitude / dims);
+            this.docVector = new BytesRef(docVector.array());
+            this.queryMagnitude = queryMagnitude / dims;
+
+        }
+    }
+
+    private static class DotKnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
+
+        private DotKnnFloatBenchmarkFunction(int dims) {
+            super(dims, false);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new KnnDenseVector(docVector).dotProduct(queryVector);
+        }
+    }
+
+    private static class DotKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
+
+        private DotKnnByteBenchmarkFunction(int dims) {
+            super(dims);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new ByteKnnDenseVector(docVector).dotProduct(queryVector);
+        }
+    }
+
+    private static class DotBinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
+
+        private DotBinaryFloatBenchmarkFunction(int dims) {
+            super(dims, false);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new BinaryDenseVector(docVector, dims, Version.CURRENT).dotProduct(queryVector);
+        }
+    }
+
+    private static class DotBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
+
+        private DotBinaryByteBenchmarkFunction(int dims) {
+            super(dims);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new ByteBinaryDenseVector(docVector, dims).dotProduct(queryVector);
+        }
+    }
+
+    private static class CosineKnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
+
+        private CosineKnnFloatBenchmarkFunction(int dims) {
+            super(dims, true);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new KnnDenseVector(docVector).cosineSimilarity(queryVector, false);
+        }
+    }
+
+    private static class CosineKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
+
+        private CosineKnnByteBenchmarkFunction(int dims) {
+            super(dims);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new ByteKnnDenseVector(docVector).cosineSimilarity(queryVector, queryMagnitude);
+        }
+    }
+
+    private static class CosineBinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
+
+        private CosineBinaryFloatBenchmarkFunction(int dims) {
+            super(dims, true);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new BinaryDenseVector(docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
+        }
+    }
+
+    private static class CosineBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
+
+        private CosineBinaryByteBenchmarkFunction(int dims) {
+            super(dims);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new ByteBinaryDenseVector(docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
+        }
+    }
+
+    private static class L1KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
+
+        private L1KnnFloatBenchmarkFunction(int dims) {
+            super(dims, false);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new KnnDenseVector(docVector).l1Norm(queryVector);
+        }
+    }
+
+    private static class L1KnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
+
+        private L1KnnByteBenchmarkFunction(int dims) {
+            super(dims);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new ByteKnnDenseVector(docVector).l1Norm(queryVector);
+        }
+    }
+
+    private static class L1BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
+
+        private L1BinaryFloatBenchmarkFunction(int dims) {
+            super(dims, true);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
+        }
+    }
+
+    private static class L1BinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
+
+        private L1BinaryByteBenchmarkFunction(int dims) {
+            super(dims);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new ByteBinaryDenseVector(docVector, dims).l1Norm(queryVector);
+        }
+    }
+
+    private static class L2KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
+
+        private L2KnnFloatBenchmarkFunction(int dims) {
+            super(dims, false);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new KnnDenseVector(docVector).l2Norm(queryVector);
+        }
+    }
+
+    private static class L2KnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
+
+        private L2KnnByteBenchmarkFunction(int dims) {
+            super(dims);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new ByteKnnDenseVector(docVector).l2Norm(queryVector);
+        }
+    }
+
+    private static class L2BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
+
+        private L2BinaryFloatBenchmarkFunction(int dims) {
+            super(dims, true);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
+        }
+    }
+
+    private static class L2BinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
+
+        private L2BinaryByteBenchmarkFunction(int dims) {
+            super(dims);
+        }
+
+        @Override
+        public void execute(Consumer<Object> consumer) {
+            consumer.accept(new ByteBinaryDenseVector(docVector, dims).l2Norm(queryVector));
+        }
+    }
+
+    private BenchmarkFunction benchmarkFunction;
+
+    @Setup
+    public void setBenchmarkFunction() {
+        switch (element) {
+            case "float" -> {
+                switch (function) {
+                    case "dot" -> benchmarkFunction = switch (type) {
+                            case "knn" -> new DotKnnFloatBenchmarkFunction(dims);
+                            case "binary" -> new DotBinaryFloatBenchmarkFunction(dims);
+                            default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
+                        };
+                    case "cosine" -> benchmarkFunction = switch (type) {
+                            case "knn" -> new CosineKnnFloatBenchmarkFunction(dims);
+                            case "binary" -> new CosineBinaryFloatBenchmarkFunction(dims);
+                            default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
+                        };
+                    case "l1" -> benchmarkFunction = switch (type) {
+                            case "knn" -> new L1KnnFloatBenchmarkFunction(dims);
+                            case "binary" -> new L1BinaryFloatBenchmarkFunction(dims);
+                            default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
+                        };
+                    case "l2" -> benchmarkFunction = switch (type) {
+                            case "knn" -> new L2KnnFloatBenchmarkFunction(dims);
+                            case "binary" -> new L2BinaryFloatBenchmarkFunction(dims);
+                            default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
+                        };
+                    default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
+                }
+            }
+            case "byte" -> {
+                switch (function) {
+                    case "dot" -> benchmarkFunction = switch (type) {
+                            case "knn" -> new DotKnnByteBenchmarkFunction(dims);
+                            case "binary" -> new DotBinaryByteBenchmarkFunction(dims);
+                            default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
+                        };
+                    case "cosine" -> benchmarkFunction = switch (type) {
+                            case "knn" -> new CosineKnnByteBenchmarkFunction(dims);
+                            case "binary" -> new CosineBinaryByteBenchmarkFunction(dims);
+                            default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
+                        };
+                    case "l1" -> benchmarkFunction = switch (type) {
+                            case "knn" -> new L1KnnByteBenchmarkFunction(dims);
+                            case "binary" -> new L1BinaryByteBenchmarkFunction(dims);
+                            default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
+                        };
+                    case "l2" -> benchmarkFunction = switch (type) {
+                            case "knn" -> new L2KnnByteBenchmarkFunction(dims);
+                            case "binary" -> new L2BinaryByteBenchmarkFunction(dims);
+                            default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
+                        };
+                    default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
+                }
+            }
+            default -> throw new UnsupportedOperationException("unexpected element [" + element + "]");
+        }
+        ;
+    }
+
+    @Benchmark
+    public void benchmark() throws IOException {
+        for (int i = 0; i < 25000; ++i) {
+            benchmarkFunction.execute(Object::toString);
+        }
+    }
+}

+ 5 - 0
docs/changelog/92340.yaml

@@ -0,0 +1,5 @@
+pr: 92340
+summary: Add vector distance scoring to micro benchmarks
+area: Performance
+type: enhancement
+issues: []