|  | @@ -13,6 +13,8 @@ import org.elasticsearch.common.settings.Settings;
 | 
	
		
			
				|  |  |  import org.elasticsearch.index.query.QueryBuilders;
 | 
	
		
			
				|  |  |  import org.elasticsearch.indices.TermsLookup;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.SearchHit;
 | 
	
		
			
				|  |  | +import org.elasticsearch.search.aggregations.AggregationBuilders;
 | 
	
		
			
				|  |  | +import org.elasticsearch.search.aggregations.metrics.InternalStats;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.vectors.KnnSearchBuilder;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
 | 
	
		
			
				|  |  |  import org.elasticsearch.test.ESSingleNodeTestCase;
 | 
	
	
		
			
				|  | @@ -20,8 +22,11 @@ import org.elasticsearch.xcontent.XContentBuilder;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xcontent.XContentFactory;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.io.IOException;
 | 
	
		
			
				|  |  | +import java.util.List;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
 | 
	
		
			
				|  |  | +import static org.hamcrest.Matchers.equalTo;
 | 
	
		
			
				|  |  | +import static org.hamcrest.Matchers.greaterThan;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
 | 
	
		
			
				|  |  |      private static final int VECTOR_DIMENSION = 10;
 | 
	
	
		
			
				|  | @@ -56,7 +61,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
 | 
	
		
			
				|  |  |          float[] queryVector = randomVector();
 | 
	
		
			
				|  |  |          KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f);
 | 
	
		
			
				|  |  |          SearchResponse response = client().prepareSearch("index")
 | 
	
		
			
				|  |  | -            .setKnnSearch(knnSearch)
 | 
	
		
			
				|  |  | +            .setKnnSearch(List.of(knnSearch))
 | 
	
		
			
				|  |  |              .setQuery(QueryBuilders.matchQuery("text", "goodnight"))
 | 
	
		
			
				|  |  |              .addFetchField("*")
 | 
	
		
			
				|  |  |              .setSize(10)
 | 
	
	
		
			
				|  | @@ -101,7 +106,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
 | 
	
		
			
				|  |  |          KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).addFilterQuery(
 | 
	
		
			
				|  |  |              QueryBuilders.termsQuery("field", "second")
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  | -        SearchResponse response = client().prepareSearch("index").setKnnSearch(knnSearch).addFetchField("*").setSize(10).get();
 | 
	
		
			
				|  |  | +        SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10).get();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          assertHitCount(response, 5);
 | 
	
		
			
				|  |  |          assertEquals(5, response.getHits().getHits().length);
 | 
	
	
		
			
				|  | @@ -144,12 +149,145 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
 | 
	
		
			
				|  |  |          KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).addFilterQuery(
 | 
	
		
			
				|  |  |              QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field"))
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  | -        SearchResponse response = client().prepareSearch("index").setKnnSearch(knnSearch).setSize(10).get();
 | 
	
		
			
				|  |  | +        SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10).get();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          assertHitCount(response, 5);
 | 
	
		
			
				|  |  |          assertEquals(5, response.getHits().getHits().length);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    public void testMultiKnnClauses() throws IOException {
 | 
	
		
			
				|  |  | +        // This tests the recall from vectors being searched in different docs
 | 
	
		
			
				|  |  | +        int numShards = 1 + randomInt(3);
 | 
	
		
			
				|  |  | +        Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        XContentBuilder builder = XContentFactory.jsonBuilder()
 | 
	
		
			
				|  |  | +            .startObject()
 | 
	
		
			
				|  |  | +            .startObject("properties")
 | 
	
		
			
				|  |  | +            .startObject("vector")
 | 
	
		
			
				|  |  | +            .field("type", "dense_vector")
 | 
	
		
			
				|  |  | +            .field("dims", VECTOR_DIMENSION)
 | 
	
		
			
				|  |  | +            .field("index", true)
 | 
	
		
			
				|  |  | +            .field("similarity", "l2_norm")
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +            .startObject("vector_2")
 | 
	
		
			
				|  |  | +            .field("type", "dense_vector")
 | 
	
		
			
				|  |  | +            .field("dims", VECTOR_DIMENSION)
 | 
	
		
			
				|  |  | +            .field("index", true)
 | 
	
		
			
				|  |  | +            .field("similarity", "l2_norm")
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +            .startObject("text")
 | 
	
		
			
				|  |  | +            .field("type", "text")
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +            .startObject("number")
 | 
	
		
			
				|  |  | +            .field("type", "long")
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +            .endObject();
 | 
	
		
			
				|  |  | +        createIndex("index", indexSettings, builder);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        for (int doc = 0; doc < 10; doc++) {
 | 
	
		
			
				|  |  | +            client().prepareIndex("index").setSource("vector", randomVector(), "text", "hello world", "number", 1).get();
 | 
	
		
			
				|  |  | +            client().prepareIndex("index").setSource("vector_2", randomVector(), "text", "hello world", "number", 2).get();
 | 
	
		
			
				|  |  | +            client().prepareIndex("index").setSource("text", "goodnight world", "number", 3).get();
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        client().admin().indices().prepareRefresh("index").get();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        float[] queryVector = randomVector();
 | 
	
		
			
				|  |  | +        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f);
 | 
	
		
			
				|  |  | +        KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50).boost(10.0f);
 | 
	
		
			
				|  |  | +        SearchResponse response = client().prepareSearch("index")
 | 
	
		
			
				|  |  | +            .setKnnSearch(List.of(knnSearch, knnSearch2))
 | 
	
		
			
				|  |  | +            .setQuery(QueryBuilders.matchQuery("text", "goodnight"))
 | 
	
		
			
				|  |  | +            .addFetchField("*")
 | 
	
		
			
				|  |  | +            .setSize(10)
 | 
	
		
			
				|  |  | +            .addAggregation(AggregationBuilders.stats("stats").field("number"))
 | 
	
		
			
				|  |  | +            .get();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // The total hits is k plus the number of text matches
 | 
	
		
			
				|  |  | +        assertHitCount(response, 20);
 | 
	
		
			
				|  |  | +        assertEquals(10, response.getHits().getHits().length);
 | 
	
		
			
				|  |  | +        InternalStats agg = response.getAggregations().get("stats");
 | 
	
		
			
				|  |  | +        assertThat(agg.getCount(), equalTo(20L));
 | 
	
		
			
				|  |  | +        assertThat(agg.getMax(), equalTo(3.0));
 | 
	
		
			
				|  |  | +        assertThat(agg.getMin(), equalTo(1.0));
 | 
	
		
			
				|  |  | +        assertThat(agg.getAvg(), equalTo(2.25));
 | 
	
		
			
				|  |  | +        assertThat(agg.getSum(), equalTo(45.0));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Because of the boost, vector_2 results should appear first
 | 
	
		
			
				|  |  | +        assertNotNull(response.getHits().getAt(0).field("vector_2"));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testMultiKnnClausesSameDoc() throws IOException {
 | 
	
		
			
				|  |  | +        int numShards = 1 + randomInt(3);
 | 
	
		
			
				|  |  | +        Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        XContentBuilder builder = XContentFactory.jsonBuilder()
 | 
	
		
			
				|  |  | +            .startObject()
 | 
	
		
			
				|  |  | +            .startObject("properties")
 | 
	
		
			
				|  |  | +            .startObject("vector")
 | 
	
		
			
				|  |  | +            .field("type", "dense_vector")
 | 
	
		
			
				|  |  | +            .field("dims", VECTOR_DIMENSION)
 | 
	
		
			
				|  |  | +            .field("index", true)
 | 
	
		
			
				|  |  | +            .field("similarity", "l2_norm")
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +            .startObject("vector_2")
 | 
	
		
			
				|  |  | +            .field("type", "dense_vector")
 | 
	
		
			
				|  |  | +            .field("dims", VECTOR_DIMENSION)
 | 
	
		
			
				|  |  | +            .field("index", true)
 | 
	
		
			
				|  |  | +            .field("similarity", "l2_norm")
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +            .startObject("number")
 | 
	
		
			
				|  |  | +            .field("type", "long")
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +            .endObject();
 | 
	
		
			
				|  |  | +        createIndex("index", indexSettings, builder);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        for (int doc = 0; doc < 10; doc++) {
 | 
	
		
			
				|  |  | +            // Make them have hte same vector. This will allow us to test the recall is the same but scores take into account both fields
 | 
	
		
			
				|  |  | +            float[] vector = randomVector();
 | 
	
		
			
				|  |  | +            client().prepareIndex("index").setSource("vector", vector, "vector_2", vector, "number", doc).get();
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        client().admin().indices().prepareRefresh("index").get();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        float[] queryVector = randomVector();
 | 
	
		
			
				|  |  | +        // Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched
 | 
	
		
			
				|  |  | +        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50);
 | 
	
		
			
				|  |  | +        KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50);
 | 
	
		
			
				|  |  | +        SearchResponse responseOneKnn = client().prepareSearch("index")
 | 
	
		
			
				|  |  | +            .setKnnSearch(List.of(knnSearch))
 | 
	
		
			
				|  |  | +            .addFetchField("*")
 | 
	
		
			
				|  |  | +            .setSize(10)
 | 
	
		
			
				|  |  | +            .addAggregation(AggregationBuilders.stats("stats").field("number"))
 | 
	
		
			
				|  |  | +            .get();
 | 
	
		
			
				|  |  | +        SearchResponse responseBothKnn = client().prepareSearch("index")
 | 
	
		
			
				|  |  | +            .setKnnSearch(List.of(knnSearch, knnSearch2))
 | 
	
		
			
				|  |  | +            .addFetchField("*")
 | 
	
		
			
				|  |  | +            .setSize(10)
 | 
	
		
			
				|  |  | +            .addAggregation(AggregationBuilders.stats("stats").field("number"))
 | 
	
		
			
				|  |  | +            .get();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // The total hits is k matched docs
 | 
	
		
			
				|  |  | +        assertHitCount(responseOneKnn, 5);
 | 
	
		
			
				|  |  | +        assertHitCount(responseBothKnn, 5);
 | 
	
		
			
				|  |  | +        assertEquals(5, responseOneKnn.getHits().getHits().length);
 | 
	
		
			
				|  |  | +        assertEquals(5, responseBothKnn.getHits().getHits().length);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        for (int i = 0; i < responseOneKnn.getHits().getHits().length; i++) {
 | 
	
		
			
				|  |  | +            SearchHit oneHit = responseOneKnn.getHits().getHits()[i];
 | 
	
		
			
				|  |  | +            SearchHit bothHit = responseBothKnn.getHits().getHits()[i];
 | 
	
		
			
				|  |  | +            assertThat(bothHit.getId(), equalTo(oneHit.getId()));
 | 
	
		
			
				|  |  | +            assertThat(bothHit.getScore(), greaterThan(oneHit.getScore()));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        InternalStats oneAgg = responseOneKnn.getAggregations().get("stats");
 | 
	
		
			
				|  |  | +        InternalStats bothAgg = responseBothKnn.getAggregations().get("stats");
 | 
	
		
			
				|  |  | +        assertThat(bothAgg.getCount(), equalTo(oneAgg.getCount()));
 | 
	
		
			
				|  |  | +        assertThat(bothAgg.getAvg(), equalTo(oneAgg.getAvg()));
 | 
	
		
			
				|  |  | +        assertThat(bothAgg.getMax(), equalTo(oneAgg.getMax()));
 | 
	
		
			
				|  |  | +        assertThat(bothAgg.getSum(), equalTo(oneAgg.getSum()));
 | 
	
		
			
				|  |  | +        assertThat(bothAgg.getMin(), equalTo(oneAgg.getMin()));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      public void testKnnFilteredAlias() throws IOException {
 | 
	
		
			
				|  |  |          int numShards = 1 + randomInt(3);
 | 
	
		
			
				|  |  |          Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build();
 | 
	
	
		
			
				|  | @@ -184,7 +322,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          float[] queryVector = randomVector();
 | 
	
		
			
				|  |  |          KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50);
 | 
	
		
			
				|  |  | -        SearchResponse response = client().prepareSearch("test-alias").setKnnSearch(knnSearch).setSize(10).get();
 | 
	
		
			
				|  |  | +        SearchResponse response = client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10).get();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          assertHitCount(response, expectedHits);
 | 
	
		
			
				|  |  |          assertEquals(expectedHits, response.getHits().getHits().length);
 |