Browse Source

Have top level knn searches tracked in query stats (#132548)

Since dfs kNN searches aren't in the query phase, we don't get their
search stats for free in query stats. 

This adds their stats specifically during knn search in dfs.

closes: https://github.com/elastic/elasticsearch/issues/128098
Benjamin Trent 2 months ago
parent
commit
c19dc0ec6e

+ 5 - 0
docs/changelog/132548.yaml

@@ -0,0 +1,5 @@
+pr: 132548
+summary: Have top level knn searches tracked in query stats
+area: Vector Search
+type: bug
+issues: []

+ 19 - 6
server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java

@@ -177,7 +177,7 @@ public class DfsPhase {
         return null;
     };
 
-    private static void executeKnnVectorQuery(SearchContext context) throws IOException {
+    static void executeKnnVectorQuery(SearchContext context) throws IOException {
         SearchSourceBuilder source = context.request().source();
         if (source == null || source.knnSearch().isEmpty()) {
             return;
@@ -195,11 +195,24 @@ public class DfsPhase {
             }
         }
         List<DfsKnnResults> knnResults = new ArrayList<>(knnVectorQueryBuilders.size());
-        for (int i = 0; i < knnSearch.size(); i++) {
-            String knnField = knnVectorQueryBuilders.get(i).getFieldName();
-            String knnNestedPath = searchExecutionContext.nestedLookup().getNestedParent(knnField);
-            Query knnQuery = searchExecutionContext.toQuery(knnVectorQueryBuilders.get(i)).query();
-            knnResults.add(singleKnnSearch(knnQuery, knnSearch.get(i).k(), context.getProfilers(), context.searcher(), knnNestedPath));
+        final long afterQueryTime;
+        final long beforeQueryTime = System.nanoTime();
+        var opsListener = context.indexShard().getSearchOperationListener();
+        opsListener.onPreQueryPhase(context);
+        try {
+            for (int i = 0; i < knnSearch.size(); i++) {
+                String knnField = knnVectorQueryBuilders.get(i).getFieldName();
+                String knnNestedPath = searchExecutionContext.nestedLookup().getNestedParent(knnField);
+                Query knnQuery = searchExecutionContext.toQuery(knnVectorQueryBuilders.get(i)).query();
+                knnResults.add(singleKnnSearch(knnQuery, knnSearch.get(i).k(), context.getProfilers(), context.searcher(), knnNestedPath));
+            }
+            afterQueryTime = System.nanoTime();
+            opsListener.onQueryPhase(context, afterQueryTime - beforeQueryTime);
+            opsListener = null;
+        } finally {
+            if (opsListener != null) {
+                opsListener.onFailedQueryPhase(context);
+            }
         }
         context.dfsResult().knnResults(knnResults);
     }

+ 2 - 0
server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java

@@ -5627,6 +5627,7 @@ public class IndexShardTests extends IndexShardTestCase {
             RetentionLeaseSyncer.EMPTY,
             EMPTY_EVENT_LISTENER,
             fakeClock,
+            Collections.emptyList(),
             // Use a listener to advance the fake clock once per indexing operation:
             new IndexingOperationListener() {
                 @Override
@@ -5772,6 +5773,7 @@ public class IndexShardTests extends IndexShardTestCase {
             RetentionLeaseSyncer.EMPTY,
             EMPTY_EVENT_LISTENER,
             fakeClock,
+            Collections.emptyList(),
             // Use a listener to advance the fake clock once per indexing operation:
             new IndexingOperationListener() {
                 @Override

+ 123 - 2
server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java

@@ -17,22 +17,44 @@ import org.apache.lucene.search.KnnFloatVectorQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.util.Accountable;
+import org.elasticsearch.action.search.SearchShardTask;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.IndexSettings;
+import org.elasticsearch.index.IndexVersion;
+import org.elasticsearch.index.cache.bitset.BitsetFilterCache;
+import org.elasticsearch.index.mapper.MapperMetrics;
+import org.elasticsearch.index.mapper.MappingLookup;
+import org.elasticsearch.index.query.ParsedQuery;
+import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.index.shard.IndexShard;
+import org.elasticsearch.index.shard.IndexShardTestCase;
+import org.elasticsearch.index.shard.SearchOperationListener;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.internal.ContextIndexSearcher;
+import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.profile.Profilers;
 import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
 import org.elasticsearch.search.profile.query.CollectorResult;
 import org.elasticsearch.search.profile.query.QueryProfileShardResult;
-import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.search.vectors.KnnSearchBuilder;
+import org.elasticsearch.test.TestSearchContext;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.atomic.AtomicLong;
 
-public class DfsPhaseTests extends ESTestCase {
+import static org.elasticsearch.search.dfs.DfsPhase.executeKnnVectorQuery;
+
+public class DfsPhaseTests extends IndexShardTestCase {
 
     ThreadPoolExecutor threadPoolExecutor;
     private TestThreadPool threadPool;
@@ -49,6 +71,105 @@ public class DfsPhaseTests extends ESTestCase {
         terminate(threadPool);
     }
 
+    public void testKnnSearch() throws IOException {
+        AtomicLong queryCount = new AtomicLong();
+        AtomicLong queryTime = new AtomicLong();
+
+        IndexShard indexShard = newShard(true, List.of(new SearchOperationListener() {
+            @Override
+            public void onQueryPhase(SearchContext searchContext, long tookInNanos) {
+                queryCount.incrementAndGet();
+                queryTime.addAndGet(tookInNanos);
+            }
+        }));
+        try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
+            int numDocs = randomIntBetween(900, 1000);
+            for (int i = 0; i < numDocs; i++) {
+                Document d = new Document();
+                d.add(new KnnFloatVectorField("float_vector", new float[] { i, 0, 0 }));
+                w.addDocument(d);
+            }
+            w.flush();
+
+            IndexReader reader = w.getReader();
+            ContextIndexSearcher searcher = new ContextIndexSearcher(
+                reader,
+                IndexSearcher.getDefaultSimilarity(),
+                IndexSearcher.getDefaultQueryCache(),
+                IndexSearcher.getDefaultQueryCachingPolicy(),
+                randomBoolean(),
+                threadPoolExecutor,
+                threadPoolExecutor.getMaximumPoolSize(),
+                1
+            );
+            IndexSettings indexSettings = new IndexSettings(
+                IndexMetadata.builder("index")
+                    .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()))
+                    .numberOfShards(1)
+                    .numberOfReplicas(0)
+                    .creationDate(System.currentTimeMillis())
+                    .build(),
+                Settings.EMPTY
+            );
+            BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetFilterCache.Listener() {
+                @Override
+                public void onCache(ShardId shardId, Accountable accountable) {
+
+                }
+
+                @Override
+                public void onRemoval(ShardId shardId, Accountable accountable) {
+
+                }
+            });
+            SearchExecutionContext searchExecutionContext = new SearchExecutionContext(
+                0,
+                0,
+                indexSettings,
+                bitsetFilterCache,
+                null,
+                null,
+                MappingLookup.EMPTY,
+                null,
+                null,
+                null,
+                null,
+                null,
+                null,
+                null,
+                null,
+                null,
+                null,
+                null,
+                Collections.emptyMap(),
+                null,
+                MapperMetrics.NOOP
+            );
+
+            Query query = new KnnFloatVectorQuery("float_vector", new float[] { 0, 0, 0 }, numDocs, null);
+            try (TestSearchContext context = new TestSearchContext(searchExecutionContext, indexShard, searcher) {
+                @Override
+                public DfsSearchResult dfsResult() {
+                    return new DfsSearchResult(null, null, null);
+                }
+            }) {
+                context.request()
+                    .source(
+                        new SearchSourceBuilder().knnSearch(
+                            List.of(new KnnSearchBuilder("float_vector", new float[] { 0, 0, 0 }, numDocs, numDocs, null, null))
+                        )
+                    );
+                context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
+                context.parsedQuery(new ParsedQuery(query));
+                executeKnnVectorQuery(context);
+                assertTrue(queryCount.get() > 0);
+                assertTrue(queryTime.get() > 0);
+                reader.close();
+                closeShards(indexShard);
+            }
+        }
+    }
+
     public void testSingleKnnSearch() throws IOException {
         try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
             int numDocs = randomIntBetween(900, 1000);

+ 126 - 11
test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java

@@ -244,7 +244,16 @@ public abstract class IndexShardTestCase extends ESTestCase {
      *                another shard)
      */
     protected IndexShard newShard(final boolean primary, final Settings settings) throws IOException {
-        return newShard(primary, settings, new InternalEngineFactory());
+        return newShard(primary, settings, new InternalEngineFactory(), Collections.emptyList());
+    }
+
+    protected IndexShard newShard(
+        boolean primary,
+        Settings settings,
+        EngineFactory engineFactory,
+        final IndexingOperationListener... listeners
+    ) throws IOException {
+        return newShard(primary, settings, engineFactory, Collections.emptyList(), listeners);
     }
 
     /**
@@ -260,9 +269,20 @@ public abstract class IndexShardTestCase extends ESTestCase {
         boolean primary,
         Settings settings,
         EngineFactory engineFactory,
+        final List<SearchOperationListener> searchListeners,
         final IndexingOperationListener... listeners
     ) throws IOException {
-        return newShard(primary, new ShardId("index", "_na_", 0), settings, engineFactory, listeners);
+        return newShard(primary, new ShardId("index", "_na_", 0), settings, engineFactory, searchListeners, listeners);
+    }
+
+    protected IndexShard newShard(
+        boolean primary,
+        ShardId shardId,
+        Settings settings,
+        EngineFactory engineFactory,
+        final IndexingOperationListener... listeners
+    ) throws IOException {
+        return newShard(primary, shardId, settings, engineFactory, Collections.emptyList(), listeners);
     }
 
     /**
@@ -280,6 +300,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
         ShardId shardId,
         Settings settings,
         EngineFactory engineFactory,
+        List<SearchOperationListener> searchListeners,
         final IndexingOperationListener... listeners
     ) throws IOException {
         final RecoverySource recoverySource = primary
@@ -288,7 +309,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
         final ShardRouting shardRouting = shardRoutingBuilder(shardId, randomAlphaOfLength(10), primary, ShardRoutingState.INITIALIZING)
             .withRecoverySource(recoverySource)
             .build();
-        return newShard(shardRouting, settings, engineFactory, listeners);
+        return newShard(shardRouting, settings, engineFactory, searchListeners, listeners);
     }
 
     protected IndexShard newShard(ShardRouting shardRouting, final IndexingOperationListener... listeners) throws IOException {
@@ -297,7 +318,16 @@ public abstract class IndexShardTestCase extends ESTestCase {
 
     protected IndexShard newShard(ShardRouting shardRouting, final Settings settings, final IndexingOperationListener... listeners)
         throws IOException {
-        return newShard(shardRouting, settings, new InternalEngineFactory(), listeners);
+        return newShard(shardRouting, settings, new InternalEngineFactory(), Collections.emptyList(), listeners);
+    }
+
+    protected IndexShard newShard(
+        final ShardRouting shardRouting,
+        final Settings settings,
+        final EngineFactory engineFactory,
+        final IndexingOperationListener... listeners
+    ) throws IOException {
+        return newShard(shardRouting, settings, engineFactory, Collections.emptyList(), listeners);
     }
 
     /**
@@ -312,6 +342,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
         final ShardRouting shardRouting,
         final Settings settings,
         final EngineFactory engineFactory,
+        final List<SearchOperationListener> searchListeners,
         final IndexingOperationListener... listeners
     ) throws IOException {
         assert shardRouting.initializing() : shardRouting;
@@ -326,7 +357,16 @@ public abstract class IndexShardTestCase extends ESTestCase {
             .settings(indexSettings)
             .primaryTerm(0, primaryTerm)
             .putMapping("{ \"properties\": {} }");
-        return newShard(shardRouting, metadata.build(), null, engineFactory, NOOP_GCP_SYNCER, RetentionLeaseSyncer.EMPTY, listeners);
+        return newShard(
+            shardRouting,
+            metadata.build(),
+            null,
+            engineFactory,
+            NOOP_GCP_SYNCER,
+            RetentionLeaseSyncer.EMPTY,
+            searchListeners,
+            listeners
+        );
     }
 
     /**
@@ -344,6 +384,10 @@ public abstract class IndexShardTestCase extends ESTestCase {
         return newShard(shardRouting, Settings.EMPTY, new InternalEngineFactory(), listeners);
     }
 
+    protected IndexShard newShard(boolean primary, List<SearchOperationListener> listeners) throws IOException {
+        return newShard(primary, Settings.EMPTY, new InternalEngineFactory(), listeners);
+    }
+
     /**
      * creates a new initializing shard. The shard will will be put in its proper path under the
      * supplied node id.
@@ -363,7 +407,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
     }
 
     /**
-     * creates a new initializing shard. The shard will will be put in its proper path under the
+     * creates a new initializing shard. The shard will be put in its proper path under the
      * supplied node id.
      *
      * @param shardId the shard id to use
@@ -387,7 +431,8 @@ public abstract class IndexShardTestCase extends ESTestCase {
             readerWrapper,
             new InternalEngineFactory(),
             globalCheckpointSyncer,
-            RetentionLeaseSyncer.EMPTY
+            RetentionLeaseSyncer.EMPTY,
+            Collections.emptyList()
         );
     }
 
@@ -406,7 +451,37 @@ public abstract class IndexShardTestCase extends ESTestCase {
         EngineFactory engineFactory,
         IndexingOperationListener... listeners
     ) throws IOException {
-        return newShard(routing, indexMetadata, indexReaderWrapper, engineFactory, NOOP_GCP_SYNCER, RetentionLeaseSyncer.EMPTY, listeners);
+        return newShard(
+            routing,
+            indexMetadata,
+            indexReaderWrapper,
+            engineFactory,
+            NOOP_GCP_SYNCER,
+            RetentionLeaseSyncer.EMPTY,
+            Collections.emptyList(),
+            listeners
+        );
+    }
+
+    protected IndexShard newShard(
+        ShardRouting routing,
+        IndexMetadata indexMetadata,
+        @Nullable CheckedFunction<DirectoryReader, DirectoryReader, IOException> indexReaderWrapper,
+        @Nullable EngineFactory engineFactory,
+        GlobalCheckpointSyncer globalCheckpointSyncer,
+        RetentionLeaseSyncer retentionLeaseSyncer,
+        IndexingOperationListener... listeners
+    ) throws IOException {
+        return newShard(
+            routing,
+            indexMetadata,
+            indexReaderWrapper,
+            engineFactory,
+            globalCheckpointSyncer,
+            retentionLeaseSyncer,
+            Collections.emptyList(),
+            listeners
+        );
     }
 
     /**
@@ -425,6 +500,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
         @Nullable EngineFactory engineFactory,
         GlobalCheckpointSyncer globalCheckpointSyncer,
         RetentionLeaseSyncer retentionLeaseSyncer,
+        List<SearchOperationListener> searchListeners,
         IndexingOperationListener... listeners
     ) throws IOException {
         // add node id as name to settings for proper logging
@@ -441,6 +517,35 @@ public abstract class IndexShardTestCase extends ESTestCase {
             globalCheckpointSyncer,
             retentionLeaseSyncer,
             EMPTY_EVENT_LISTENER,
+            searchListeners,
+            listeners
+        );
+    }
+
+    protected IndexShard newShard(
+        ShardRouting routing,
+        ShardPath shardPath,
+        IndexMetadata indexMetadata,
+        @Nullable CheckedFunction<IndexSettings, Store, IOException> storeProvider,
+        @Nullable CheckedFunction<DirectoryReader, DirectoryReader, IOException> indexReaderWrapper,
+        @Nullable EngineFactory engineFactory,
+        GlobalCheckpointSyncer globalCheckpointSyncer,
+        RetentionLeaseSyncer retentionLeaseSyncer,
+        IndexEventListener indexEventListener,
+        IndexingOperationListener... listeners
+    ) throws IOException {
+        return newShard(
+            routing,
+            shardPath,
+            indexMetadata,
+            storeProvider,
+            indexReaderWrapper,
+            engineFactory,
+            globalCheckpointSyncer,
+            retentionLeaseSyncer,
+            indexEventListener,
+            System::nanoTime,
+            Collections.emptyList(),
             listeners
         );
     }
@@ -466,6 +571,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
         GlobalCheckpointSyncer globalCheckpointSyncer,
         RetentionLeaseSyncer retentionLeaseSyncer,
         IndexEventListener indexEventListener,
+        List<SearchOperationListener> searchListeners,
         IndexingOperationListener... listeners
     ) throws IOException {
         return newShard(
@@ -479,6 +585,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
             retentionLeaseSyncer,
             indexEventListener,
             System::nanoTime,
+            searchListeners,
             listeners
         );
     }
@@ -506,6 +613,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
         RetentionLeaseSyncer retentionLeaseSyncer,
         IndexEventListener indexEventListener,
         LongSupplier relativeTimeSupplier,
+        List<SearchOperationListener> soListener,
         IndexingOperationListener... listeners
     ) throws IOException {
         final Settings nodeSettings = Settings.builder().put("node.name", routing.currentNodeId()).build();
@@ -553,7 +661,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
                 threadPoolMergeExecutorService,
                 BigArrays.NON_RECYCLING_INSTANCE,
                 warmer,
-                Collections.emptyList(),
+                soListener,
                 Arrays.asList(listeners),
                 globalCheckpointSyncer,
                 retentionLeaseSyncer,
@@ -629,6 +737,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
             current.getGlobalCheckpointSyncer(),
             current.getRetentionLeaseSyncer(),
             EMPTY_EVENT_LISTENER,
+            Collections.emptyList(),
             listeners
         );
     }
@@ -683,7 +792,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
         final EngineFactory engineFactory,
         final IndexingOperationListener... listeners
     ) throws IOException {
-        return newStartedShard(p -> newShard(p, settings, engineFactory, listeners), primary);
+        return newStartedShard(p -> newShard(p, settings, engineFactory, Collections.emptyList(), listeners), primary);
     }
 
     /**
@@ -785,7 +894,13 @@ public abstract class IndexShardTestCase extends ESTestCase {
         IndexShard primary = null;
         try {
             primary = newStartedShard(
-                p -> newShard(p, replica.routingEntry().shardId(), replica.indexSettings.getSettings(), new InternalEngineFactory()),
+                p -> newShard(
+                    p,
+                    replica.routingEntry().shardId(),
+                    replica.indexSettings.getSettings(),
+                    new InternalEngineFactory(),
+                    Collections.emptyList()
+                ),
                 true
             );
             recoverReplica(replica, primary, startReplica);