Browse Source

Add new query_vector_builder option to knn search clause (#93331)

This adds a new option to the knn search clause called query_vector_builder. This is a pluggable configuration that allows the query_vector created or retrieved.
Benjamin Trent 2 years ago
parent
commit
7f9f3bcd30

+ 5 - 0
docs/changelog/93331.yaml

@@ -0,0 +1,5 @@
+pr: 93331
+summary: Add new `query_vector_builder` option to knn search clause
+area: Search
+type: enhancement
+issues: []

+ 7 - 1
docs/reference/search/search.asciidoc

@@ -506,8 +506,14 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-k]
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-num-candidates]
 
 `query_vector`::
-(Required, array of floats)
+(Optional, array of floats)
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector]
+
+`query_vector_builder`::
+(Optional, object)
+A configuration object indicating how to build a query_vector before executing the request. You must provide
+a `query_vector_builder` or `query_vector`, but not both.
+
 ====
 
 [[search-api-min-score]]

+ 42 - 0
server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java

@@ -39,6 +39,7 @@ import org.elasticsearch.search.rescore.RescorerBuilder;
 import org.elasticsearch.search.suggest.Suggest;
 import org.elasticsearch.search.suggest.Suggester;
 import org.elasticsearch.search.suggest.SuggestionBuilder;
+import org.elasticsearch.search.vectors.QueryVectorBuilder;
 import org.elasticsearch.xcontent.ContextParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContent;
@@ -73,6 +74,14 @@ public interface SearchPlugin {
         return emptyList();
     }
 
+    /**
+     * The new {@link QueryVectorBuilder}s defined by this plugin. {@linkplain QueryVectorBuilder}s can be used within a kNN
+     * search to build the query vector instead of having the user provide the vector directly
+     */
+    default List<QueryVectorBuilderSpec<?>> getQueryVectorBuilders() {
+        return emptyList();
+    }
+
     /**
      * The new {@link FetchSubPhase}s defined by this plugin.
      */
@@ -592,4 +601,37 @@ public interface SearchPlugin {
             return highlighters;
         }
     }
+
+    /**
+     * Specification of custom {@link QueryVectorBuilder}.
+     */
+    class QueryVectorBuilderSpec<T extends QueryVectorBuilder> extends SearchExtensionSpec<T, BiFunction<XContentParser, Void, T>> {
+        /**
+         * Specification of custom {@link QueryVectorBuilder}.
+         *
+         * @param name holds the names by which this query vector builder might be parsed.
+         *             The {@link ParseField#getPreferredName()} is special as it
+         *             is the name by under which the reader is registered. So it is the name that the builder should use as its
+         *             {@link NamedWriteable#getWriteableName()} too.
+         * @param reader the reader registered for this query vector builder. Typically a reference to a constructor that takes a
+         *        {@link StreamInput}
+         * @param parser the parser the reads the query vector builder from xcontent
+         */
+        public QueryVectorBuilderSpec(ParseField name, Writeable.Reader<T> reader, BiFunction<XContentParser, Void, T> parser) {
+            super(name, reader, parser);
+        }
+
+        /**
+         * Specification of custom {@link QueryVectorBuilder}.
+         *
+         * @param name the name by which this query vector builder might be parsed or deserialized.
+         *             Make sure that the query builder returns this name for {@link NamedWriteable#getWriteableName()}.
+         * @param reader the reader registered for this query vector builder. Typically a reference to a constructor that takes a
+         *        {@link StreamInput}
+         * @param parser the parser the reads the query vector builder from xcontent
+         */
+        public QueryVectorBuilderSpec(String name, Writeable.Reader<T> reader, BiFunction<XContentParser, Void, T> parser) {
+            super(name, reader, parser);
+        }
+    }
 }

+ 14 - 0
server/src/main/java/org/elasticsearch/search/SearchModule.java

@@ -82,6 +82,7 @@ import org.elasticsearch.plugins.SearchPlugin.AggregationSpec;
 import org.elasticsearch.plugins.SearchPlugin.FetchPhaseConstructionContext;
 import org.elasticsearch.plugins.SearchPlugin.PipelineAggregationSpec;
 import org.elasticsearch.plugins.SearchPlugin.QuerySpec;
+import org.elasticsearch.plugins.SearchPlugin.QueryVectorBuilderSpec;
 import org.elasticsearch.plugins.SearchPlugin.RescorerSpec;
 import org.elasticsearch.plugins.SearchPlugin.ScoreFunctionSpec;
 import org.elasticsearch.plugins.SearchPlugin.SearchExtSpec;
@@ -244,6 +245,7 @@ import org.elasticsearch.search.suggest.term.TermSuggestion;
 import org.elasticsearch.search.suggest.term.TermSuggestionBuilder;
 import org.elasticsearch.search.vectors.KnnScoreDocQueryBuilder;
 import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
+import org.elasticsearch.search.vectors.QueryVectorBuilder;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentParser;
@@ -305,6 +307,7 @@ public class SearchModule {
         registerSorts();
         registerValueFormats();
         registerSignificanceHeuristics(plugins);
+        registerQueryVectorBuilders(plugins);
         this.valuesSourceRegistry = registerAggregations(plugins);
         registerPipelineAggregations(plugins);
         registerFetchSubPhases(plugins);
@@ -980,6 +983,17 @@ public class SearchModule {
         );
     }
 
+    private void registerQueryVectorBuilders(List<SearchPlugin> plugins) {
+        registerFromPlugin(plugins, SearchPlugin::getQueryVectorBuilders, this::registerQueryVectorBuilder);
+    }
+
+    private <T extends QueryVectorBuilder> void registerQueryVectorBuilder(QueryVectorBuilderSpec<?> spec) {
+        namedXContents.add(new NamedXContentRegistry.Entry(QueryVectorBuilder.class, spec.getName(), p -> spec.getParser().apply(p, null)));
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(QueryVectorBuilder.class, spec.getName().getPreferredName(), spec.getReader())
+        );
+    }
+
     private void registerFetchSubPhases(List<SearchPlugin> plugins) {
         registerFetchSubPhase(new ExplainPhase());
         registerFetchSubPhase(new StoredFieldsPhase());

+ 146 - 9
server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

@@ -8,6 +8,9 @@
 
 package org.elasticsearch.search.vectors;
 
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -27,8 +30,11 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
+import java.util.function.Supplier;
 
+import static org.elasticsearch.common.Strings.format;
 import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
 
 /**
  * Defines a kNN search to run in the search request.
@@ -39,6 +45,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
     public static final ParseField K_FIELD = new ParseField("k");
     public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
     public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
+    public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
     public static final ParseField FILTER_FIELD = new ParseField("filter");
     public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD;
 
@@ -46,18 +53,28 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         @SuppressWarnings("unchecked")
         // TODO optimize parsing for when BYTE values are provided
         List<Float> vector = (List<Float>) args[1];
-        float[] vectorArray = new float[vector.size()];
-        for (int i = 0; i < vector.size(); i++) {
-            vectorArray[i] = vector.get(i);
+        final float[] vectorArray;
+        if (vector != null) {
+            vectorArray = new float[vector.size()];
+            for (int i = 0; i < vector.size(); i++) {
+                vectorArray[i] = vector.get(i);
+            }
+        } else {
+            vectorArray = null;
         }
-        return new KnnSearchBuilder((String) args[0], vectorArray, (int) args[2], (int) args[3]);
+        return new KnnSearchBuilder((String) args[0], vectorArray, (QueryVectorBuilder) args[4], (int) args[2], (int) args[3]);
     });
 
     static {
         PARSER.declareString(constructorArg(), FIELD_FIELD);
-        PARSER.declareFloatArray(constructorArg(), QUERY_VECTOR_FIELD);
+        PARSER.declareFloatArray(optionalConstructorArg(), QUERY_VECTOR_FIELD);
         PARSER.declareInt(constructorArg(), K_FIELD);
         PARSER.declareInt(constructorArg(), NUM_CANDS_FIELD);
+        PARSER.declareNamedObject(
+            optionalConstructorArg(),
+            (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
+            QUERY_VECTOR_BUILDER_FIELD
+        );
         PARSER.declareFieldArray(
             KnnSearchBuilder::addFilterQueries,
             (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
@@ -73,6 +90,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
 
     final String field;
     final float[] queryVector;
+    final QueryVectorBuilder queryVectorBuilder;
+    private final Supplier<float[]> querySupplier;
     final int k;
     final int numCands;
     final List<QueryBuilder> filterQueries;
@@ -87,6 +106,27 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
      * @param numCands    the number of nearest neighbor candidates to consider per shard
      */
     public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands) {
+        this(field, Objects.requireNonNull(queryVector, format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands);
+    }
+
+    /**
+     * Defines a kNN search where the query vector will be provided by the queryVectorBuilder
+     * @param field              the name of the vector field to search against
+     * @param queryVectorBuilder the query vector builder
+     * @param k                  the final number of nearest neighbors to return as top hits
+     * @param numCands           the number of nearest neighbor candidates to consider per shard
+     */
+    public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands) {
+        this(
+            field,
+            null,
+            Objects.requireNonNull(queryVectorBuilder, format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())),
+            k,
+            numCands
+        );
+    }
+
+    private KnnSearchBuilder(String field, float[] queryVector, QueryVectorBuilder queryVectorBuilder, int k, int numCands) {
         if (k < 1) {
             throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
         }
@@ -98,11 +138,41 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         if (numCands > NUM_CANDS_LIMIT) {
             throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
         }
+        if (queryVector == null && queryVectorBuilder == null) {
+            throw new IllegalArgumentException(
+                format(
+                    "either [%s] or [%s] must be provided",
+                    QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
+                    QUERY_VECTOR_FIELD.getPreferredName()
+                )
+            );
+        }
+        if (queryVector != null && queryVectorBuilder != null) {
+            throw new IllegalArgumentException(
+                format(
+                    "cannot provide both [%s] and [%s]",
+                    QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
+                    QUERY_VECTOR_FIELD.getPreferredName()
+                )
+            );
+        }
         this.field = field;
-        this.queryVector = queryVector;
+        this.queryVector = queryVector == null ? new float[0] : queryVector;
+        this.queryVectorBuilder = queryVectorBuilder;
         this.k = k;
         this.numCands = numCands;
         this.filterQueries = new ArrayList<>();
+        this.querySupplier = null;
+    }
+
+    private KnnSearchBuilder(String field, Supplier<float[]> querySupplier, int k, int numCands, List<QueryBuilder> filterQueries) {
+        this.field = field;
+        this.queryVector = new float[0];
+        this.queryVectorBuilder = null;
+        this.k = k;
+        this.numCands = numCands;
+        this.filterQueries = filterQueries;
+        this.querySupplier = querySupplier;
     }
 
     public KnnSearchBuilder(StreamInput in) throws IOException {
@@ -112,6 +182,12 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         this.queryVector = in.readFloatArray();
         this.filterQueries = in.readNamedWriteableList(QueryBuilder.class);
         this.boost = in.readFloat();
+        if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_7_0)) {
+            this.queryVectorBuilder = in.readOptionalNamedWriteable(QueryVectorBuilder.class);
+        } else {
+            this.queryVectorBuilder = null;
+        }
+        this.querySupplier = null;
     }
 
     public int k() {
@@ -140,6 +216,32 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
 
     @Override
     public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
+        if (querySupplier != null) {
+            if (querySupplier.get() == null) {
+                return this;
+            }
+            return new KnnSearchBuilder(field, querySupplier.get(), k, numCands).boost(boost).addFilterQueries(filterQueries);
+        }
+        if (queryVectorBuilder != null) {
+            SetOnce<float[]> toSet = new SetOnce<>();
+            ctx.registerAsyncAction((c, l) -> queryVectorBuilder.buildVector(c, ActionListener.wrap(v -> {
+                toSet.set(v);
+                if (v == null) {
+                    l.onFailure(
+                        new IllegalArgumentException(
+                            format(
+                                "[%s] with name [%s] returned null query_vector",
+                                QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
+                                queryVectorBuilder.getWriteableName()
+                            )
+                        )
+                    );
+                    return;
+                }
+                l.onResponse(null);
+            }, l::onFailure)));
+            return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries).boost(boost);
+        }
         boolean changed = false;
         List<QueryBuilder> rewrittenQueries = new ArrayList<>(filterQueries.size());
         for (QueryBuilder query : filterQueries) {
@@ -156,6 +258,9 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
     }
 
     public KnnVectorQueryBuilder toQueryBuilder() {
+        if (queryVectorBuilder != null) {
+            throw new IllegalArgumentException("missing rewrite");
+        }
         return new KnnVectorQueryBuilder(field, queryVector, numCands).boost(boost).addFilterQueries(filterQueries);
     }
 
@@ -168,21 +273,38 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             && numCands == that.numCands
             && Objects.equals(field, that.field)
             && Arrays.equals(queryVector, that.queryVector)
+            && Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
+            && Objects.equals(querySupplier, that.querySupplier)
             && Objects.equals(filterQueries, that.filterQueries)
             && boost == that.boost;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(field, k, numCands, Arrays.hashCode(queryVector), Objects.hashCode(filterQueries), boost);
+        return Objects.hash(
+            field,
+            k,
+            numCands,
+            querySupplier,
+            queryVectorBuilder,
+            Arrays.hashCode(queryVector),
+            Objects.hashCode(filterQueries),
+            boost
+        );
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.field(FIELD_FIELD.getPreferredName(), field)
             .field(K_FIELD.getPreferredName(), k)
-            .field(NUM_CANDS_FIELD.getPreferredName(), numCands)
-            .array(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
+            .field(NUM_CANDS_FIELD.getPreferredName(), numCands);
+        if (queryVectorBuilder != null) {
+            builder.startObject(QUERY_VECTOR_BUILDER_FIELD.getPreferredName());
+            builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder);
+            builder.endObject();
+        } else {
+            builder.array(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
+        }
 
         if (filterQueries.isEmpty() == false) {
             builder.startArray(FILTER_FIELD.getPreferredName());
@@ -201,11 +323,26 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        if (querySupplier != null) {
+            throw new IllegalStateException("missing a rewriteAndFetch?");
+        }
         out.writeString(field);
         out.writeVInt(k);
         out.writeVInt(numCands);
         out.writeFloatArray(queryVector);
         out.writeNamedWriteableList(filterQueries);
         out.writeFloat(boost);
+        if (out.getTransportVersion().before(TransportVersion.V_8_7_0) && queryVectorBuilder != null) {
+            throw new IllegalArgumentException(
+                format(
+                    "cannot serialize [%s] to older node of version [%s]",
+                    QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
+                    out.getTransportVersion()
+                )
+            );
+        }
+        if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_7_0)) {
+            out.writeOptionalNamedWriteable(queryVectorBuilder);
+        }
     }
 }

+ 34 - 0
server/src/main/java/org/elasticsearch/search/vectors/QueryVectorBuilder.java

@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
+import org.elasticsearch.xcontent.ToXContentObject;
+
+/**
+ * Provides a mechanism for building a KNN query vector in an asynchronous manner during the rewrite phase
+ */
+public interface QueryVectorBuilder extends VersionedNamedWriteable, ToXContentObject {
+
+    /**
+     * Method for building a vector via the client. This method is called during RerwiteAndFetch.
+     * Typical implementation for this method will:
+     *  1. call some asynchronous client action
+     *  2. Handle failure/success for that action (usually passing failure to the provided listener)
+     *  3. Parse the success case and extract the query vector
+     *  4. Pass the extracted query vector to the provided listener
+     *
+     * @param client for performing asynchronous actions against the cluster
+     * @param listener listener to accept the created vector
+     */
+    void buildVector(Client client, ActionListener<float[]> listener);
+
+}

+ 164 - 0
server/src/test/java/org/elasticsearch/search/vectors/AbstractQueryVectorBuilderTestCase.java

@@ -0,0 +1,164 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.Rewriteable;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.test.AbstractXContentSerializingTestCase;
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.test.client.NoOpClient;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.search.vectors.KnnSearchBuilderTests.randomVector;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.nullValue;
+
+/**
+ * Tests a query vector builder
+ * @param <T> the query vector builder type to test
+ */
+public abstract class AbstractQueryVectorBuilderTestCase<T extends QueryVectorBuilder> extends AbstractXContentSerializingTestCase<T> {
+
+    private NamedWriteableRegistry namedWriteableRegistry;
+    private NamedXContentRegistry namedXContentRegistry;
+
+    protected List<SearchPlugin> additionalPlugins() {
+        return List.of();
+    }
+
+    @Before
+    public void registerNamedXContents() {
+        SearchModule searchModule = new SearchModule(Settings.EMPTY, additionalPlugins());
+        namedXContentRegistry = new NamedXContentRegistry(searchModule.getNamedXContents());
+        namedWriteableRegistry = new NamedWriteableRegistry(searchModule.getNamedWriteables());
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return namedXContentRegistry;
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return namedWriteableRegistry;
+    }
+
+    // Just in case the vector builder needs to know the expected value when testing
+    protected T createTestInstance(float[] expected) {
+        return createTestInstance();
+    }
+
+    public final void testKnnSearchBuilderXContent() throws Exception {
+        AbstractXContentTestCase.XContentTester<KnnSearchBuilder> tester = AbstractXContentTestCase.xContentTester(
+            this::createParser,
+            () -> new KnnSearchBuilder(randomAlphaOfLength(10), createTestInstance(), 5, 10),
+            getToXContentParams(),
+            KnnSearchBuilder::fromXContent
+        );
+        tester.test();
+    }
+
+    public final void testKnnSearchBuilderWireSerialization() throws IOException {
+        for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) {
+            KnnSearchBuilder searchBuilder = new KnnSearchBuilder(randomAlphaOfLength(10), createTestInstance(), 5, 10);
+            KnnSearchBuilder serialized = copyWriteable(
+                searchBuilder,
+                getNamedWriteableRegistry(),
+                KnnSearchBuilder::new,
+                TransportVersion.CURRENT
+            );
+            assertThat(serialized, equalTo(searchBuilder));
+            assertNotSame(serialized, searchBuilder);
+        }
+    }
+
+    public final void testKnnSearchRewrite() throws Exception {
+        for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) {
+            float[] expected = randomVector(randomIntBetween(10, 1024));
+            T queryVectorBuilder = createTestInstance(expected);
+            KnnSearchBuilder searchBuilder = new KnnSearchBuilder(randomAlphaOfLength(10), queryVectorBuilder, 5, 10);
+            KnnSearchBuilder serialized = copyWriteable(
+                searchBuilder,
+                getNamedWriteableRegistry(),
+                KnnSearchBuilder::new,
+                TransportVersion.CURRENT
+            );
+            try (NoOpClient client = new AssertingClient(expected, queryVectorBuilder)) {
+                QueryRewriteContext context = new QueryRewriteContext(null, null, client, null);
+                PlainActionFuture<KnnSearchBuilder> future = new PlainActionFuture<>();
+                Rewriteable.rewriteAndFetch(randomFrom(serialized, searchBuilder), context, future);
+                KnnSearchBuilder rewritten = future.get();
+                assertThat(rewritten.queryVector, equalTo(expected));
+                assertThat(rewritten.queryVectorBuilder, nullValue());
+            }
+        }
+    }
+
+    public final void testVectorFetch() throws Exception {
+        float[] expected = randomVector(randomIntBetween(10, 1024));
+        T queryVectorBuilder = createTestInstance(expected);
+        try (NoOpClient client = new AssertingClient(expected, queryVectorBuilder)) {
+            PlainActionFuture<float[]> future = new PlainActionFuture<>();
+            queryVectorBuilder.buildVector(client, future);
+            assertThat(future.get(), equalTo(expected));
+        }
+    }
+
+    /**
+     * Assert that the client action request is correct given this provided random builder
+     * @param request The built request to be executed by the client
+     * @param builder The builder used when generating this request
+     */
+    abstract void doAssertClientRequest(ActionRequest request, T builder);
+
+    /**
+     * Create a response given this expected array that is acceptable to the query builder
+     * @param array The expected final array
+     * @param builder The original randomly built query vector builder
+     * @return An action response to be handled by the query vector builder
+     */
+    abstract ActionResponse createResponse(float[] array, T builder);
+
+    private class AssertingClient extends NoOpClient {
+
+        private final float[] array;
+        private final T queryVectorBuilder;
+
+        AssertingClient(float[] array, T queryVectorBuilder) {
+            super("query_vector_builder_tests");
+            this.array = array;
+            this.queryVectorBuilder = queryVectorBuilder;
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
+            ActionType<Response> action,
+            Request request,
+            ActionListener<Response> listener
+        ) {
+            doAssertClientRequest(request, queryVectorBuilder);
+            listener.onResponse((Response) createResponse(array, queryVectorBuilder));
+        }
+    }
+}

+ 82 - 1
server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

@@ -8,25 +8,37 @@
 
 package org.elasticsearch.search.vectors;
 
+import org.apache.lucene.search.Query;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.query.AbstractQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.Rewriteable;
+import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.test.AbstractXContentSerializingTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.junit.Before;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Objects;
 
 import static java.util.Collections.emptyList;
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.nullValue;
 
 public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<KnnSearchBuilder> {
     private NamedWriteableRegistry namedWriteableRegistry;
@@ -161,11 +173,80 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
         assertThat(e.getMessage(), containsString("[k] must be greater than 0"));
     }
 
-    private static float[] randomVector(int dim) {
+    public void testRewrite() throws Exception {
+        float[] expectedArray = randomVector(randomIntBetween(10, 1024));
+        KnnSearchBuilder searchBuilder = new KnnSearchBuilder(
+            "field",
+            new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray),
+            5,
+            10
+        );
+        searchBuilder.boost(randomFloat());
+        searchBuilder.addFilterQueries(List.of(new RewriteableQuery()));
+
+        QueryRewriteContext context = new QueryRewriteContext(null, null, null, null);
+        PlainActionFuture<KnnSearchBuilder> future = new PlainActionFuture<>();
+        Rewriteable.rewriteAndFetch(searchBuilder, context, future);
+        KnnSearchBuilder rewritten = future.get();
+
+        assertThat(rewritten.field, equalTo(searchBuilder.field));
+        assertThat(rewritten.boost, equalTo(searchBuilder.boost));
+        assertThat(rewritten.queryVector, equalTo(expectedArray));
+        assertThat(rewritten.queryVectorBuilder, nullValue());
+        assertThat(rewritten.filterQueries, hasSize(1));
+        assertThat(((RewriteableQuery) rewritten.filterQueries.get(0)).rewrites, equalTo(1));
+    }
+
+    static float[] randomVector(int dim) {
         float[] vector = new float[dim];
         for (int i = 0; i < vector.length; i++) {
             vector[i] = randomFloat();
         }
         return vector;
     }
+
+    private static class RewriteableQuery extends AbstractQueryBuilder<RewriteableQuery> {
+        private int rewrites;
+
+        @Override
+        public String getWriteableName() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public TransportVersion getMinimalSupportedVersion() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        protected void doWriteTo(StreamOutput out) {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        protected void doXContent(XContentBuilder builder, Params params) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        protected Query doToQuery(SearchExecutionContext context) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        protected boolean doEquals(RewriteableQuery other) {
+            return true;
+        }
+
+        @Override
+        protected int doHashCode() {
+            return Objects.hashCode(RewriteableQuery.class);
+        }
+
+        @Override
+        protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
+            rewrites++;
+            return this;
+        }
+    }
 }

+ 67 - 0
server/src/test/java/org/elasticsearch/search/vectors/QueryVectorBuilderTests.java

@@ -0,0 +1,67 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Test the query vector builder logic with a test plugin
+ */
+public class QueryVectorBuilderTests extends AbstractQueryVectorBuilderTestCase<TestQueryVectorBuilderPlugin.TestQueryVectorBuilder> {
+
+    @Override
+    protected List<SearchPlugin> additionalPlugins() {
+        return List.of(new TestQueryVectorBuilderPlugin());
+    }
+
+    @Override
+    protected Writeable.Reader<TestQueryVectorBuilderPlugin.TestQueryVectorBuilder> instanceReader() {
+        return TestQueryVectorBuilderPlugin.TestQueryVectorBuilder::new;
+    }
+
+    @Override
+    protected TestQueryVectorBuilderPlugin.TestQueryVectorBuilder createTestInstance() {
+        return new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(randomList(2, 1024, ESTestCase::randomFloat));
+    }
+
+    @Override
+    protected TestQueryVectorBuilderPlugin.TestQueryVectorBuilder createTestInstance(float[] expected) {
+        return new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expected);
+    }
+
+    @Override
+    protected TestQueryVectorBuilderPlugin.TestQueryVectorBuilder mutateInstance(
+        TestQueryVectorBuilderPlugin.TestQueryVectorBuilder instance
+    ) throws IOException {
+        return createTestInstance();
+    }
+
+    @Override
+    protected TestQueryVectorBuilderPlugin.TestQueryVectorBuilder doParseInstance(XContentParser parser) throws IOException {
+        return TestQueryVectorBuilderPlugin.TestQueryVectorBuilder.PARSER.apply(parser, null);
+    }
+
+    @Override
+    protected void doAssertClientRequest(ActionRequest request, TestQueryVectorBuilderPlugin.TestQueryVectorBuilder builder) {
+        // Nothing to assert here as this object does not make client calls
+    }
+
+    @Override
+    protected ActionResponse createResponse(float[] array, TestQueryVectorBuilderPlugin.TestQueryVectorBuilder builder) {
+        return new ActionResponse.Empty();
+    }
+}

+ 114 - 0
server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java

@@ -0,0 +1,114 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * A SearchPlugin to exercise query vector builder
+ */
+class TestQueryVectorBuilderPlugin implements SearchPlugin {
+
+    static class TestQueryVectorBuilder implements QueryVectorBuilder {
+        private static final String NAME = "test_query_vector_builder";
+
+        private static final ParseField QUERY_VECTOR = new ParseField("query_vector");
+
+        @SuppressWarnings("unchecked")
+        static final ConstructingObjectParser<TestQueryVectorBuilder, Void> PARSER = new ConstructingObjectParser<>(
+            NAME + "_parser",
+            true,
+            a -> new TestQueryVectorBuilder((List<Float>) a[0])
+        );
+
+        static {
+            PARSER.declareFloatArray(ConstructingObjectParser.constructorArg(), QUERY_VECTOR);
+        }
+
+        private List<Float> vectorToBuild;
+
+        TestQueryVectorBuilder(List<Float> vectorToBuild) {
+            this.vectorToBuild = vectorToBuild;
+        }
+
+        TestQueryVectorBuilder(float[] expected) {
+            this.vectorToBuild = new ArrayList<>(expected.length);
+            for (float f : expected) {
+                vectorToBuild.add(f);
+            }
+        }
+
+        TestQueryVectorBuilder(StreamInput in) throws IOException {
+            this.vectorToBuild = in.readList(StreamInput::readFloat);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            return builder.startObject().field(QUERY_VECTOR.getPreferredName(), vectorToBuild).endObject();
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME;
+        }
+
+        @Override
+        public TransportVersion getMinimalSupportedVersion() {
+            return TransportVersion.CURRENT;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeCollection(vectorToBuild, StreamOutput::writeFloat);
+        }
+
+        @Override
+        public void buildVector(Client client, ActionListener<float[]> listener) {
+            float[] response = new float[vectorToBuild.size()];
+            int i = 0;
+            for (Float f : vectorToBuild) {
+                response[i++] = f;
+            }
+            listener.onResponse(response);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            TestQueryVectorBuilder that = (TestQueryVectorBuilder) o;
+            return Objects.equals(vectorToBuild, that.vectorToBuild);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(vectorToBuild);
+        }
+    }
+
+    @Override
+    public List<QueryVectorBuilderSpec<?>> getQueryVectorBuilders() {
+        return List.of(
+            new QueryVectorBuilderSpec<>(TestQueryVectorBuilder.NAME, TestQueryVectorBuilder::new, TestQueryVectorBuilder.PARSER::apply)
+        );
+    }
+}