Sfoglia il codice sorgente

Add nProbe to `:qa:vector:checkVec` and allow multiple nProbes (#130316)

This change adds the n_probe value to the output which will be 0 in the
case of non-ivf runs. In addition it separates index and search data, so
a normal output looks like:

```
index_type  num_docs  index_time(ms)  force_merge_time(ms)  num_segments
----------  --------  --------------  --------------------  ------------  
ivf          1000000           50382                132819             0

index_type  n_probe  latency(ms)  net_cpu_time(ms)  avg_cpu_count     QPS  recall   visited
----------  -------  -----------  ----------------  -------------  ------  ------  --------  
ivf             100         3.69              0.00           0.00  271.00    0.97  58917.00
```

In addition, this change allows to define an array of n_probe in the
configuration file so we can test different values in the same run, so
for example defining an n_probe like:

```
  "n_probe" : [10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
```

will produce the following output:

```
index_type  num_docs  index_time(ms)  force_merge_time(ms)  num_segments
----------  --------  --------------  --------------------  ------------  
ivf          1000000           50382                132819             0

index_type  n_probe  latency(ms)  net_cpu_time(ms)  avg_cpu_count     QPS  recall   visited
----------  -------  -----------  ----------------  -------------  ------  ------  --------  
ivf              10         1.18              0.00           0.00  847.46    0.82   7244.59
ivf              20         1.36              0.00           0.00  735.29    0.89  13288.69
ivf              30         1.66              0.00           0.00  602.41    0.92  19266.67
ivf              40         1.93              0.00           0.00  518.13    0.94  24995.41
ivf              50         2.21              0.00           0.00  452.49    0.94  30739.60
ivf              60         2.51              0.00           0.00  398.41    0.95  36428.00
ivf              70         2.76              0.00           0.00  362.32    0.96  41952.59
ivf              80         2.99              0.00           0.00  334.45    0.96  47599.64
ivf              90         3.31              0.00           0.00  302.11    0.96  53254.45
ivf             100         3.69              0.00           0.00  271.00    0.97  58917.00
```

This makes easier to plot the n_probe curve while doing changes.
Ignacio Vera 3 mesi fa
parent
commit
7bc215aa49

+ 8 - 7
qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

@@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentParser;
 
 import java.io.IOException;
 import java.nio.file.Path;
+import java.util.List;
 import java.util.Locale;
 
 /**
@@ -35,7 +36,7 @@ record CmdLineArgs(
     KnnIndexTester.IndexType indexType,
     int numCandidates,
     int k,
-    int nProbe,
+    int[] nProbes,
     int ivfClusterSize,
     int overSamplingFactor,
     int hnswM,
@@ -86,7 +87,7 @@ record CmdLineArgs(
         PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD);
         PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD);
         PARSER.declareInt(Builder::setK, K_FIELD);
-        PARSER.declareInt(Builder::setNProbe, N_PROBE_FIELD);
+        PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
         PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD);
         PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD);
         PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD);
@@ -115,7 +116,7 @@ record CmdLineArgs(
         builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT));
         builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates);
         builder.field(K_FIELD.getPreferredName(), k);
-        builder.field(N_PROBE_FIELD.getPreferredName(), nProbe);
+        builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
         builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize);
         builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor);
         builder.field(HNSW_M_FIELD.getPreferredName(), hnswM);
@@ -144,7 +145,7 @@ record CmdLineArgs(
         private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW;
         private int numCandidates = 1000;
         private int k = 10;
-        private int nProbe = 10;
+        private int[] nProbes = new int[] { 10 };
         private int ivfClusterSize = 1000;
         private int overSamplingFactor = 1;
         private int hnswM = 16;
@@ -193,8 +194,8 @@ record CmdLineArgs(
             return this;
         }
 
-        public Builder setNProbe(int nProbe) {
-            this.nProbe = nProbe;
+        public Builder setNProbe(List<Integer> nProbes) {
+            this.nProbes = nProbes.stream().mapToInt(Integer::intValue).toArray();
             return this;
         }
 
@@ -275,7 +276,7 @@ record CmdLineArgs(
                 indexType,
                 numCandidates,
                 k,
-                nProbe,
+                nProbes,
                 ivfClusterSize,
                 overSamplingFactor,
                 hnswM,

+ 58 - 44
qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

@@ -172,8 +172,15 @@ public class KnnIndexTester {
             }
         }
         FormattedResults formattedResults = new FormattedResults();
+
         for (CmdLineArgs cmdLineArgs : cmdLineArgsList) {
-            Results result = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
+            int[] nProbes = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
+                ? cmdLineArgs.nProbes()
+                : new int[] { 0 };
+            Results[] results = new Results[nProbes.length];
+            for (int i = 0; i < nProbes.length; i++) {
+                results[i] = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
+            }
             logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
             Codec codec = createCodec(cmdLineArgs);
             Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs));
@@ -192,19 +199,22 @@ public class KnnIndexTester {
                     throw new IllegalArgumentException("Index path does not exist: " + indexPath);
                 }
                 if (cmdLineArgs.reindex()) {
-                    knnIndexer.createIndex(result);
+                    knnIndexer.createIndex(results[0]);
                 }
                 if (cmdLineArgs.forceMerge()) {
-                    knnIndexer.forceMerge(result);
+                    knnIndexer.forceMerge(results[0]);
                 } else {
-                    knnIndexer.numSegments(result);
+                    knnIndexer.numSegments(results[0]);
                 }
             }
             if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
-                KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
-                knnSearcher.runSearch(result);
+                for (int i = 0; i < results.length; i++) {
+                    int nProbe = nProbes[i];
+                    KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe);
+                    knnSearcher.runSearch(results[i]);
+                }
             }
-            formattedResults.results.add(result);
+            formattedResults.results.addAll(List.of(results));
         }
         logger.info("Results: \n" + formattedResults);
     }
@@ -218,13 +228,12 @@ public class KnnIndexTester {
                 return "No results available.";
             }
 
+            String[] indexingHeaders = { "index_type", "num_docs", "index_time(ms)", "force_merge_time(ms)", "num_segments" };
+
             // Define column headers
-            String[] headers = {
+            String[] searchHeaders = {
                 "index_type",
-                "num_docs",
-                "index_time(ms)",
-                "force_merge_time(ms)",
-                "num_segments",
+                "n_probe",
                 "latency(ms)",
                 "net_cpu_time(ms)",
                 "avg_cpu_count",
@@ -233,41 +242,58 @@ public class KnnIndexTester {
                 "visited" };
 
             // Calculate appropriate column widths based on headers and data
-            int[] widths = calculateColumnWidths(headers);
 
             StringBuilder sb = new StringBuilder();
 
-            // Format and append header
-            sb.append(formatRow(headers, widths));
-            sb.append("\n");
+            Results indexResult = results.get(0); // Assuming all results have the same index type and numDocs
+            String[] indexData = {
+                indexResult.indexType,
+                Integer.toString(indexResult.numDocs),
+                Long.toString(indexResult.indexTimeMS),
+                Long.toString(indexResult.forceMergeTimeMS),
+                Integer.toString(indexResult.numSegments) };
 
-            // Add separator line
-            for (int width : widths) {
-                sb.append("-".repeat(width)).append("  ");
-            }
-            sb.append("\n");
+            printBlock(sb, indexingHeaders, new String[][] { indexData });
 
+            String[][] searchData = new String[results.size()][];
             // Format and append each row of data
-            for (Results result : results) {
-                String[] rowData = {
+            for (int i = 0; i < results.size(); i++) {
+                Results result = results.get(i);
+                searchData[i] = new String[] {
                     result.indexType,
-                    Integer.toString(result.numDocs),
-                    Long.toString(result.indexTimeMS),
-                    Long.toString(result.forceMergeTimeMS),
-                    Integer.toString(result.numSegments),
+                    Integer.toString(result.nProbe),
                     String.format(Locale.ROOT, "%.2f", result.avgLatency),
                     String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS),
                     String.format(Locale.ROOT, "%.2f", result.avgCpuCount),
                     String.format(Locale.ROOT, "%.2f", result.qps),
                     String.format(Locale.ROOT, "%.2f", result.avgRecall),
                     String.format(Locale.ROOT, "%.2f", result.averageVisited) };
-                sb.append(formatRow(rowData, widths));
-                sb.append("\n");
+
             }
 
+            printBlock(sb, searchHeaders, searchData);
+
             return sb.toString();
         }
 
+        private void printBlock(StringBuilder sb, String[] headers, String[][] rows) {
+            int[] widths = calculateColumnWidths(headers, rows);
+            sb.append("\n");
+            sb.append(formatRow(headers, widths));
+            sb.append("\n");
+
+            // Add separator line
+            for (int width : widths) {
+                sb.append("-".repeat(width)).append("  ");
+            }
+            sb.append("\n");
+
+            for (String[] row : rows) {
+                sb.append(formatRow(row, widths));
+                sb.append("\n");
+            }
+        }
+
         // Helper method to format a single row with proper column widths
         private String formatRow(String[] values, int[] widths) {
             StringBuilder row = new StringBuilder();
@@ -285,7 +311,7 @@ public class KnnIndexTester {
         }
 
         // Calculate appropriate column widths based on headers and data
-        private int[] calculateColumnWidths(String[] headers) {
+        private int[] calculateColumnWidths(String[] headers, String[]... data) {
             int[] widths = new int[headers.length];
 
             // Initialize widths with header lengths
@@ -294,20 +320,7 @@ public class KnnIndexTester {
             }
 
             // Update widths based on data
-            for (Results result : results) {
-                String[] values = {
-                    result.indexType,
-                    Integer.toString(result.numDocs),
-                    Long.toString(result.indexTimeMS),
-                    Long.toString(result.forceMergeTimeMS),
-                    Integer.toString(result.numSegments),
-                    String.format(Locale.ROOT, "%.2f", result.avgLatency),
-                    String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS),
-                    String.format(Locale.ROOT, "%.2f", result.avgCpuCount),
-                    String.format(Locale.ROOT, "%.2f", result.qps),
-                    String.format(Locale.ROOT, "%.2f", result.avgRecall),
-                    String.format(Locale.ROOT, "%.2f", result.averageVisited) };
-
+            for (String[] values : data) {
                 for (int i = 0; i < values.length; i++) {
                     widths[i] = Math.max(widths[i], values[i].length());
                 }
@@ -323,6 +336,7 @@ public class KnnIndexTester {
         long indexTimeMS;
         long forceMergeTimeMS;
         int numSegments;
+        int nProbe;
         double avgLatency;
         double qps;
         double avgRecall;

+ 3 - 2
qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

@@ -94,7 +94,7 @@ class KnnSearcher {
     private final float overSamplingFactor;
     private final int searchThreads;
 
-    KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs) {
+    KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
         this.docPath = cmdLineArgs.docVectors();
         this.indexPath = indexPath;
         this.queryPath = cmdLineArgs.queryVectors();
@@ -109,7 +109,7 @@ class KnnSearcher {
             throw new IllegalArgumentException("numQueryVectors must be > 0");
         }
         this.efSearch = cmdLineArgs.numCandidates();
-        this.nProbe = cmdLineArgs.nProbe();
+        this.nProbe = nProbe;
         this.indexType = cmdLineArgs.indexType();
         this.searchThreads = cmdLineArgs.searchThreads();
     }
@@ -206,6 +206,7 @@ class KnnSearcher {
         }
         logger.info("checking results");
         int[][] nn = getOrCalculateExactNN(offsetByteSize);
+        finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0;
         finalResults.avgRecall = checkResults(resultIds, nn, topK);
         finalResults.qps = (1000f * numQueryVectors) / elapsed;
         finalResults.avgLatency = (float) elapsed / numQueryVectors;