Browse Source

Fix Semantic Query Rewrite Interception Drops Boosts (#129282) (#131469)

* fix boosting for knn

* Fixing for match query

* fixing for match subquery

* fix for sparse vector query boost

* fix linting issues

* Update docs/changelog/129282.yaml

* update changelog

* Copy constructor with match query

* util function to create sparseVectorBuilder for sparse query

* util function for knn query to support boost

* adding unit tests for all intercepted query terms

* Adding yaml test for match,sparse, and knn

* Adding queryname support for nested query

* fix code styles

* Fix failed yaml tests

* Update docs/changelog/129282.yaml

* update yaml tests to expand test scenarios

* Updating knn to copy constructor

* adding yaml tests for multiple indices

* refactoring match query to adjust boost and queryname and move to copy constructor

* refactoring sparse query to adjust boost and queryname and move to copy constructor

* [CI] Auto commit changes from spotless

* Refactor sparse vector to adjust boost and queryname in the top level

* Refactor knn vector to adjust boost and queryname in the top level

* fix knn combined query

* fix unit tests

* fix lint issues

* remove unused code

* Update inference feature name

* Remove double boosting issue from match

* Fix double boosting in match test yaml file

* move to bool level for match semantic boost

* fix double boosting for sparse vector

* fix double boosting for sparse vector in yaml test

* fix knn combined query

* fix knn combined query

* fix sparse combined query

* fix knn yaml test for combined query

* refactoring unit tests

* linting

* fix match query unit test

* adding copy constructor for match query

* refactor copy match builder to intercepter

* [CI] Auto commit changes from spotless

* fix unit tests

* update yaml tests

* fix match yaml test

* fix yaml tests with 4 digits error margin

* unit tests are now more randomized

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
(cherry picked from commit e2bb47c3bb5be3ed77ac5026db40d9297b458f36)

# Conflicts:
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java
Samiul Monir 3 months ago
parent
commit
ce38cda03b

+ 6 - 0
docs/changelog/129282.yaml

@@ -0,0 +1,6 @@
+pr: 129282
+summary: "Fix query rewrite logic to preserve `boosts` and `queryName` for `match`,\
+  \ `knn`, and `sparse_vector` queries on semantic_text fields"
+area: Search
+type: bug
+issues: []

+ 5 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

@@ -36,6 +36,9 @@ public class InferenceFeatures implements FeatureSpecification {
     private static final NodeFeature TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS = new NodeFeature(
         "test_rule_retriever.with_indices_that_dont_return_rank_docs"
     );
+    private static final NodeFeature SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX = new NodeFeature(
+        "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
+    );
     private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
     private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");
 
@@ -66,7 +69,8 @@ public class InferenceFeatures implements FeatureSpecification {
             SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
             SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
             SEMANTIC_TEXT_INDEX_OPTIONS,
-            COHERE_V2_API
+            COHERE_V2_API,
+            SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX
         );
     }
 }

+ 8 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java

@@ -52,16 +52,20 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri
         assert (queryBuilder instanceof KnnVectorQueryBuilder);
         KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
         Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
+        QueryBuilder finalQueryBuilder;
         if (inferenceIdsIndices.size() == 1) {
             // Simple case, everything uses the same inference ID
             Map.Entry<String, List<String>> inferenceIdIndex = inferenceIdsIndices.entrySet().iterator().next();
             String searchInferenceId = inferenceIdIndex.getKey();
             List<String> indices = inferenceIdIndex.getValue();
-            return buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
+            finalQueryBuilder = buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
         } else {
             // Multiple inference IDs, construct a boolean query
-            return buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
+            finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
         }
+        finalQueryBuilder.boost(queryBuilder.boost());
+        finalQueryBuilder.queryName(queryBuilder.queryName());
+        return finalQueryBuilder;
     }
 
     private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
@@ -102,6 +106,8 @@ public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewri
                 )
             );
         }
+        boolQueryBuilder.boost(queryBuilder.boost());
+        boolQueryBuilder.queryName(queryBuilder.queryName());
         return boolQueryBuilder;
     }
 

+ 30 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java

@@ -36,7 +36,10 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn
 
     @Override
     protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
-        return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
+        SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
+        semanticQueryBuilder.boost(queryBuilder.boost());
+        semanticQueryBuilder.queryName(queryBuilder.queryName());
+        return semanticQueryBuilder;
     }
 
     @Override
@@ -45,7 +48,10 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn
         InferenceIndexInformationForField indexInformation
     ) {
         assert (queryBuilder instanceof MatchQueryBuilder);
-        MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
+        MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder;
+        // Create a copy for non-inference fields without boost and _name
+        MatchQueryBuilder matchQueryBuilder = copyMatchQueryBuilder(originalMatchQueryBuilder);
+
         BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
         boolQueryBuilder.should(
             createSemanticSubQuery(
@@ -55,6 +61,8 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn
             )
         );
         boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
+        boolQueryBuilder.boost(queryBuilder.boost());
+        boolQueryBuilder.queryName(queryBuilder.queryName());
         return boolQueryBuilder;
     }
 
@@ -62,4 +70,24 @@ public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteIn
     public String getQueryName() {
         return MatchQueryBuilder.NAME;
     }
+
+    private MatchQueryBuilder copyMatchQueryBuilder(MatchQueryBuilder queryBuilder) {
+        MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(queryBuilder.fieldName(), queryBuilder.value());
+        matchQueryBuilder.operator(queryBuilder.operator());
+        matchQueryBuilder.prefixLength(queryBuilder.prefixLength());
+        matchQueryBuilder.maxExpansions(queryBuilder.maxExpansions());
+        matchQueryBuilder.fuzzyTranspositions(queryBuilder.fuzzyTranspositions());
+        matchQueryBuilder.lenient(queryBuilder.lenient());
+        matchQueryBuilder.zeroTermsQuery(queryBuilder.zeroTermsQuery());
+        matchQueryBuilder.analyzer(queryBuilder.analyzer());
+        matchQueryBuilder.minimumShouldMatch(queryBuilder.minimumShouldMatch());
+        matchQueryBuilder.fuzzyRewrite(queryBuilder.fuzzyRewrite());
+
+        if (queryBuilder.fuzziness() != null) {
+            matchQueryBuilder.fuzziness(queryBuilder.fuzziness());
+        }
+
+        matchQueryBuilder.autoGenerateSynonymsPhraseQuery(queryBuilder.autoGenerateSynonymsPhraseQuery());
+        return matchQueryBuilder;
+    }
 }

+ 21 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java

@@ -43,14 +43,18 @@ public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRe
     @Override
     protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
         Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
+        QueryBuilder finalQueryBuilder;
         if (inferenceIdsIndices.size() == 1) {
             // Simple case, everything uses the same inference ID
             String searchInferenceId = inferenceIdsIndices.keySet().iterator().next();
-            return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
+            finalQueryBuilder = buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
         } else {
             // Multiple inference IDs, construct a boolean query
-            return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
+            finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
         }
+        finalQueryBuilder.queryName(queryBuilder.queryName());
+        finalQueryBuilder.boost(queryBuilder.boost());
+        return finalQueryBuilder;
     }
 
     private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
@@ -79,7 +83,19 @@ public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRe
         Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
 
         BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
-        boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), sparseVectorQueryBuilder));
+        boolQueryBuilder.should(
+            createSubQueryForIndices(
+                indexInformation.nonInferenceIndices(),
+                new SparseVectorQueryBuilder(
+                    sparseVectorQueryBuilder.getFieldName(),
+                    sparseVectorQueryBuilder.getQueryVectors(),
+                    sparseVectorQueryBuilder.getInferenceId(),
+                    sparseVectorQueryBuilder.getQuery(),
+                    sparseVectorQueryBuilder.shouldPruneTokens(),
+                    sparseVectorQueryBuilder.getTokenPruningConfig()
+                )
+            )
+        );
         // We always perform nested subqueries on semantic_text fields, to support
         // sparse_vector queries using query vectors.
         for (String inferenceId : inferenceIdsIndices.keySet()) {
@@ -90,6 +106,8 @@ public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRe
                 )
             );
         }
+        boolQueryBuilder.boost(queryBuilder.boost());
+        boolQueryBuilder.queryName(queryBuilder.queryName());
         return boolQueryBuilder;
     }
 

+ 25 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java

@@ -61,6 +61,14 @@ public class SemanticKnnVectorQueryRewriteInterceptorTests extends ESTestCase {
         QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
         QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY);
         KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
+        if (randomBoolean()) {
+            float boost = randomFloatBetween(1, 10, randomBoolean());
+            original.boost(boost);
+        }
+        if (randomBoolean()) {
+            String queryName = randomAlphaOfLength(5);
+            original.queryName(queryName);
+        }
         testRewrittenInferenceQuery(context, original);
     }
 
@@ -72,6 +80,14 @@ public class SemanticKnnVectorQueryRewriteInterceptorTests extends ESTestCase {
         QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
         QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY);
         KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
+        if (randomBoolean()) {
+            float boost = randomFloatBetween(1, 10, randomBoolean());
+            original.boost(boost);
+        }
+        if (randomBoolean()) {
+            String queryName = randomAlphaOfLength(5);
+            original.queryName(queryName);
+        }
         testRewrittenInferenceQuery(context, original);
     }
 
@@ -82,14 +98,23 @@ public class SemanticKnnVectorQueryRewriteInterceptorTests extends ESTestCase {
             rewritten instanceof InterceptedQueryBuilderWrapper
         );
         InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
+        assertEquals(original.boost(), intercepted.boost(), 0.0f);
+        assertEquals(original.queryName(), intercepted.queryName());
         assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
+
         NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
+        assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
+        assertEquals(original.queryName(), nestedQueryBuilder.queryName());
         assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
+
         QueryBuilder innerQuery = nestedQueryBuilder.query();
         assertTrue(innerQuery instanceof KnnVectorQueryBuilder);
         KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery;
+        assertEquals(1.0f, knnVectorQueryBuilder.boost(), 0.0f);
+        assertNull(knnVectorQueryBuilder.queryName());
         assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName());
         assertTrue(knnVectorQueryBuilder.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder);
+
         TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) knnVectorQueryBuilder
             .queryVectorBuilder();
         assertEquals(QUERY, textEmbeddingQueryVectorBuilder.getModelText());

+ 25 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java

@@ -36,6 +36,8 @@ public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase {
 
     private static final String FIELD_NAME = "fieldName";
     private static final String VALUE = "value";
+    private static final String QUERY_NAME = "match_query";
+    private static final float BOOST = 5.0f;
 
     @Before
     public void setup() {
@@ -79,6 +81,29 @@ public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase {
         assertEquals(original, rewritten);
     }
 
+    public void testBoostAndQueryNameInMatchQueryRewrite() throws IOException {
+        Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
+            FIELD_NAME,
+            new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null)
+        );
+        QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
+        QueryBuilder original = createTestQueryBuilder();
+        original.boost(BOOST);
+        original.queryName(QUERY_NAME);
+        QueryBuilder rewritten = original.rewrite(context);
+        assertTrue(
+            "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
+            rewritten instanceof InterceptedQueryBuilderWrapper
+        );
+        InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
+        assertEquals(BOOST, intercepted.boost(), 0.0f);
+        assertEquals(QUERY_NAME, intercepted.queryName());
+        assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder);
+        SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder;
+        assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName());
+        assertEquals(VALUE, semanticQueryBuilder.getQuery());
+    }
+
     private MatchQueryBuilder createTestQueryBuilder() {
         return new MatchQueryBuilder(FIELD_NAME, VALUE);
     }

+ 40 - 26
x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java

@@ -58,21 +58,15 @@ public class SemanticSparseVectorQueryRewriteInterceptorTests extends ESTestCase
         );
         QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
         QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
-        QueryBuilder rewritten = original.rewrite(context);
-        assertTrue(
-            "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
-            rewritten instanceof InterceptedQueryBuilderWrapper
-        );
-        InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
-        assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
-        NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
-        assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
-        QueryBuilder innerQuery = nestedQueryBuilder.query();
-        assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
-        SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
-        assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
-        assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
-        assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
+        if (randomBoolean()) {
+            float boost = randomFloatBetween(1, 10, randomBoolean());
+            original.boost(boost);
+        }
+        if (randomBoolean()) {
+            String queryName = randomAlphaOfLength(5);
+            original.queryName(queryName);
+        }
+        testRewrittenInferenceQuery(context, original);
     }
 
     public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
@@ -82,32 +76,52 @@ public class SemanticSparseVectorQueryRewriteInterceptorTests extends ESTestCase
         );
         QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
         QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
+        if (randomBoolean()) {
+            float boost = randomFloatBetween(1, 10, randomBoolean());
+            original.boost(boost);
+        }
+        if (randomBoolean()) {
+            String queryName = randomAlphaOfLength(5);
+            original.queryName(queryName);
+        }
+        testRewrittenInferenceQuery(context, original);
+    }
+
+    public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
+        QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
+        QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
+        QueryBuilder rewritten = original.rewrite(context);
+        assertTrue(
+            "Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
+            rewritten instanceof SparseVectorQueryBuilder
+        );
+        assertEquals(original, rewritten);
+    }
+
+    private void testRewrittenInferenceQuery(QueryRewriteContext context, QueryBuilder original) throws IOException {
         QueryBuilder rewritten = original.rewrite(context);
         assertTrue(
             "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
             rewritten instanceof InterceptedQueryBuilderWrapper
         );
         InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
+        assertEquals(original.boost(), intercepted.boost(), 0.0f);
+        assertEquals(original.queryName(), intercepted.queryName());
+
         assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
         NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
         assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
+        assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
+        assertEquals(original.queryName(), nestedQueryBuilder.queryName());
+
         QueryBuilder innerQuery = nestedQueryBuilder.query();
         assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
         SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
         assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
         assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
         assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
-    }
-
-    public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
-        QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
-        QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
-        QueryBuilder rewritten = original.rewrite(context);
-        assertTrue(
-            "Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
-            rewritten instanceof SparseVectorQueryBuilder
-        );
-        assertEquals(original, rewritten);
+        assertEquals(1.0f, sparseVectorQueryBuilder.boost(), 0.0f);
+        assertNull(sparseVectorQueryBuilder.queryName());
     }
 
     private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {

+ 123 - 0
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/45_semantic_text_match.yml

@@ -277,3 +277,126 @@ setup:
                 query: "inference test"
 
   - match: { hits.total.value: 0 }
+
+---
+"Apply boost and query name on single index":
+  - requires:
+      cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
+      reason: fix boosting and query name for semantic text match queries.
+
+  - skip:
+      features: [ "headers", "close_to" ]
+
+  - do:
+      index:
+        index: test-sparse-index
+        id: doc_1
+        body:
+          inference_field: [ "It was a beautiful game", "Very competitive" ]
+          non_inference_field: "non inference test"
+        refresh: true
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-sparse-index
+        body:
+          query:
+            match:
+              inference_field:
+                query: "soccer"
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - close_to: { hits.hits.0._score: { value: 5.700229E18, error: 1e15 } }
+  - not_exists: hits.hits.0.matched_queries
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-sparse-index
+        body:
+          query:
+            match:
+              inference_field:
+                query: "soccer"
+                boost: 5.0
+                _name: i-like-naming-my-queries
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - close_to: { hits.hits.0._score: { value: 2.8501142E19, error: 1e16 } }
+  - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] }
+
+---
+"Apply boost and query name on multiple indices":
+  - requires:
+      cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
+      reason: fix boosting and query name for semantic text match queries.
+
+  - skip:
+      features: [ "headers", "close_to" ]
+
+  - do:
+      index:
+        index: test-sparse-index
+        id: doc_1
+        body:
+          inference_field: [ "It was a beautiful game", "Very competitive" ]
+          non_inference_field: "non inference test"
+        refresh: true
+
+  - do:
+      index:
+        index: test-text-only-index
+        id: doc_2
+        body:
+          inference_field: [ "It was a beautiful game", "Very competitive" ]
+          non_inference_field: "non inference test"
+        refresh: true
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-sparse-index,test-text-only-index
+        body:
+          query:
+            match:
+              inference_field:
+                query: "beautiful"
+
+  - match: { hits.total.value: 2 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+  - close_to: { hits.hits.0._score: { value: 1.1140361E19, error: 1e16 } }
+  - not_exists: hits.hits.0.matched_queries
+  - close_to: { hits.hits.1._score: { value: 0.2876821, error: 1e-4 } }
+  - not_exists: hits.hits.1.matched_queries
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-sparse-index,test-text-only-index
+        body:
+          query:
+            match:
+              inference_field:
+                query: "beautiful"
+                boost: 5.0
+                _name: i-like-naming-my-queries
+
+  - match: { hits.total.value: 2 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+  - close_to: { hits.hits.0._score: { value: 5.5701804E19, error: 1e16 } }
+  - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] }
+  - close_to: { hits.hits.1._score: { value: 1.4384103, error: 1e-4 } }
+  - match: { hits.hits.1.matched_queries: [ "i-like-naming-my-queries" ] }

+ 97 - 0
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml

@@ -247,3 +247,100 @@ setup:
 
   - match: { hits.total.value: 2 }
 
+---
+"Apply boost and query name on single index":
+  - requires:
+      cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
+      reason: fix boosting and query name for semantic text sparse vector queries.
+
+  - skip:
+      features: [ "headers", "close_to" ]
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              query: "inference test"
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - close_to: { hits.hits.0._score: { value: 3.7837332E17, error: 1e14 } }
+  - not_exists: hits.hits.0.matched_queries
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              query: "inference test"
+              boost: 5.0
+              _name: i-like-naming-my-queries
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - close_to: { hits.hits.0._score: { value: 1.8918664E18, error: 1e15 } }
+  - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] }
+
+---
+"Apply boost and query name on multiple indices":
+  - requires:
+      cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
+      reason: fix boosting and query name for semantic text sparse vector queries.
+
+  - skip:
+      features: [ "headers", "close_to" ]
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index,test-sparse-vector-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              query: "inference test"
+              inference_id: sparse-inference-id
+
+  - match: { hits.total.value: 2 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+  - close_to: { hits.hits.0._score: { value: 3.7837332E17, error: 1e14 } }
+  - not_exists: hits.hits.0.matched_queries
+  - close_to: { hits.hits.1._score: { value: 7.314424E8, error: 1e5 } }
+  - not_exists: hits.hits.1.matched_queries
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index,test-sparse-vector-index
+        body:
+          query:
+            sparse_vector:
+              field: inference_field
+              query: "inference test"
+              inference_id: sparse-inference-id
+              boost: 5.0
+              _name: i-like-naming-my-queries
+
+  - match: { hits.total.value: 2 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+  - close_to: { hits.hits.0._score: { value: 1.8918664E18, error: 1e15 } }
+  - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] }
+  - close_to: { hits.hits.1._score: { value: 3.657212E9, error: 1e6 } }
+  - match: { hits.hits.1.matched_queries: [ "i-like-naming-my-queries" ] }

+ 112 - 0
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml

@@ -404,4 +404,116 @@ setup:
 
   - match: { hits.total.value: 4 }
 
+---
+"Apply boost and query name on single index":
+  - requires:
+      cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
+      reason: fix boosting and query name for semantic text knn queries.
+
+  - skip:
+      features: [ "headers", "close_to" ]
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index
+        body:
+          query:
+            knn:
+              field: inference_field
+              k: 2
+              num_candidates: 100
+              query_vector_builder:
+                text_embedding:
+                  model_text: test
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - close_to: { hits.hits.0._score: { value: 0.9990483, error: 1e-4 } }
+  - not_exists: hits.hits.0.matched_queries
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index
+        body:
+          query:
+            knn:
+              field: inference_field
+              k: 2
+              num_candidates: 100
+              query_vector_builder:
+                text_embedding:
+                  model_text: test
+              boost: 5.0
+              _name: i-like-naming-my-queries
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - close_to: { hits.hits.0._score: { value: 4.9952416, error: 1e-3 } }
+  - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] }
 
+---
+"Apply boost and query name on multiple indices":
+  - requires:
+      cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
+      reason: fix boosting and query name for semantic text knn queries.
+
+  - skip:
+      features: [ "headers", "close_to" ]
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index,test-dense-vector-index
+        body:
+          query:
+            knn:
+              field: inference_field
+              k: 2
+              num_candidates: 100
+              query_vector_builder:
+                text_embedding:
+                  model_text: test
+                  model_id: dense-inference-id
+
+  - match: { hits.total.value: 2 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_3" }
+  - close_to: { hits.hits.0._score: { value: 0.9990483, error: 1e-4 } }
+  - not_exists: hits.hits.0.matched_queries
+  - close_to: { hits.hits.1._score: { value: 0.9439374, error: 1e-4 } }
+  - not_exists: hits.hits.1.matched_queries
+
+  - do:
+      headers:
+        # Force JSON content type so that we use a parser that interprets the floating-point score as a double
+        Content-Type: application/json
+      search:
+        index: test-semantic-text-index,test-dense-vector-index
+        body:
+          query:
+            knn:
+              field: inference_field
+              k: 2
+              num_candidates: 100
+              query_vector_builder:
+                text_embedding:
+                  model_text: test
+                  model_id: dense-inference-id
+              boost: 5.0
+              _name: i-like-naming-my-queries
+
+  - match: { hits.total.value: 2 }
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_3" }
+  - close_to: { hits.hits.0._score: { value: 4.9952416, error: 1e-3 } }
+  - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] }
+  - close_to: { hits.hits.1._score: { value: 4.719687, error: 1e-3 } }
+  - match: { hits.hits.1.matched_queries: [ "i-like-naming-my-queries" ] }