Browse Source

Fix vector search bug

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>
sahuang 4 years ago
parent
commit
6090ace479

+ 26 - 22
CHANGELOG.md

@@ -1,43 +1,49 @@
 # Changelog   
 # Changelog   
 
 
-## milvus-sdk-java 0.9.0 (2020-10-16)
+## v0.9.1 (2020-10-29)
+
+### Bug
+
+- [\#4086](https://github.com/milvus-io/milvus/issues/4086) - Fix vector search error when topK is small
+
+## v0.9.0 (2020-10-16)
 
 
 ### Feature
 ### Feature
 
 
-- \#2976 Scalar-field filtering support
+- [\#2976](https://github.com/milvus-io/milvus/pull/2976) Scalar-field filtering support
 
 
 ### Improvement
 ### Improvement
 
 
-- \#134 - Simplify the client code
+- [\#134](https://github.com/milvus-io/milvus-sdk-java/pull/134) - Simplify the client code
 
 
-## milvus-sdk-java 0.8.5 (2020-08-26)
+## v0.8.5 (2020-08-26)
 
 
 ### Feature
 ### Feature
 
 
-- \#128 - GRPC timeout support
-- \#129 - Support GRPC name resolver and load balancing
+- [\#128](https://github.com/milvus-io/milvus-sdk-java/pull/128) - GRPC timeout support
+- [\#129](https://github.com/milvus-io/milvus-sdk-java/pull/129) - Support GRPC name resolver and load balancing
 
 
-## milvus-sdk-java 0.8.3 (2020-07-15)
+## v0.8.3 (2020-07-15)
 
 
 ### Improvement
 ### Improvement
 
 
-- \#118 - Remove isConnect() API
+- [\#118](https://github.com/milvus-io/milvus-sdk-java/pull/118) - Remove isConnect() API
 
 
-## milvus-sdk-java 0.8.0 (2020-05-15)
+## v0.8.0 (2020-05-15)
 
 
 ### Feature
 ### Feature
 
 
-- \#93 - Add/Improve getVectorByID, collectionInfo and hasPartition API
-- \#2295 - Rename SDK interfaces
+- [\#93](https://github.com/milvus-io/milvus-sdk-java/pull/93) - Add/Improve getVectorByID, collectionInfo and hasPartition API
+- [\#2295](https://github.com/milvus-io/milvus/issues/2295) - Rename SDK interfaces
 
 
-## milvus-sdk-java 0.7.0 (2020-04-15)
+## v0.7.0 (2020-04-15)
 
 
 ### Feature
 ### Feature
 
 
-- \#261 - Integrate ANNOY into Milvus
-- \#1828 - Add searchAsync / createIndexAsync / insertAsync / flushAsync / compactAsync API
+- [\#261](https://github.com/milvus-io/milvus/issues/261) - Integrate ANNOY into Milvus
+- [\#1828](https://github.com/milvus-io/milvus/issues/1828) - Add searchAsync / createIndexAsync / insertAsync / flushAsync / compactAsync API
 
 
-## milvus-sdk-java 0.6.0 (2020-03-31)
+## v0.6.0 (2020-03-31)
 
 
 ### Bug
 ### Bug
 
 
@@ -48,15 +54,13 @@
 
 
 - \#1603 - Add binary metrics: Substructure & Superstructure
 - \#1603 - Add binary metrics: Substructure & Superstructure
 
 
-## milvus-sdk-java 0.5.0 (2020-03-11)
-
-## milvus-sdk-java 0.4.1 (2019-12-16)
+## v0.4.1 (2019-12-16)
 
 
 ### Bug
 ### Bug
 
 
 - \#78 - Partition tag not working when searching
 - \#78 - Partition tag not working when searching
 
 
-## milvus-sdk-java 0.4.0 (2019-12-7)
+## v0.4.0 (2019-12-7)
 
 
 ### Bug
 ### Bug
 
 
@@ -69,7 +73,7 @@
 - \#72 - Add more getters in ShowPartitionResponse
 - \#72 - Add more getters in ShowPartitionResponse
 - \#73 - Add @Deprecated for DateRanges in SearchParam
 - \#73 - Add @Deprecated for DateRanges in SearchParam
 
 
-## milvus-sdk-java 0.3.0 (2019-11-13)
+## v0.3.0 (2019-11-13)
 
 
 ### Bug
 ### Bug
 
 
@@ -82,7 +86,7 @@
 - \#62 - Change GRPC proto (and related code) to increase search result's transmission speed
 - \#62 - Change GRPC proto (and related code) to increase search result's transmission speed
 - \#63 - Make some functions and constructors package-private if necessary
 - \#63 - Make some functions and constructors package-private if necessary
 
 
-## milvus-sdk-java 0.2.2 (2019-11-4)
+## v0.2.2 (2019-11-4)
 
 
 ### Improvement
 ### Improvement
 
 
@@ -90,7 +94,7 @@
 - \#51 - Change connect waitTime to timeout
 - \#51 - Change connect waitTime to timeout
 - \#52 - Change IVF_SQ8H to IVF_SQ8_H
 - \#52 - Change IVF_SQ8H to IVF_SQ8_H
 
 
-## milvus-sdk-java 0.2.0 (2019-10-21)
+## v0.2.0 (2019-10-21)
 
 
 ### Bug
 ### Bug
 
 

+ 2 - 2
README.md

@@ -43,7 +43,7 @@ You can use **Apache Maven** or **Gradle**/**Grails** to download the SDK.
    - Gradle/Grails
    - Gradle/Grails
 
 
         ```gradle
         ```gradle
-        compile 'io.milvus:milvus-sdk-java:0.9.2'
+        compile group: 'io.milvus', name: 'milvus-sdk-java', version: '0.9.2'
         ```
         ```
 
 
 ### Examples
 ### Examples
@@ -64,7 +64,7 @@ Please refer to [examples](https://github.com/milvus-io/milvus-sdk-java/tree/0.9
     ```
     ```
   This is because SLF4J jar files need to be added into your application's classpath. SLF4J is used by Java SDK for logging purpose.
   This is because SLF4J jar files need to be added into your application's classpath. SLF4J is used by Java SDK for logging purpose.
   
   
-  To fix this issue, you can use **Apache Maven** or **Gradle**/**Grails** to download the required jar files.
+  To fix this issue, you can download the required jar files.
                                                                                                          
                                                                                                          
     - Apache Maven
     - Apache Maven
     
     

+ 6 - 0
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -715,6 +715,12 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
         resultDistancesList.add(queryDistancesList.subList(i * topK, pos));
         resultDistancesList.add(queryDistancesList.subList(i * topK, pos));
         resultFieldsMap.add(fieldsMap.subList(i * topK, pos));
         resultFieldsMap.add(fieldsMap.subList(i * topK, pos));
       }
       }
+    } else {
+      for (int i = 0; i < numQueries; i++) {
+        resultIdsList.add(new ArrayList<>());
+        resultDistancesList.add(new ArrayList<>());
+        resultFieldsMap.add(new ArrayList<>());
+      }
     }
     }
 
 
     return new SearchResult(numQueries, topK, resultIdsList, resultDistancesList, resultFieldsMap);
     return new SearchResult(numQueries, topK, resultIdsList, resultDistancesList, resultFieldsMap);

+ 48 - 0
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -608,6 +608,54 @@ class MilvusClientTest {
     client.dropCollection(binaryCollectionName);
     client.dropCollection(binaryCollectionName);
   }
   }
 
 
+  @org.junit.jupiter.api.Test
+  void searchEmptyResult() {
+    List<Long> intValues = new ArrayList<>(size);
+    List<Float> floatValues = new ArrayList<>(size);
+    List<List<Float>> vectors = generateFloatVectors(size, dimension);
+    for (int i = 0; i < size; i++) {
+      intValues.add((long) i);
+      floatValues.add((float) i);
+    }
+    vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
+
+    List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
+    InsertParam insertParam =
+        InsertParam.create(randomCollectionName)
+            .addField("int64", DataType.INT64, intValues)
+            .addField("float", DataType.FLOAT, floatValues)
+            .addVectorField("float_vec", DataType.VECTOR_FLOAT, vectors)
+            .setEntityIds(insertIds);
+    List<Long> entityIds = client.insert(insertParam);
+    assertEquals(size, entityIds.size());
+
+    client.flush(randomCollectionName);
+
+    final int searchSize = 5;
+    List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
+
+    final long topK = 10;
+    final String tag = "tag";
+    List<String> partitionTags = new ArrayList<>();
+    partitionTags.add(tag);
+    SearchParam searchParam =
+        SearchParam.create(randomCollectionName)
+            .setDsl(generateComplexDSL(topK, vectorsToSearch.toString()))
+            .setPartitionTags(partitionTags)
+            .setParamsInJson(
+                new JsonBuilder()
+                    .param("fields", new ArrayList<>(Arrays.asList("int64", "float_vec")))
+                    .build());
+    SearchResult searchResult = client.search(searchParam);
+    assertEquals(searchSize, searchResult.getQueryResultsList().size());
+    List<List<Long>> resultIdsList = searchResult.getResultIdsList();
+    assertEquals(searchSize, resultIdsList.size());
+    assertEquals(0, resultIdsList.get(0).size());
+    List<List<Float>> resultDistancesList = searchResult.getResultDistancesList();
+    assertEquals(searchSize, resultDistancesList.size());
+    assertEquals(0, resultDistancesList.get(0).size());
+  }
+
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test
   void getCollectionInfo() {
   void getCollectionInfo() {
     CollectionMapping collectionMapping = client.getCollectionInfo(randomCollectionName);
     CollectionMapping collectionMapping = client.getCollectionInfo(randomCollectionName);

+ 28 - 0
src/test/java/io/milvus/client/dsl/SearchDslTest.java

@@ -1,6 +1,7 @@
 package io.milvus.client.dsl;
 package io.milvus.client.dsl;
 
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 
 import io.milvus.client.ConnectParam;
 import io.milvus.client.ConnectParam;
@@ -11,6 +12,7 @@ import io.milvus.client.MilvusClient;
 import io.milvus.client.MilvusGrpcClient;
 import io.milvus.client.MilvusGrpcClient;
 import io.milvus.client.SearchParam;
 import io.milvus.client.SearchParam;
 import io.milvus.client.SearchResult;
 import io.milvus.client.SearchResult;
+import io.milvus.client.exception.InvalidDsl;
 import java.nio.ByteBuffer;
 import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.List;
 import java.util.List;
@@ -22,6 +24,7 @@ import java.util.stream.LongStream;
 import java.util.stream.Stream;
 import java.util.stream.Stream;
 import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.RandomUtils;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.function.Executable;
 import org.testcontainers.containers.GenericContainer;
 import org.testcontainers.containers.GenericContainer;
 import org.testcontainers.junit.jupiter.Container;
 import org.testcontainers.junit.jupiter.Container;
 import org.testcontainers.junit.jupiter.Testcontainers;
 import org.testcontainers.junit.jupiter.Testcontainers;
@@ -300,4 +303,29 @@ public class SearchDslTest {
                   .collect(Collectors.toList()));
                   .collect(Collectors.toList()));
         });
         });
   }
   }
+
+  @Test
+  public void testMultipleVectorsQuery() {
+    withMilvusServiceFloat(
+        service -> {
+          testCreateIndexFloat();
+
+          List<Long> entityIds = LongStream.range(0, 10).boxed().collect(Collectors.toList());
+
+          Map<Long, Schema.Entity> entities = service.getEntityByID(entityIds);
+
+          List<List<Float>> vectors =
+              entities.values().stream()
+                  .map(e -> e.get(floatSchema.floatVectorField))
+                  .collect(Collectors.toList());
+
+          Query query =
+              Query.bool(
+                  Query.must(
+                      floatSchema.floatVectorField.query(vectors).param("nprobe", 16).top(1),
+                      floatSchema.floatVectorField.query(vectors).param("nprobe", 16).top(1)));
+
+          assertThrows(InvalidDsl.class, () -> service.buildSearchParam(query));
+        });
+  }
 }
 }