|
@@ -19,10 +19,12 @@
|
|
|
|
|
|
package io.milvus.client;
|
|
|
|
|
|
+import com.google.common.util.concurrent.ListenableFuture;
|
|
|
import org.apache.commons.text.RandomStringGenerator;
|
|
|
|
|
|
import java.nio.ByteBuffer;
|
|
|
import java.util.*;
|
|
|
+import java.util.concurrent.ExecutionException;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
import java.util.stream.Collectors;
|
|
|
import java.util.stream.DoubleStream;
|
|
@@ -361,6 +363,48 @@ class MilvusClientTest {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ @org.junit.jupiter.api.Test
|
|
|
+ void searchAsync() throws ExecutionException, InterruptedException {
|
|
|
+ List<List<Float>> vectors = generateFloatVectors(size, dimension);
|
|
|
+ vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
|
|
|
+ InsertParam insertParam =
|
|
|
+ new InsertParam.Builder(randomCollectionName).withFloatVectors(vectors).build();
|
|
|
+ InsertResponse insertResponse = client.insert(insertParam);
|
|
|
+ assertTrue(insertResponse.ok());
|
|
|
+ List<Long> vectorIds = insertResponse.getVectorIds();
|
|
|
+ assertEquals(size, vectorIds.size());
|
|
|
+
|
|
|
+ assertTrue(client.flush(randomCollectionName).ok());
|
|
|
+
|
|
|
+ final int searchSize = 5;
|
|
|
+ List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
|
|
|
+
|
|
|
+ final long topK = 10;
|
|
|
+ SearchParam searchParam =
|
|
|
+ new SearchParam.Builder(randomCollectionName)
|
|
|
+ .withFloatVectors(vectorsToSearch)
|
|
|
+ .withTopK(topK)
|
|
|
+ .withParamsInJson("{\"nprobe\": 20}")
|
|
|
+ .build();
|
|
|
+ ListenableFuture<SearchResponse> searchResponseFuture = client.searchAsync(searchParam);
|
|
|
+ SearchResponse searchResponse = searchResponseFuture.get();
|
|
|
+ assertTrue(searchResponse.ok());
|
|
|
+ List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
|
|
|
+ assertEquals(searchSize, resultIdsList.size());
|
|
|
+ List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
|
|
|
+ assertEquals(searchSize, resultDistancesList.size());
|
|
|
+ List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
|
|
|
+ assertEquals(searchSize, queryResultsList.size());
|
|
|
+ final double epsilon = 0.001;
|
|
|
+ for (int i = 0; i < searchSize; i++) {
|
|
|
+ SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
|
|
|
+ assertEquals(vectorIds.get(i), firstQueryResult.getVectorId());
|
|
|
+ assertEquals(vectorIds.get(i), resultIdsList.get(i).get(0));
|
|
|
+ assertTrue(Math.abs(1 - firstQueryResult.getDistance()) < epsilon);
|
|
|
+ assertTrue(Math.abs(1 - resultDistancesList.get(i).get(0)) < epsilon);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
@org.junit.jupiter.api.Test
|
|
|
void searchBinary() {
|
|
|
final long binaryDimension = 10000;
|