瀏覽代碼

release v0.4.1: fix search partition not working and add unit test (#79)

* fix partition in search not working and add unittest

* update CHANGELOG and version

* update

* update

* update
Zhiru Zhu 5 年之前
父節點
當前提交
6b4f0c7945

+ 6 - 0
CHANGELOG.md

@@ -1,5 +1,11 @@
 # Changelog     
 # Changelog     
 
 
+## milvus-sdk-java 0.4.1 (2019-12-16)
+
+### Bug
+---
+- \#78 - Partition tag not working when searching
+
 ## milvus-sdk-java 0.4.0 (2019-12-7)
 ## milvus-sdk-java 0.4.0 (2019-12-7)
 
 
 ### Bug
 ### Bug

+ 3 - 3
README.md

@@ -17,7 +17,7 @@ Milvus version compatibility:
 | 0.5.1 | 0.2.2 | 
 | 0.5.1 | 0.2.2 | 
 | 0.5.2 | 0.2.2 | 
 | 0.5.2 | 0.2.2 | 
 | 0.5.3 | 0.3.0 | 
 | 0.5.3 | 0.3.0 | 
-| 0.6.0 | 0.4.0 | 
+| 0.6.0 | 0.4.1 | 
 
 
 ### Dependency 
 ### Dependency 
 
 
@@ -26,13 +26,13 @@ Apache Maven
 <dependency>
 <dependency>
     <groupId>io.milvus</groupId>
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>0.4.0</version>
+    <version>0.4.1</version>
 </dependency>
 </dependency>
 ```
 ```
 
 
 Gradle/Grails 
 Gradle/Grails 
 
 
-`compile 'io.milvus:milvus-sdk-java:0.4.0'`
+`compile 'io.milvus:milvus-sdk-java:0.4.1'`
 
 
 ### Examples
 ### Examples
 
 

+ 1 - 1
examples/pom.xml

@@ -63,7 +63,7 @@
         <dependency>
         <dependency>
             <groupId>io.milvus</groupId>
             <groupId>io.milvus</groupId>
             <artifactId>milvus-sdk-java</artifactId>
             <artifactId>milvus-sdk-java</artifactId>
-            <version>0.4.0</version>
+            <version>0.4.1</version>
         </dependency>
         </dependency>
     </dependencies>
     </dependencies>
 
 

+ 1 - 1
pom.xml

@@ -25,7 +25,7 @@
 
 
     <groupId>io.milvus</groupId>
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>0.4.0</version>
+    <version>0.4.1</version>
     <packaging>jar</packaging>
     <packaging>jar</packaging>
 
 
     <name>io.milvus:milvus-sdk-java</name>
     <name>io.milvus:milvus-sdk-java</name>

+ 3 - 1
src/main/java/io/milvus/client/MilvusClient.java

@@ -22,7 +22,7 @@ package io.milvus.client;
 /** The Milvus Client Interface */
 /** The Milvus Client Interface */
 public interface MilvusClient {
 public interface MilvusClient {
 
 
-  String clientVersion = "0.4.0";
+  String clientVersion = "0.4.1";
 
 
   /** @return the current Milvus client version */
   /** @return the current Milvus client version */
   default String getClientVersion() {
   default String getClientVersion() {
@@ -212,6 +212,7 @@ public interface MilvusClient {
    *                                          .withTopK(topK)
    *                                          .withTopK(topK)
    *                                          .withNProbe(nProbe)
    *                                          .withNProbe(nProbe)
    *                                          .withDateRanges(dateRanges)
    *                                          .withDateRanges(dateRanges)
+   *                                          .withPartitionTags(partitionTagsList)
    *                                          .build();
    *                                          .build();
    * </code>
    * </code>
    * </pre>
    * </pre>
@@ -236,6 +237,7 @@ public interface MilvusClient {
    *                                          .withTopK(topK)
    *                                          .withTopK(topK)
    *                                          .withNProbe(nProbe)
    *                                          .withNProbe(nProbe)
    *                                          .withDateRanges(dateRanges)
    *                                          .withDateRanges(dateRanges)
+   *                                          .withPartitionTags(partitionTagsList)
    *                                          .build();
    *                                          .build();
    * SearchInFilesParam searchInFilesParam = new SearchInFilesParam.Builder(fileIds, searchParam)
    * SearchInFilesParam searchInFilesParam = new SearchInFilesParam.Builder(fileIds, searchParam)
    *                                                               .build();
    *                                                               .build();

+ 2 - 0
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -479,6 +479,7 @@ public class MilvusGrpcClient implements MilvusClient {
             .addAllQueryRangeArray(queryRangeList)
             .addAllQueryRangeArray(queryRangeList)
             .setTopk(searchParam.getTopK())
             .setTopk(searchParam.getTopK())
             .setNprobe(searchParam.getNProbe())
             .setNprobe(searchParam.getNProbe())
+            .addAllPartitionTagArray(searchParam.getPartitionTags())
             .build();
             .build();
 
 
     io.milvus.grpc.TopKQueryResult response;
     io.milvus.grpc.TopKQueryResult response;
@@ -533,6 +534,7 @@ public class MilvusGrpcClient implements MilvusClient {
             .addAllQueryRangeArray(queryRangeList)
             .addAllQueryRangeArray(queryRangeList)
             .setTopk(searchParam.getTopK())
             .setTopk(searchParam.getTopK())
             .setNprobe(searchParam.getNProbe())
             .setNprobe(searchParam.getNProbe())
+            .addAllPartitionTagArray(searchParam.getPartitionTags())
             .build();
             .build();
 
 
     io.milvus.grpc.SearchInFilesParam request =
     io.milvus.grpc.SearchInFilesParam request =

+ 61 - 31
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -25,6 +25,7 @@ import java.util.*;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import java.util.stream.Collectors;
 import java.util.stream.DoubleStream;
 import java.util.stream.DoubleStream;
+import java.util.stream.LongStream;
 
 
 import static org.junit.jupiter.api.Assertions.*;
 import static org.junit.jupiter.api.Assertions.*;
 
 
@@ -174,54 +175,83 @@ class MilvusClientTest {
 
 
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test
   void partitionTest() throws InterruptedException {
   void partitionTest() throws InterruptedException {
-    final String partitionName = "partition";
-    final String tag = "tag";
+    final String partitionName1 = "partition1";
+    final String tag1 = "tag1";
 
 
-    Partition partition = new Partition.Builder(randomTableName, partitionName, tag).build();
+    Partition partition = new Partition.Builder(randomTableName, partitionName1, tag1).build();
     Response createPartitionResponse = client.createPartition(partition);
     Response createPartitionResponse = client.createPartition(partition);
     assertTrue(createPartitionResponse.ok());
     assertTrue(createPartitionResponse.ok());
 
 
-    List<List<Float>> vectors = generateVectors(size, dimension);
+    final String partitionName2 = "partition2";
+    final String tag2 = "tag2";
+
+    Partition partition2 = new Partition.Builder(randomTableName, partitionName2, tag2).build();
+    createPartitionResponse = client.createPartition(partition2);
+    assertTrue(createPartitionResponse.ok());
+
+    ShowPartitionsResponse showPartitionsResponse = client.showPartitions(randomTableName);
+    assertTrue(showPartitionsResponse.ok());
+    assertEquals(2, showPartitionsResponse.getPartitionList().size());
+
+    List<List<Float>> vectors1 = generateVectors(size, dimension);
+    List<Long> vectorIds1 = LongStream.range(0, size).boxed().collect(Collectors.toList());
     InsertParam insertParam =
     InsertParam insertParam =
-        new InsertParam.Builder(randomTableName, vectors).withPartitionTag(tag).build();
+        new InsertParam.Builder(randomTableName, vectors1)
+            .withVectorIds(vectorIds1)
+            .withPartitionTag(tag1)
+            .build();
     InsertResponse insertResponse = client.insert(insertParam);
     InsertResponse insertResponse = client.insert(insertParam);
     assertTrue(insertResponse.ok());
     assertTrue(insertResponse.ok());
+    List<List<Float>> vectors2 = generateVectors(size, dimension);
+    List<Long> vectorIds2 = LongStream.range(size, size * 2).boxed().collect(Collectors.toList());
+    insertParam =
+        new InsertParam.Builder(randomTableName, vectors2)
+            .withVectorIds(vectorIds2)
+            .withPartitionTag(tag2)
+            .build();
+    insertResponse = client.insert(insertParam);
+    assertTrue(insertResponse.ok());
 
 
     TimeUnit.SECONDS.sleep(1);
     TimeUnit.SECONDS.sleep(1);
 
 
-    final int searchSize = 5;
-    List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
+    assertEquals(size * 2, client.getTableRowCount(randomTableName).getTableRowCount());
 
 
-    List<String> partitionTags = new ArrayList<>();
-    partitionTags.add(tag);
+    final int searchSize = 1;
     final long topK = 10;
     final long topK = 10;
-    SearchParam searchParam =
-        new SearchParam.Builder(randomTableName, vectorsToSearch)
+
+    List<List<Float>> vectorsToSearch1 = vectors1.subList(0, searchSize);
+    List<String> partitionTags1 = new ArrayList<>();
+    partitionTags1.add(tag1);
+    SearchParam searchParam1 =
+        new SearchParam.Builder(randomTableName, vectorsToSearch1)
             .withTopK(topK)
             .withTopK(topK)
             .withNProbe(20)
             .withNProbe(20)
-            .withPartitionTags(partitionTags)
+            .withPartitionTags(partitionTags1)
             .build();
             .build();
-    SearchResponse searchResponse = client.search(searchParam);
-    assertTrue(searchResponse.ok());
-    List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
-    assertEquals(searchSize, resultIdsList.size());
-    List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
-    assertEquals(searchSize, resultDistancesList.size());
-    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
-    assertEquals(searchSize, queryResultsList.size());
-
-    final String partitionName2 = "partition2";
-    final String tag2 = "tag2";
-
-    Partition partition2 = new Partition.Builder(randomTableName, partitionName2, tag2).build();
-    createPartitionResponse = client.createPartition(partition2);
-    assertTrue(createPartitionResponse.ok());
+    SearchResponse searchResponse1 = client.search(searchParam1);
+    assertTrue(searchResponse1.ok());
+    List<List<Long>> resultIdsList1 = searchResponse1.getResultIdsList();
+    assertEquals(searchSize, resultIdsList1.size());
+    assertTrue(vectorIds1.containsAll(resultIdsList1.get(0)));
+
+    List<List<Float>> vectorsToSearch2 = vectors2.subList(0, searchSize);
+    List<String> partitionTags2 = new ArrayList<>();
+    partitionTags2.add(tag2);
+    SearchParam searchParam2 =
+        new SearchParam.Builder(randomTableName, vectorsToSearch2)
+            .withTopK(topK)
+            .withNProbe(20)
+            .withPartitionTags(partitionTags2)
+            .build();
+    SearchResponse searchResponse2 = client.search(searchParam2);
+    assertTrue(searchResponse2.ok());
+    List<List<Long>> resultIdsList2 = searchResponse2.getResultIdsList();
+    assertEquals(searchSize, resultIdsList2.size());
+    assertTrue(vectorIds2.containsAll(resultIdsList2.get(0)));
 
 
-    ShowPartitionsResponse showPartitionsResponse = client.showPartitions(randomTableName);
-    assertTrue(showPartitionsResponse.ok());
-    assertEquals(2, showPartitionsResponse.getPartitionList().size());
+    assertTrue(Collections.disjoint(resultIdsList1, resultIdsList2));
 
 
-    Response dropPartitionResponse = client.dropPartition(partitionName);
+    Response dropPartitionResponse = client.dropPartition(partitionName1);
     assertTrue(dropPartitionResponse.ok());
     assertTrue(dropPartitionResponse.ok());
 
 
     dropPartitionResponse = client.dropPartition(randomTableName, tag2);
     dropPartitionResponse = client.dropPartition(randomTableName, tag2);