Browse Source

Fix vector search bug

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>
sahuang 4 years ago
parent
commit
6090ace479

+ 26 - 22
CHANGELOG.md

@@ -1,43 +1,49 @@
 # Changelog   
 
-## milvus-sdk-java 0.9.0 (2020-10-16)
+## v0.9.1 (2020-10-29)
+
+### Bug
+
+- [\#4086](https://github.com/milvus-io/milvus/issues/4086) - Fix vector search error when topK is small
+
+## v0.9.0 (2020-10-16)
 
 ### Feature
 
-- \#2976 Scalar-field filtering support
+- [\#2976](https://github.com/milvus-io/milvus/pull/2976) Scalar-field filtering support
 
 ### Improvement
 
-- \#134 - Simplify the client code
+- [\#134](https://github.com/milvus-io/milvus-sdk-java/pull/134) - Simplify the client code
 
-## milvus-sdk-java 0.8.5 (2020-08-26)
+## v0.8.5 (2020-08-26)
 
 ### Feature
 
-- \#128 - GRPC timeout support
-- \#129 - Support GRPC name resolver and load balancing
+- [\#128](https://github.com/milvus-io/milvus-sdk-java/pull/128) - GRPC timeout support
+- [\#129](https://github.com/milvus-io/milvus-sdk-java/pull/129) - Support GRPC name resolver and load balancing
 
-## milvus-sdk-java 0.8.3 (2020-07-15)
+## v0.8.3 (2020-07-15)
 
 ### Improvement
 
-- \#118 - Remove isConnect() API
+- [\#118](https://github.com/milvus-io/milvus-sdk-java/pull/118) - Remove isConnect() API
 
-## milvus-sdk-java 0.8.0 (2020-05-15)
+## v0.8.0 (2020-05-15)
 
 ### Feature
 
-- \#93 - Add/Improve getVectorByID, collectionInfo and hasPartition API
-- \#2295 - Rename SDK interfaces
+- [\#93](https://github.com/milvus-io/milvus-sdk-java/pull/93) - Add/Improve getVectorByID, collectionInfo and hasPartition API
+- [\#2295](https://github.com/milvus-io/milvus/issues/2295) - Rename SDK interfaces
 
-## milvus-sdk-java 0.7.0 (2020-04-15)
+## v0.7.0 (2020-04-15)
 
 ### Feature
 
-- \#261 - Integrate ANNOY into Milvus
-- \#1828 - Add searchAsync / createIndexAsync / insertAsync / flushAsync / compactAsync API
+- [\#261](https://github.com/milvus-io/milvus/issues/261) - Integrate ANNOY into Milvus
+- [\#1828](https://github.com/milvus-io/milvus/issues/1828) - Add searchAsync / createIndexAsync / insertAsync / flushAsync / compactAsync API
 
-## milvus-sdk-java 0.6.0 (2020-03-31)
+## v0.6.0 (2020-03-31)
 
 ### Bug
 
@@ -48,15 +54,13 @@
 
 - \#1603 - Add binary metrics: Substructure & Superstructure
 
-## milvus-sdk-java 0.5.0 (2020-03-11)
-
-## milvus-sdk-java 0.4.1 (2019-12-16)
+## v0.4.1 (2019-12-16)
 
 ### Bug
 
 - \#78 - Partition tag not working when searching
 
-## milvus-sdk-java 0.4.0 (2019-12-7)
+## v0.4.0 (2019-12-7)
 
 ### Bug
 
@@ -69,7 +73,7 @@
 - \#72 - Add more getters in ShowPartitionResponse
 - \#73 - Add @Deprecated for DateRanges in SearchParam
 
-## milvus-sdk-java 0.3.0 (2019-11-13)
+## v0.3.0 (2019-11-13)
 
 ### Bug
 
@@ -82,7 +86,7 @@
 - \#62 - Change GRPC proto (and related code) to increase search result's transmission speed
 - \#63 - Make some functions and constructors package-private if necessary
 
-## milvus-sdk-java 0.2.2 (2019-11-4)
+## v0.2.2 (2019-11-4)
 
 ### Improvement
 
@@ -90,7 +94,7 @@
 - \#51 - Change connect waitTime to timeout
 - \#52 - Change IVF_SQ8H to IVF_SQ8_H
 
-## milvus-sdk-java 0.2.0 (2019-10-21)
+## v0.2.0 (2019-10-21)
 
 ### Bug
 

+ 2 - 2
README.md

@@ -43,7 +43,7 @@ You can use **Apache Maven** or **Gradle**/**Grails** to download the SDK.
    - Gradle/Grails
 
         ```gradle
-        compile 'io.milvus:milvus-sdk-java:0.9.2'
+        compile group: 'io.milvus', name: 'milvus-sdk-java', version: '0.9.2'
         ```
 
 ### Examples
@@ -64,7 +64,7 @@ Please refer to [examples](https://github.com/milvus-io/milvus-sdk-java/tree/0.9
     ```
   This is because SLF4J jar files need to be added into your application's classpath. SLF4J is used by Java SDK for logging purpose.
   
-  To fix this issue, you can use **Apache Maven** or **Gradle**/**Grails** to download the required jar files.
+  To fix this issue, you can download the required jar files.
                                                                                                          
     - Apache Maven
     

+ 6 - 0
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -715,6 +715,12 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
         resultDistancesList.add(queryDistancesList.subList(i * topK, pos));
         resultFieldsMap.add(fieldsMap.subList(i * topK, pos));
       }
+    } else {
+      for (int i = 0; i < numQueries; i++) {
+        resultIdsList.add(new ArrayList<>());
+        resultDistancesList.add(new ArrayList<>());
+        resultFieldsMap.add(new ArrayList<>());
+      }
     }
 
     return new SearchResult(numQueries, topK, resultIdsList, resultDistancesList, resultFieldsMap);

+ 48 - 0
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -608,6 +608,54 @@ class MilvusClientTest {
     client.dropCollection(binaryCollectionName);
   }
 
+  @org.junit.jupiter.api.Test
+  void searchEmptyResult() {
+    List<Long> intValues = new ArrayList<>(size);
+    List<Float> floatValues = new ArrayList<>(size);
+    List<List<Float>> vectors = generateFloatVectors(size, dimension);
+    for (int i = 0; i < size; i++) {
+      intValues.add((long) i);
+      floatValues.add((float) i);
+    }
+    vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
+
+    List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
+    InsertParam insertParam =
+        InsertParam.create(randomCollectionName)
+            .addField("int64", DataType.INT64, intValues)
+            .addField("float", DataType.FLOAT, floatValues)
+            .addVectorField("float_vec", DataType.VECTOR_FLOAT, vectors)
+            .setEntityIds(insertIds);
+    List<Long> entityIds = client.insert(insertParam);
+    assertEquals(size, entityIds.size());
+
+    client.flush(randomCollectionName);
+
+    final int searchSize = 5;
+    List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
+
+    final long topK = 10;
+    final String tag = "tag";
+    List<String> partitionTags = new ArrayList<>();
+    partitionTags.add(tag);
+    SearchParam searchParam =
+        SearchParam.create(randomCollectionName)
+            .setDsl(generateComplexDSL(topK, vectorsToSearch.toString()))
+            .setPartitionTags(partitionTags)
+            .setParamsInJson(
+                new JsonBuilder()
+                    .param("fields", new ArrayList<>(Arrays.asList("int64", "float_vec")))
+                    .build());
+    SearchResult searchResult = client.search(searchParam);
+    assertEquals(searchSize, searchResult.getQueryResultsList().size());
+    List<List<Long>> resultIdsList = searchResult.getResultIdsList();
+    assertEquals(searchSize, resultIdsList.size());
+    assertEquals(0, resultIdsList.get(0).size());
+    List<List<Float>> resultDistancesList = searchResult.getResultDistancesList();
+    assertEquals(searchSize, resultDistancesList.size());
+    assertEquals(0, resultDistancesList.get(0).size());
+  }
+
   @org.junit.jupiter.api.Test
   void getCollectionInfo() {
     CollectionMapping collectionMapping = client.getCollectionInfo(randomCollectionName);

+ 28 - 0
src/test/java/io/milvus/client/dsl/SearchDslTest.java

@@ -1,6 +1,7 @@
 package io.milvus.client.dsl;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 import io.milvus.client.ConnectParam;
@@ -11,6 +12,7 @@ import io.milvus.client.MilvusClient;
 import io.milvus.client.MilvusGrpcClient;
 import io.milvus.client.SearchParam;
 import io.milvus.client.SearchResult;
+import io.milvus.client.exception.InvalidDsl;
 import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.List;
@@ -22,6 +24,7 @@ import java.util.stream.LongStream;
 import java.util.stream.Stream;
 import org.apache.commons.lang3.RandomUtils;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.function.Executable;
 import org.testcontainers.containers.GenericContainer;
 import org.testcontainers.junit.jupiter.Container;
 import org.testcontainers.junit.jupiter.Testcontainers;
@@ -300,4 +303,29 @@ public class SearchDslTest {
                   .collect(Collectors.toList()));
         });
   }
+
+  @Test
+  public void testMultipleVectorsQuery() {
+    withMilvusServiceFloat(
+        service -> {
+          testCreateIndexFloat();
+
+          List<Long> entityIds = LongStream.range(0, 10).boxed().collect(Collectors.toList());
+
+          Map<Long, Schema.Entity> entities = service.getEntityByID(entityIds);
+
+          List<List<Float>> vectors =
+              entities.values().stream()
+                  .map(e -> e.get(floatSchema.floatVectorField))
+                  .collect(Collectors.toList());
+
+          Query query =
+              Query.bool(
+                  Query.must(
+                      floatSchema.floatVectorField.query(vectors).param("nprobe", 16).top(1),
+                      floatSchema.floatVectorField.query(vectors).param("nprobe", 16).top(1)));
+
+          assertThrows(InvalidDsl.class, () -> service.buildSearchParam(query));
+        });
+  }
 }