|
@@ -58,21 +58,15 @@ public class SemanticSparseVectorQueryRewriteInterceptorTests extends ESTestCase
|
|
|
);
|
|
|
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
|
|
|
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
|
|
|
- QueryBuilder rewritten = original.rewrite(context);
|
|
|
- assertTrue(
|
|
|
- "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
|
|
|
- rewritten instanceof InterceptedQueryBuilderWrapper
|
|
|
- );
|
|
|
- InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
|
|
|
- assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
|
|
|
- NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
|
|
|
- assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
|
|
|
- QueryBuilder innerQuery = nestedQueryBuilder.query();
|
|
|
- assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
|
|
|
- SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
|
|
|
- assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
|
|
|
- assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
|
|
|
- assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
|
|
|
+ if (randomBoolean()) {
|
|
|
+ float boost = randomFloatBetween(1, 10, randomBoolean());
|
|
|
+ original.boost(boost);
|
|
|
+ }
|
|
|
+ if (randomBoolean()) {
|
|
|
+ String queryName = randomAlphaOfLength(5);
|
|
|
+ original.queryName(queryName);
|
|
|
+ }
|
|
|
+ testRewrittenInferenceQuery(context, original);
|
|
|
}
|
|
|
|
|
|
public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
|
|
@@ -82,32 +76,52 @@ public class SemanticSparseVectorQueryRewriteInterceptorTests extends ESTestCase
|
|
|
);
|
|
|
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
|
|
|
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
|
|
|
+ if (randomBoolean()) {
|
|
|
+ float boost = randomFloatBetween(1, 10, randomBoolean());
|
|
|
+ original.boost(boost);
|
|
|
+ }
|
|
|
+ if (randomBoolean()) {
|
|
|
+ String queryName = randomAlphaOfLength(5);
|
|
|
+ original.queryName(queryName);
|
|
|
+ }
|
|
|
+ testRewrittenInferenceQuery(context, original);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
|
|
|
+ QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
|
|
|
+ QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
|
|
|
+ QueryBuilder rewritten = original.rewrite(context);
|
|
|
+ assertTrue(
|
|
|
+ "Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
|
|
|
+ rewritten instanceof SparseVectorQueryBuilder
|
|
|
+ );
|
|
|
+ assertEquals(original, rewritten);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void testRewrittenInferenceQuery(QueryRewriteContext context, QueryBuilder original) throws IOException {
|
|
|
QueryBuilder rewritten = original.rewrite(context);
|
|
|
assertTrue(
|
|
|
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
|
|
|
rewritten instanceof InterceptedQueryBuilderWrapper
|
|
|
);
|
|
|
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
|
|
|
+ assertEquals(original.boost(), intercepted.boost(), 0.0f);
|
|
|
+ assertEquals(original.queryName(), intercepted.queryName());
|
|
|
+
|
|
|
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
|
|
|
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
|
|
|
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
|
|
|
+ assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
|
|
|
+ assertEquals(original.queryName(), nestedQueryBuilder.queryName());
|
|
|
+
|
|
|
QueryBuilder innerQuery = nestedQueryBuilder.query();
|
|
|
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
|
|
|
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
|
|
|
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
|
|
|
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
|
|
|
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
|
|
|
- }
|
|
|
-
|
|
|
- public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
|
|
|
- QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
|
|
|
- QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
|
|
|
- QueryBuilder rewritten = original.rewrite(context);
|
|
|
- assertTrue(
|
|
|
- "Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
|
|
|
- rewritten instanceof SparseVectorQueryBuilder
|
|
|
- );
|
|
|
- assertEquals(original, rewritten);
|
|
|
+ assertEquals(1.0f, sparseVectorQueryBuilder.boost(), 0.0f);
|
|
|
+ assertNull(sparseVectorQueryBuilder.queryName());
|
|
|
}
|
|
|
|
|
|
private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
|