|
@@ -13,6 +13,8 @@ import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.index.query.QueryBuilders;
|
|
|
import org.elasticsearch.indices.TermsLookup;
|
|
|
import org.elasticsearch.search.SearchHit;
|
|
|
+import org.elasticsearch.search.aggregations.AggregationBuilders;
|
|
|
+import org.elasticsearch.search.aggregations.metrics.InternalStats;
|
|
|
import org.elasticsearch.search.vectors.KnnSearchBuilder;
|
|
|
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
|
|
import org.elasticsearch.test.ESSingleNodeTestCase;
|
|
@@ -20,8 +22,11 @@ import org.elasticsearch.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.xcontent.XContentFactory;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
+import java.util.List;
|
|
|
|
|
|
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
|
|
|
+import static org.hamcrest.Matchers.equalTo;
|
|
|
+import static org.hamcrest.Matchers.greaterThan;
|
|
|
|
|
|
public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
|
|
|
private static final int VECTOR_DIMENSION = 10;
|
|
@@ -56,7 +61,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
|
|
|
float[] queryVector = randomVector();
|
|
|
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f);
|
|
|
SearchResponse response = client().prepareSearch("index")
|
|
|
- .setKnnSearch(knnSearch)
|
|
|
+ .setKnnSearch(List.of(knnSearch))
|
|
|
.setQuery(QueryBuilders.matchQuery("text", "goodnight"))
|
|
|
.addFetchField("*")
|
|
|
.setSize(10)
|
|
@@ -101,7 +106,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
|
|
|
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).addFilterQuery(
|
|
|
QueryBuilders.termsQuery("field", "second")
|
|
|
);
|
|
|
- SearchResponse response = client().prepareSearch("index").setKnnSearch(knnSearch).addFetchField("*").setSize(10).get();
|
|
|
+ SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10).get();
|
|
|
|
|
|
assertHitCount(response, 5);
|
|
|
assertEquals(5, response.getHits().getHits().length);
|
|
@@ -144,12 +149,145 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
|
|
|
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).addFilterQuery(
|
|
|
QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field"))
|
|
|
);
|
|
|
- SearchResponse response = client().prepareSearch("index").setKnnSearch(knnSearch).setSize(10).get();
|
|
|
+ SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10).get();
|
|
|
|
|
|
assertHitCount(response, 5);
|
|
|
assertEquals(5, response.getHits().getHits().length);
|
|
|
}
|
|
|
|
|
|
+ public void testMultiKnnClauses() throws IOException {
|
|
|
+ // This tests the recall from vectors being searched in different docs
|
|
|
+ int numShards = 1 + randomInt(3);
|
|
|
+ Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build();
|
|
|
+
|
|
|
+ XContentBuilder builder = XContentFactory.jsonBuilder()
|
|
|
+ .startObject()
|
|
|
+ .startObject("properties")
|
|
|
+ .startObject("vector")
|
|
|
+ .field("type", "dense_vector")
|
|
|
+ .field("dims", VECTOR_DIMENSION)
|
|
|
+ .field("index", true)
|
|
|
+ .field("similarity", "l2_norm")
|
|
|
+ .endObject()
|
|
|
+ .startObject("vector_2")
|
|
|
+ .field("type", "dense_vector")
|
|
|
+ .field("dims", VECTOR_DIMENSION)
|
|
|
+ .field("index", true)
|
|
|
+ .field("similarity", "l2_norm")
|
|
|
+ .endObject()
|
|
|
+ .startObject("text")
|
|
|
+ .field("type", "text")
|
|
|
+ .endObject()
|
|
|
+ .startObject("number")
|
|
|
+ .field("type", "long")
|
|
|
+ .endObject()
|
|
|
+ .endObject()
|
|
|
+ .endObject();
|
|
|
+ createIndex("index", indexSettings, builder);
|
|
|
+
|
|
|
+ for (int doc = 0; doc < 10; doc++) {
|
|
|
+ client().prepareIndex("index").setSource("vector", randomVector(), "text", "hello world", "number", 1).get();
|
|
|
+ client().prepareIndex("index").setSource("vector_2", randomVector(), "text", "hello world", "number", 2).get();
|
|
|
+ client().prepareIndex("index").setSource("text", "goodnight world", "number", 3).get();
|
|
|
+ }
|
|
|
+ client().admin().indices().prepareRefresh("index").get();
|
|
|
+
|
|
|
+ float[] queryVector = randomVector();
|
|
|
+ KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f);
|
|
|
+ KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50).boost(10.0f);
|
|
|
+ SearchResponse response = client().prepareSearch("index")
|
|
|
+ .setKnnSearch(List.of(knnSearch, knnSearch2))
|
|
|
+ .setQuery(QueryBuilders.matchQuery("text", "goodnight"))
|
|
|
+ .addFetchField("*")
|
|
|
+ .setSize(10)
|
|
|
+ .addAggregation(AggregationBuilders.stats("stats").field("number"))
|
|
|
+ .get();
|
|
|
+
|
|
|
+ // The total hits is k plus the number of text matches
|
|
|
+ assertHitCount(response, 20);
|
|
|
+ assertEquals(10, response.getHits().getHits().length);
|
|
|
+ InternalStats agg = response.getAggregations().get("stats");
|
|
|
+ assertThat(agg.getCount(), equalTo(20L));
|
|
|
+ assertThat(agg.getMax(), equalTo(3.0));
|
|
|
+ assertThat(agg.getMin(), equalTo(1.0));
|
|
|
+ assertThat(agg.getAvg(), equalTo(2.25));
|
|
|
+ assertThat(agg.getSum(), equalTo(45.0));
|
|
|
+
|
|
|
+ // Because of the boost, vector_2 results should appear first
|
|
|
+ assertNotNull(response.getHits().getAt(0).field("vector_2"));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testMultiKnnClausesSameDoc() throws IOException {
|
|
|
+ int numShards = 1 + randomInt(3);
|
|
|
+ Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build();
|
|
|
+
|
|
|
+ XContentBuilder builder = XContentFactory.jsonBuilder()
|
|
|
+ .startObject()
|
|
|
+ .startObject("properties")
|
|
|
+ .startObject("vector")
|
|
|
+ .field("type", "dense_vector")
|
|
|
+ .field("dims", VECTOR_DIMENSION)
|
|
|
+ .field("index", true)
|
|
|
+ .field("similarity", "l2_norm")
|
|
|
+ .endObject()
|
|
|
+ .startObject("vector_2")
|
|
|
+ .field("type", "dense_vector")
|
|
|
+ .field("dims", VECTOR_DIMENSION)
|
|
|
+ .field("index", true)
|
|
|
+ .field("similarity", "l2_norm")
|
|
|
+ .endObject()
|
|
|
+ .startObject("number")
|
|
|
+ .field("type", "long")
|
|
|
+ .endObject()
|
|
|
+ .endObject()
|
|
|
+ .endObject();
|
|
|
+ createIndex("index", indexSettings, builder);
|
|
|
+
|
|
|
+ for (int doc = 0; doc < 10; doc++) {
|
|
|
+ // Make them have hte same vector. This will allow us to test the recall is the same but scores take into account both fields
|
|
|
+ float[] vector = randomVector();
|
|
|
+ client().prepareIndex("index").setSource("vector", vector, "vector_2", vector, "number", doc).get();
|
|
|
+ }
|
|
|
+ client().admin().indices().prepareRefresh("index").get();
|
|
|
+
|
|
|
+ float[] queryVector = randomVector();
|
|
|
+ // Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched
|
|
|
+ KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50);
|
|
|
+ KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50);
|
|
|
+ SearchResponse responseOneKnn = client().prepareSearch("index")
|
|
|
+ .setKnnSearch(List.of(knnSearch))
|
|
|
+ .addFetchField("*")
|
|
|
+ .setSize(10)
|
|
|
+ .addAggregation(AggregationBuilders.stats("stats").field("number"))
|
|
|
+ .get();
|
|
|
+ SearchResponse responseBothKnn = client().prepareSearch("index")
|
|
|
+ .setKnnSearch(List.of(knnSearch, knnSearch2))
|
|
|
+ .addFetchField("*")
|
|
|
+ .setSize(10)
|
|
|
+ .addAggregation(AggregationBuilders.stats("stats").field("number"))
|
|
|
+ .get();
|
|
|
+
|
|
|
+ // The total hits is k matched docs
|
|
|
+ assertHitCount(responseOneKnn, 5);
|
|
|
+ assertHitCount(responseBothKnn, 5);
|
|
|
+ assertEquals(5, responseOneKnn.getHits().getHits().length);
|
|
|
+ assertEquals(5, responseBothKnn.getHits().getHits().length);
|
|
|
+
|
|
|
+ for (int i = 0; i < responseOneKnn.getHits().getHits().length; i++) {
|
|
|
+ SearchHit oneHit = responseOneKnn.getHits().getHits()[i];
|
|
|
+ SearchHit bothHit = responseBothKnn.getHits().getHits()[i];
|
|
|
+ assertThat(bothHit.getId(), equalTo(oneHit.getId()));
|
|
|
+ assertThat(bothHit.getScore(), greaterThan(oneHit.getScore()));
|
|
|
+ }
|
|
|
+ InternalStats oneAgg = responseOneKnn.getAggregations().get("stats");
|
|
|
+ InternalStats bothAgg = responseBothKnn.getAggregations().get("stats");
|
|
|
+ assertThat(bothAgg.getCount(), equalTo(oneAgg.getCount()));
|
|
|
+ assertThat(bothAgg.getAvg(), equalTo(oneAgg.getAvg()));
|
|
|
+ assertThat(bothAgg.getMax(), equalTo(oneAgg.getMax()));
|
|
|
+ assertThat(bothAgg.getSum(), equalTo(oneAgg.getSum()));
|
|
|
+ assertThat(bothAgg.getMin(), equalTo(oneAgg.getMin()));
|
|
|
+ }
|
|
|
+
|
|
|
public void testKnnFilteredAlias() throws IOException {
|
|
|
int numShards = 1 + randomInt(3);
|
|
|
Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build();
|
|
@@ -184,7 +322,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
|
|
|
|
|
|
float[] queryVector = randomVector();
|
|
|
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50);
|
|
|
- SearchResponse response = client().prepareSearch("test-alias").setKnnSearch(knnSearch).setSize(10).get();
|
|
|
+ SearchResponse response = client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10).get();
|
|
|
|
|
|
assertHitCount(response, expectedHits);
|
|
|
assertEquals(expectedHits, response.getHits().getHits().length);
|