Browse Source

Add support for sparse_vector queries against semantic_text fields (#118617) (#118951)

(cherry picked from commit 15bec3cefa48c958dada0fba42f452fd278c1ab6)

# Conflicts:
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java
Kathleen DeRusso 10 months ago
parent
commit
19fd296315
13 changed files with 887 additions and 76 deletions
  1. 5 0
      docs/changelog/118617.yaml
  2. 30 8
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java
  3. 3 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java
  4. 5 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java
  5. 2 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
  6. 29 59
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java
  7. 8 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
  8. 148 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java
  9. 124 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java
  10. 110 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java
  11. 137 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java
  12. 249 0
      x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml
  13. 37 3
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/sparse_vector_search.yml

+ 5 - 0
docs/changelog/118617.yaml

@@ -0,0 +1,5 @@
+pr: 118617
+summary: Add support for `sparse_vector` queries against `semantic_text` fields
+area: "Search"
+type: enhancement
+issues: []

+ 30 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java

@@ -90,7 +90,8 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
             : (this.shouldPruneTokens ? new TokenPruningConfig() : null));
         this.weightedTokensSupplier = null;
 
-        if (queryVectors == null ^ inferenceId == null == false) {
+        // Preserve BWC error messaging
+        if (queryVectors != null && inferenceId != null) {
             throw new IllegalArgumentException(
                 "["
                     + NAME
@@ -98,18 +99,24 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
                     + QUERY_VECTOR_FIELD.getPreferredName()
                     + "] or ["
                     + INFERENCE_ID_FIELD.getPreferredName()
-                    + "]"
+                    + "] for "
+                    + ALLOWED_FIELD_TYPE
+                    + " fields"
             );
         }
-        if (inferenceId != null && query == null) {
+
+        // Preserve BWC error messaging
+        if ((queryVectors == null) == (query == null)) {
             throw new IllegalArgumentException(
                 "["
                     + NAME
-                    + "] requires ["
-                    + QUERY_FIELD.getPreferredName()
-                    + "] when ["
+                    + "] requires one of ["
+                    + QUERY_VECTOR_FIELD.getPreferredName()
+                    + "] or ["
                     + INFERENCE_ID_FIELD.getPreferredName()
-                    + "] is specified"
+                    + "] for "
+                    + ALLOWED_FIELD_TYPE
+                    + " fields"
             );
         }
     }
@@ -143,6 +150,14 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
         return queryVectors;
     }
 
+    public String getInferenceId() {
+        return inferenceId;
+    }
+
+    public String getQuery() {
+        return query;
+    }
+
     public boolean shouldPruneTokens() {
         return shouldPruneTokens;
     }
@@ -176,7 +191,9 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
             }
             builder.endObject();
         } else {
-            builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
+            if (inferenceId != null) {
+                builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
+            }
             builder.field(QUERY_FIELD.getPreferredName(), query);
         }
         builder.field(PRUNE_FIELD.getPreferredName(), shouldPruneTokens);
@@ -228,6 +245,11 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
                 shouldPruneTokens,
                 tokenPruningConfig
             );
+        } else if (inferenceId == null) {
+            // Edge case, where inference_id was not specified in the request,
+            // but we did not intercept this and rewrite to a query o field with
+            // pre-configured inference. So we trap here and output a nicer error message.
+            throw new IllegalArgumentException("inference_id required to perform vector search on query string");
         }
 
         // TODO move this to xpack core and use inference APIs

+ 3 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java

@@ -260,16 +260,16 @@ public class SparseVectorQueryBuilderTests extends AbstractQueryTestCase<SparseV
         {
             IllegalArgumentException e = expectThrows(
                 IllegalArgumentException.class,
-                () -> new SparseVectorQueryBuilder("field name", null, "model id")
+                () -> new SparseVectorQueryBuilder("field name", null, null)
             );
-            assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id]", e.getMessage());
+            assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
         }
         {
             IllegalArgumentException e = expectThrows(
                 IllegalArgumentException.class,
                 () -> new SparseVectorQueryBuilder("field name", "model text", null)
             );
-            assertEquals("[sparse_vector] requires [query] when [inference_id] is specified", e.getMessage());
+            assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
         }
     }
 

+ 5 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

@@ -10,12 +10,14 @@ package org.elasticsearch.xpack.inference;
 import org.elasticsearch.features.FeatureSpecification;
 import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
-import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
 import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
 import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
 
 import java.util.Set;
 
+import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
+import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
+
 /**
  * Provides inference features.
  */
@@ -43,7 +45,8 @@ public class InferenceFeatures implements FeatureSpecification {
             SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX,
             SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
             SEMANTIC_TEXT_HIGHLIGHTER,
-            SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED
+            SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
+            SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED
         );
     }
 }

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -75,6 +75,7 @@ import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
 import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
 import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
 import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
+import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
 import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
 import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
 import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
@@ -404,7 +405,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
 
     @Override
     public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
-        return List.of(new SemanticMatchQueryRewriteInterceptor());
+        return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor());
     }
 
     @Override

+ 29 - 59
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java

@@ -7,24 +7,12 @@
 
 package org.elasticsearch.xpack.inference.queries;
 
-import org.elasticsearch.action.ResolvedIndices;
-import org.elasticsearch.cluster.metadata.IndexMetadata;
-import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
 import org.elasticsearch.features.NodeFeature;
-import org.elasticsearch.index.mapper.IndexFieldMapper;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.MatchQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.index.query.QueryRewriteContext;
-import org.elasticsearch.index.query.TermQueryBuilder;
-import org.elasticsearch.index.query.TermsQueryBuilder;
-import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
 
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-
-public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
+public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
 
     public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
         "search.semantic_match_query_rewrite_interception_supported"
@@ -33,63 +21,45 @@ public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterce
     public SemanticMatchQueryRewriteInterceptor() {}
 
     @Override
-    public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
+    protected String getFieldName(QueryBuilder queryBuilder) {
         assert (queryBuilder instanceof MatchQueryBuilder);
         MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
-        QueryBuilder rewritten = queryBuilder;
-        ResolvedIndices resolvedIndices = context.getResolvedIndices();
-        if (resolvedIndices != null) {
-            Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
-            List<String> inferenceIndices = new ArrayList<>();
-            List<String> nonInferenceIndices = new ArrayList<>();
-            for (IndexMetadata indexMetadata : indexMetadataCollection) {
-                String indexName = indexMetadata.getIndex().getName();
-                InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName());
-                if (inferenceFieldMetadata != null) {
-                    inferenceIndices.add(indexName);
-                } else {
-                    nonInferenceIndices.add(indexName);
-                }
-            }
-
-            if (inferenceIndices.isEmpty()) {
-                return rewritten;
-            } else if (nonInferenceIndices.isEmpty() == false) {
-                BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
-                for (String inferenceIndexName : inferenceIndices) {
-                    // Add a separate clause for each semantic query, because they may be using different inference endpoints
-                    // TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints
-                    boolQueryBuilder.should(
-                        createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value())
-                    );
-                }
-                boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder));
-                rewritten = boolQueryBuilder;
-            } else {
-                rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
-            }
-        }
-
-        return rewritten;
+        return matchQueryBuilder.fieldName();
+    }
 
+    @Override
+    protected String getQuery(QueryBuilder queryBuilder) {
+        assert (queryBuilder instanceof MatchQueryBuilder);
+        MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
+        return (String) matchQueryBuilder.value();
     }
 
     @Override
-    public String getQueryName() {
-        return MatchQueryBuilder.NAME;
+    protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
+        return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
     }
 
-    private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) {
+    @Override
+    protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
+        QueryBuilder queryBuilder,
+        InferenceIndexInformationForField indexInformation
+    ) {
+        assert (queryBuilder instanceof MatchQueryBuilder);
+        MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
         BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
-        boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
-        boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
+        boolQueryBuilder.should(
+            createSemanticSubQuery(
+                indexInformation.getInferenceIndices(),
+                matchQueryBuilder.fieldName(),
+                (String) matchQueryBuilder.value()
+            )
+        );
+        boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
         return boolQueryBuilder;
     }
 
-    private QueryBuilder createMatchSubQuery(List<String> indices, MatchQueryBuilder matchQueryBuilder) {
-        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
-        boolQueryBuilder.must(matchQueryBuilder);
-        boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
-        return boolQueryBuilder;
+    @Override
+    public String getQueryName() {
+        return MatchQueryBuilder.NAME;
     }
 }

+ 8 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

@@ -144,6 +144,14 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
         return NAME;
     }
 
+    public String getFieldName() {
+        return fieldName;
+    }
+
+    public String getQuery() {
+        return query;
+    }
+
     @Override
     public TransportVersion getMinimalSupportedVersion() {
         return TransportVersions.V_8_15_0;

+ 148 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java

@@ -0,0 +1,148 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.queries;
+
+import org.elasticsearch.action.ResolvedIndices;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
+import org.elasticsearch.index.mapper.IndexFieldMapper;
+import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.TermsQueryBuilder;
+import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Intercepts and adapts a query to be rewritten to work seamlessly on a semantic_text field.
+ */
+public abstract class SemanticQueryRewriteInterceptor implements QueryRewriteInterceptor {
+
+    public SemanticQueryRewriteInterceptor() {}
+
+    @Override
+    public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
+        String fieldName = getFieldName(queryBuilder);
+        ResolvedIndices resolvedIndices = context.getResolvedIndices();
+
+        if (resolvedIndices == null) {
+            // No resolved indices, so return the original query.
+            return queryBuilder;
+        }
+
+        InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices);
+        if (indexInformation.getInferenceIndices().isEmpty()) {
+            // No inference fields were identified, so return the original query.
+            return queryBuilder;
+        } else if (indexInformation.nonInferenceIndices().isEmpty() == false) {
+            // Combined case where the field name requested by this query contains both
+            // semantic_text and non-inference fields, so we have to combine queries per index
+            // containing each field type.
+            return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
+        } else {
+            // The only fields we've identified are inference fields (e.g. semantic_text),
+            // so rewrite the entire query to work on a semantic_text field.
+            return buildInferenceQuery(queryBuilder, indexInformation);
+        }
+    }
+
+    /**
+     * @param queryBuilder {@link QueryBuilder}
+     * @return The singular field name requested by the provided query builder.
+     */
+    protected abstract String getFieldName(QueryBuilder queryBuilder);
+
+    /**
+     * @param queryBuilder {@link QueryBuilder}
+     * @return The text/query string requested by the provided query builder.
+     */
+    protected abstract String getQuery(QueryBuilder queryBuilder);
+
+    /**
+     * Builds the inference query
+     *
+     * @param queryBuilder {@link QueryBuilder}
+     * @param indexInformation {@link InferenceIndexInformationForField}
+     * @return {@link QueryBuilder}
+     */
+    protected abstract QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation);
+
+    /**
+     * Builds a combined inference and non-inference query,
+     * which separates the different queries into appropriate indices based on field type.
+     * @param queryBuilder {@link QueryBuilder}
+     * @param indexInformation {@link InferenceIndexInformationForField}
+     * @return {@link QueryBuilder}
+     */
+    protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
+        QueryBuilder queryBuilder,
+        InferenceIndexInformationForField indexInformation
+    );
+
+    private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
+        Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
+        Map<String, InferenceFieldMetadata> inferenceIndicesMetadata = new HashMap<>();
+        List<String> nonInferenceIndices = new ArrayList<>();
+        for (IndexMetadata indexMetadata : indexMetadataCollection) {
+            String indexName = indexMetadata.getIndex().getName();
+            InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName);
+            if (inferenceFieldMetadata != null) {
+                inferenceIndicesMetadata.put(indexName, inferenceFieldMetadata);
+            } else {
+                nonInferenceIndices.add(indexName);
+            }
+        }
+
+        return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices);
+    }
+
+    protected QueryBuilder createSubQueryForIndices(Collection<String> indices, QueryBuilder queryBuilder) {
+        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
+        boolQueryBuilder.must(queryBuilder);
+        boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
+        return boolQueryBuilder;
+    }
+
+    protected QueryBuilder createSemanticSubQuery(Collection<String> indices, String fieldName, String value) {
+        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
+        boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
+        boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
+        return boolQueryBuilder;
+    }
+
+    /**
+     * Represents the indices and associated inference information for a field.
+     */
+    public record InferenceIndexInformationForField(
+        String fieldName,
+        Map<String, InferenceFieldMetadata> inferenceIndicesMetadata,
+        List<String> nonInferenceIndices
+    ) {
+
+        public Collection<String> getInferenceIndices() {
+            return inferenceIndicesMetadata.keySet();
+        }
+
+        public Map<String, List<String>> getInferenceIdsIndices() {
+            return inferenceIndicesMetadata.entrySet()
+                .stream()
+                .collect(
+                    Collectors.groupingBy(
+                        entry -> entry.getValue().getSearchInferenceId(),
+                        Collectors.mapping(Map.Entry::getKey, Collectors.toList())
+                    )
+                );
+        }
+    }
+}

+ 124 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java

@@ -0,0 +1,124 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.queries;
+
+import org.apache.lucene.search.join.ScoreMode;
+import org.elasticsearch.features.NodeFeature;
+import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
+import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
+
+import java.util.List;
+import java.util.Map;
+
+public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
+
+    public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
+        "search.semantic_sparse_vector_query_rewrite_interception_supported"
+    );
+
+    public SemanticSparseVectorQueryRewriteInterceptor() {}
+
+    @Override
+    protected String getFieldName(QueryBuilder queryBuilder) {
+        assert (queryBuilder instanceof SparseVectorQueryBuilder);
+        SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
+        return sparseVectorQueryBuilder.getFieldName();
+    }
+
+    @Override
+    protected String getQuery(QueryBuilder queryBuilder) {
+        assert (queryBuilder instanceof SparseVectorQueryBuilder);
+        SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
+        return sparseVectorQueryBuilder.getQuery();
+    }
+
+    @Override
+    protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
+        Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
+        if (inferenceIdsIndices.size() == 1) {
+            // Simple case, everything uses the same inference ID
+            String searchInferenceId = inferenceIdsIndices.keySet().iterator().next();
+            return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
+        } else {
+            // Multiple inference IDs, construct a boolean query
+            return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
+        }
+    }
+
+    private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
+        QueryBuilder queryBuilder,
+        Map<String, List<String>> inferenceIdsIndices
+    ) {
+        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
+        for (String inferenceId : inferenceIdsIndices.keySet()) {
+            boolQueryBuilder.should(
+                createSubQueryForIndices(
+                    inferenceIdsIndices.get(inferenceId),
+                    buildNestedQueryFromSparseVectorQuery(queryBuilder, inferenceId)
+                )
+            );
+        }
+        return boolQueryBuilder;
+    }
+
+    @Override
+    protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
+        QueryBuilder queryBuilder,
+        InferenceIndexInformationForField indexInformation
+    ) {
+        assert (queryBuilder instanceof SparseVectorQueryBuilder);
+        SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
+        Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
+
+        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
+        boolQueryBuilder.should(
+            createSubQueryForIndices(
+                indexInformation.nonInferenceIndices(),
+                createSubQueryForIndices(indexInformation.nonInferenceIndices(), sparseVectorQueryBuilder)
+            )
+        );
+        // We always perform nested subqueries on semantic_text fields, to support
+        // sparse_vector queries using query vectors.
+        for (String inferenceId : inferenceIdsIndices.keySet()) {
+            boolQueryBuilder.should(
+                createSubQueryForIndices(
+                    inferenceIdsIndices.get(inferenceId),
+                    buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder, inferenceId)
+                )
+            );
+        }
+        return boolQueryBuilder;
+    }
+
+    private QueryBuilder buildNestedQueryFromSparseVectorQuery(QueryBuilder queryBuilder, String searchInferenceId) {
+        assert (queryBuilder instanceof SparseVectorQueryBuilder);
+        SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
+        return QueryBuilders.nestedQuery(
+            SemanticTextField.getChunksFieldName(sparseVectorQueryBuilder.getFieldName()),
+            new SparseVectorQueryBuilder(
+                SemanticTextField.getEmbeddingsFieldName(sparseVectorQueryBuilder.getFieldName()),
+                sparseVectorQueryBuilder.getQueryVectors(),
+                (sparseVectorQueryBuilder.getInferenceId() == null && sparseVectorQueryBuilder.getQuery() != null)
+                    ? searchInferenceId
+                    : sparseVectorQueryBuilder.getInferenceId(),
+                sparseVectorQueryBuilder.getQuery(),
+                sparseVectorQueryBuilder.shouldPruneTokens(),
+                sparseVectorQueryBuilder.getTokenPruningConfig()
+            ),
+            ScoreMode.Max
+        );
+    }
+
+    @Override
+    public String getQueryName() {
+        return SparseVectorQueryBuilder.NAME;
+    }
+}

+ 110 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java

@@ -0,0 +1,110 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.index.query;
+
+import org.elasticsearch.action.MockResolvedIndices;
+import org.elasticsearch.action.OriginalIndices;
+import org.elasticsearch.action.ResolvedIndices;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.Index;
+import org.elasticsearch.index.IndexVersion;
+import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.client.NoOpClient;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
+import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.Map;
+
+public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase {
+
+    private TestThreadPool threadPool;
+    private NoOpClient client;
+    private Index index;
+
+    private static final String FIELD_NAME = "fieldName";
+    private static final String VALUE = "value";
+
+    @Before
+    public void setup() {
+        threadPool = createThreadPool();
+        client = new NoOpClient(threadPool);
+        index = new Index(randomAlphaOfLength(10), randomAlphaOfLength(10));
+    }
+
+    @After
+    public void cleanup() {
+        threadPool.close();
+    }
+
+    public void testMatchQueryOnInferenceFieldIsInterceptedAndRewrittenToSemanticQuery() throws IOException {
+        Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
+            FIELD_NAME,
+            new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
+        );
+        QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
+        QueryBuilder original = createTestQueryBuilder();
+        QueryBuilder rewritten = original.rewrite(context);
+        assertTrue(
+            "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
+            rewritten instanceof InterceptedQueryBuilderWrapper
+        );
+        InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
+        assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder);
+        SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder;
+        assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName());
+        assertEquals(VALUE, semanticQueryBuilder.getQuery());
+    }
+
+    public void testMatchQueryOnNonInferenceFieldRemainsMatchQuery() throws IOException {
+        QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
+        QueryBuilder original = createTestQueryBuilder();
+        QueryBuilder rewritten = original.rewrite(context);
+        assertTrue(
+            "Expected query to remain match but was [" + rewritten.getClass().getName() + "]",
+            rewritten instanceof MatchQueryBuilder
+        );
+        assertEquals(original, rewritten);
+    }
+
+    private MatchQueryBuilder createTestQueryBuilder() {
+        return new MatchQueryBuilder(FIELD_NAME, VALUE);
+    }
+
+    private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
+        IndexMetadata indexMetadata = IndexMetadata.builder(index.getName())
+            .settings(
+                Settings.builder()
+                    .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
+                    .put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID())
+            )
+            .numberOfShards(1)
+            .numberOfReplicas(0)
+            .putInferenceFields(inferenceFields)
+            .build();
+
+        ResolvedIndices resolvedIndices = new MockResolvedIndices(
+            Map.of(),
+            new OriginalIndices(new String[] { index.getName() }, IndicesOptions.DEFAULT),
+            Map.of(index, indexMetadata)
+        );
+
+        return new QueryRewriteContext(null, client, null, resolvedIndices, null, createRewriteInterceptor());
+    }
+
+    private QueryRewriteInterceptor createRewriteInterceptor() {
+        return new SemanticMatchQueryRewriteInterceptor();
+    }
+}

+ 137 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java

@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.index.query;
+
+import org.elasticsearch.action.MockResolvedIndices;
+import org.elasticsearch.action.OriginalIndices;
+import org.elasticsearch.action.ResolvedIndices;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.Index;
+import org.elasticsearch.index.IndexVersion;
+import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.client.NoOpClient;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
+import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
+import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.Map;
+
+public class SemanticSparseVectorQueryRewriteInterceptorTests extends ESTestCase {
+
+    private TestThreadPool threadPool;
+    private NoOpClient client;
+    private Index index;
+
+    private static final String FIELD_NAME = "fieldName";
+    private static final String INFERENCE_ID = "inferenceId";
+    private static final String QUERY = "query";
+
+    @Before
+    public void setup() {
+        threadPool = createThreadPool();
+        client = new NoOpClient(threadPool);
+        index = new Index(randomAlphaOfLength(10), randomAlphaOfLength(10));
+    }
+
+    @After
+    public void cleanup() {
+        threadPool.close();
+    }
+
+    public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException {
+        Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
+            FIELD_NAME,
+            new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
+        );
+        QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
+        QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
+        QueryBuilder rewritten = original.rewrite(context);
+        assertTrue(
+            "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
+            rewritten instanceof InterceptedQueryBuilderWrapper
+        );
+        InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
+        assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
+        NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
+        assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
+        QueryBuilder innerQuery = nestedQueryBuilder.query();
+        assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
+        SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
+        assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
+        assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
+        assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
+    }
+
+    public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
+        Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
+            FIELD_NAME,
+            new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
+        );
+        QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
+        QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
+        QueryBuilder rewritten = original.rewrite(context);
+        assertTrue(
+            "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
+            rewritten instanceof InterceptedQueryBuilderWrapper
+        );
+        InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
+        assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
+        NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
+        assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
+        QueryBuilder innerQuery = nestedQueryBuilder.query();
+        assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
+        SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
+        assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
+        assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
+        assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
+    }
+
+    public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
+        QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
+        QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
+        QueryBuilder rewritten = original.rewrite(context);
+        assertTrue(
+            "Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
+            rewritten instanceof SparseVectorQueryBuilder
+        );
+        assertEquals(original, rewritten);
+    }
+
+    private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
+        IndexMetadata indexMetadata = IndexMetadata.builder(index.getName())
+            .settings(
+                Settings.builder()
+                    .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
+                    .put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID())
+            )
+            .numberOfShards(1)
+            .numberOfReplicas(0)
+            .putInferenceFields(inferenceFields)
+            .build();
+
+        ResolvedIndices resolvedIndices = new MockResolvedIndices(
+            Map.of(),
+            new OriginalIndices(new String[] { index.getName() }, IndicesOptions.DEFAULT),
+            Map.of(index, indexMetadata)
+        );
+
+        return new QueryRewriteContext(null, client, null, resolvedIndices, null, createRewriteInterceptor());
+    }
+
+    private QueryRewriteInterceptor createRewriteInterceptor() {
+        return new SemanticSparseVectorQueryRewriteInterceptor();
+    }
+}

+ 249 - 0
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml

@@ -0,0 +1,249 @@
+setup:
+  - requires:
+      cluster_features: "search.semantic_sparse_vector_query_rewrite_interception_supported"
+      reason: semantic_text sparse_vector support introduced in 8.18.0
+
+  - do:
+      inference.put:
+        task_type: sparse_embedding
+        inference_id: sparse-inference-id
+        body: >
+          {
+            "service": "test_service",
+            "service_settings": {
+              "model": "my_model",
+              "api_key": "abc64"
+            },
+            "task_settings": {
+            }
+          }
+
+  - do:
+      inference.put:
+        task_type: sparse_embedding
+        inference_id: sparse-inference-id-2
+        body: >
+          {
+            "service": "test_service",
+            "service_settings": {
+              "model": "my_model",
+              "api_key": "abc64"
+            },
+            "task_settings": {
+            }
+          }
+
+  - do:
+      indices.create:
+        index: test-semantic-text-index
+        body:
+          mappings:
+            properties:
+              inference_field:
+                type: semantic_text
+                inference_id: sparse-inference-id
+
+  - do:
+      indices.create:
+        index: test-semantic-text-index-2
+        body:
+          mappings:
+            properties:
+              inference_field:
+                type: semantic_text
+                inference_id: sparse-inference-id-2
+
+  - do:
+      indices.create:
+        index: test-sparse-vector-index
+        body:
+          mappings:
+            properties:
+              inference_field:
+                type: sparse_vector
+
+  - do:
+      index:
+        index: test-semantic-text-index
+        id: doc_1
+        body:
+          inference_field: [ "inference test", "another inference test" ]
+        refresh: true
+
+  - do:
+      index:
+        index: test-semantic-text-index-2
+        id: doc_3
+        body:
+          inference_field: [ "inference test", "another inference test" ]
+        refresh: true
+
+  - do:
+      index:
+        index: test-sparse-vector-index
+        id: doc_2
+        body:
+          inference_field: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
+        refresh: true
+
+---
+"Nested sparse_vector queries using the old format on semantic_text embeddings and inference still work":
+  - skip:
+      features: [ "headers" ]
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index
+        body:
+          query:
+            nested:
+              path: inference_field.inference.chunks
+              query:
+                sparse_vector:
+                  field: inference_field.inference.chunks.embeddings
+                  inference_id: sparse-inference-id
+                  query: test
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+
+---
+"Nested sparse_vector queries using the old format on semantic_text embeddings and query vectors still work":
+  - skip:
+      features: [ "headers" ]
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index
+        body:
+          query:
+            nested:
+              path: inference_field.inference.chunks
+              query:
+                sparse_vector:
+                  field: inference_field.inference.chunks.embeddings
+                  query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+
+---
+"sparse_vector query against semantic_text field using a specified inference ID":
+
+  - do:
+      search:
+        index: test-semantic-text-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              inference_id: sparse-inference-id
+              query: "inference test"
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+
+---
+"sparse_vector query against semantic_text field using inference ID configured in semantic_text field":
+
+  - do:
+      search:
+        index: test-semantic-text-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              query: "inference test"
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+
+---
+"sparse_vector query against semantic_text field using query vectors":
+
+  - do:
+      search:
+        index: test-semantic-text-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+
+---
+"sparse_vector query against combined sparse_vector and semantic_text fields using inference":
+
+  - do:
+      search:
+        index:
+          - test-semantic-text-index
+          - test-sparse-vector-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              inference_id: sparse-inference-id
+              query: "inference test"
+
+  - match: { hits.total.value: 2 }
+
+---
+"sparse_vector query against combined sparse_vector and semantic_text fields still requires inference ID":
+
+  - do:
+      catch: bad_request
+      search:
+        index:
+          - test-semantic-text-index
+          - test-sparse-vector-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              query: "inference test"
+
+  - match: { error.type: "illegal_argument_exception" }
+  - match: { error.reason: "inference_id required to perform vector search on query string" }
+
+---
+"sparse_vector query against combined sparse_vector and semantic_text fields using query vectors":
+
+  - do:
+      search:
+        index:
+          - test-semantic-text-index
+          - test-sparse-vector-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
+
+  - match: { hits.total.value: 2 }
+
+
+---
+"sparse_vector query against multiple semantic_text fields with multiple inference IDs specified in semantic_text fields":
+
+  - do:
+      search:
+        index:
+          - test-semantic-text-index
+          - test-semantic-text-index-2
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              query: "inference test"
+
+  - match: { hits.total.value: 2 }
+

+ 37 - 3
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/sparse_vector_search.yml

@@ -268,7 +268,7 @@ setup:
   - match: { hits.hits.0._score: 0.25 }
 
 ---
-"Test sparse_vector requires one of inference_id or query_vector":
+"Test sparse_vector requires one of query or query_vector":
   - do:
       catch: /\[sparse_vector\] requires one of \[query_vector\] or \[inference_id\]/
       search:
@@ -281,7 +281,41 @@ setup:
   - match: { status: 400 }
 
 ---
-"Test sparse_vector only allows one of inference_id or query_vector":
+"Test sparse_vector returns an error if inference ID not specified with query":
+  - do:
+      catch: bad_request # This is for BWC, the actual error message is tested in a subsequent test
+      search:
+        index: index-with-sparse-vector
+        body:
+          query:
+            sparse_vector:
+              field: text
+              query: "octopus comforter smells"
+
+  - match: { status: 400 }
+
+---
+"Test sparse_vector requires an inference ID to be specified on sparse_vector fields":
+  - requires:
+      cluster_features: [ "search.semantic_sparse_vector_query_rewrite_interception_supported" ]
+      reason: "Error message changed in 8.18"
+  - do:
+      catch: /inference_id required to perform vector search on query string/
+      search:
+        index: index-with-sparse-vector
+        body:
+          query:
+            sparse_vector:
+              field: text
+              query: "octopus comforter smells"
+
+  - match: { status: 400 }
+
+---
+"Test sparse_vector only allows one of query or query_vector (note the error message is misleading)":
+  - requires:
+      cluster_features: [ "search.semantic_sparse_vector_query_rewrite_interception_supported" ]
+      reason: "sparse vector inference checks updated in 8.18 to support sparse_vector on semantic_text fields"
   - do:
       catch: /\[sparse_vector\] requires one of \[query_vector\] or \[inference_id\]/
       search:
@@ -290,7 +324,7 @@ setup:
           query:
             sparse_vector:
               field: text
-              inference_id: text_expansion_model
+              query: "octopus comforter smells"
               query_vector:
                 the: 1.0
                 comforter: 1.0