Browse Source

Merge pull request #94 from sahuang/master

Add new APIs
Xiaohai Xu 5 years ago
parent
commit
1f2bb8f519

+ 1 - 1
examples/pom.xml

@@ -63,7 +63,7 @@
         <dependency>
             <groupId>io.milvus</groupId>
             <artifactId>milvus-sdk-java</artifactId>
-            <version>0.7.0</version>
+            <version>0.8.0-SNAPSHOT</version>
         </dependency>
         <dependency>
             <groupId>com.google.code.gson</groupId>

+ 39 - 3
examples/src/main/java/MilvusClientExample.java

@@ -140,6 +140,21 @@ public class MilvusClientExample {
     // Describe the index for your collection
     DescribeIndexResponse describeIndexResponse = client.describeIndex(collectionName);
 
+    // Get collection info
+    Response showCollectionInfoResponse = client.showCollectionInfo(collectionName);
+    if (showCollectionInfoResponse.ok()) {
+      // Collection info is sent back with JSON type string
+      String jsonString = showCollectionInfoResponse.getMessage();
+      System.out.println(jsonString);
+    }
+
+    // Check whether a partition exists in collection
+    // Obviously we do not have partition "tag" now
+    HasPartitionResponse testHasPartition = client.hasPartition(collectionName, "tag");
+    if (testHasPartition.ok() && testHasPartition.hasPartition()) {
+      throw new AssertionError("Wrong results!");
+    }
+
     // Search vectors
     // Searching the first 5 vectors of the vectors we just inserted
     final int searchBatchSize = 5;
@@ -175,6 +190,27 @@ public class MilvusClientExample {
     List<List<Long>> resultIds = searchResponse.getResultIdsList();
     List<List<Float>> resultDistances = searchResponse.getResultDistancesList();
 
+    // SearchByIds
+    // Searching the first 5 vectors of the vectors we just inserted by ID
+    SearchByIdsParam searchByIdsParam =
+            new SearchByIdsParam.Builder(collectionName)
+                    .withIDs(vectorIds.subList(0, searchBatchSize))
+                    .withTopK(topK)
+                    .withParamsInJson(searchParamsJson.toString())
+                    .build();
+    SearchResponse searchByIDResponse = client.searchByIds(searchByIdsParam);
+    if (searchByIDResponse.ok()) {
+      List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
+      final double epsilon = 0.001;
+      for (int i = 0; i < searchBatchSize; i++) {
+        SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+        if (firstQueryResult.getVectorId() != vectorIds.get(i)
+                || Math.abs(1 - firstQueryResult.getDistance()) > epsilon) {
+          throw new AssertionError("Wrong results!");
+        }
+      }
+    }
+
     // You can send search request asynchronously, which returns a ListenableFuture object
     ListenableFuture<SearchResponse> searchResponseFuture = client.searchAsync(searchParam);
     try {
@@ -193,10 +229,10 @@ public class MilvusClientExample {
     flushResponse = client.flush(collectionName);
 
     // Try to get the corresponding vector of the first id you just deleted.
-    GetVectorByIdResponse getVectorByIdResponse =
-        client.getVectorById(collectionName, vectorIds.get(0));
+    GetVectorsByIdsResponse getVectorsByIdsResponse =
+        client.getVectorsByIds(collectionName, vectorIds.subList(0, searchBatchSize));
     // Obviously you won't get anything
-    if (getVectorByIdResponse.exists()) {
+    if (!getVectorsByIdsResponse.getFloatVectors().get(0).isEmpty()) {
       throw new AssertionError("This can never happen!");
     }
 

+ 13 - 1
pom.xml

@@ -25,7 +25,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>0.7.0</version>
+    <version>0.8.0-SNAPSHOT</version>
     <packaging>jar</packaging>
 
     <name>io.milvus:milvus-sdk-java</name>
@@ -47,6 +47,12 @@
             <organization>Milvus</organization>
             <organizationUrl>http://www.milvus.io</organizationUrl>
         </developer>
+        <developer>
+            <name>Xiaohai Xu</name>
+            <email>xiaohai.xu@zilliz.com</email>
+            <organization>Milvus</organization>
+            <organizationUrl>http://www.milvus.io</organizationUrl>
+        </developer>
     </developers>
 
     <scm>
@@ -139,6 +145,12 @@
             <artifactId>commons-collections4</artifactId>
             <version>4.4</version>
         </dependency>
+        <dependency>
+            <groupId>org.json</groupId>
+            <artifactId>json</artifactId>
+            <version>20190722</version>
+        </dependency>
+
     </dependencies>
 
     <profiles>

+ 0 - 59
src/main/java/io/milvus/client/GetVectorByIdResponse.java

@@ -1,59 +0,0 @@
-package io.milvus.client;
-
-import java.nio.ByteBuffer;
-import java.util.List;
-import java.util.Optional;
-
-/**
- * Contains the returned <code>response</code> and either a <code>floatVector</code> or a <code>
- * binaryVector</code> for <code>getVectorById</code>. If the id does not exist, both returned
- * vectors will be empty.
- */
-public class GetVectorByIdResponse {
-  private final Response response;
-  private final List<Float> floatVector;
-  private final ByteBuffer binaryVector;
-
-  GetVectorByIdResponse(Response response, List<Float> floatVector, ByteBuffer binaryVector) {
-    this.response = response;
-    this.floatVector = floatVector;
-    this.binaryVector = binaryVector;
-  }
-
-  public List<Float> getFloatVector() {
-    return floatVector;
-  }
-
-  /**
-   * @return an <code>Optional</code> object which may or may not contain a <code>ByteBuffer</code>
-   *     object
-   * @see Optional
-   */
-  public Optional<ByteBuffer> getBinaryVector() {
-    return Optional.ofNullable(binaryVector);
-  }
-
-  /** @return <code>true</code> if the id corresponds to a float vector */
-  public boolean isFloatVector() {
-    return !floatVector.isEmpty();
-  }
-
-  /** @return <code>true</code> if the id corresponds to a binary vector */
-  public boolean isBinaryVector() {
-    return binaryVector != null && binaryVector.hasRemaining();
-  }
-
-  /** @return <code>true</code> if the id's corresponding vector exists */
-  public boolean exists() {
-    return isFloatVector() || isBinaryVector();
-  }
-
-  public Response getResponse() {
-    return response;
-  }
-
-  /** @return <code>true</code> if the response status equals SUCCESS */
-  public boolean ok() {
-    return response.ok();
-  }
-}

+ 43 - 0
src/main/java/io/milvus/client/GetVectorsByIdsResponse.java

@@ -0,0 +1,43 @@
+package io.milvus.client;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Contains the returned <code>response</code> and either a <code>List</code> of <code>floatVectors</code> or <code>
+ * binaryVectors</code> for <code>getVectorsByIds</code>. If the id does not exist, both float and binary
+ * vectors corresponding to the id will be empty.
+ */
+public class GetVectorsByIdsResponse {
+  private final Response response;
+  private final List<List<Float>> floatVectors;
+  private final List<ByteBuffer> binaryVectors;
+
+  GetVectorsByIdsResponse(Response response, List<List<Float>> floatVectors, List<ByteBuffer> binaryVectors) {
+    this.response = response;
+    this.floatVectors = floatVectors;
+    this.binaryVectors = binaryVectors;
+  }
+
+  public List<List<Float>> getFloatVectors() {
+    return floatVectors;
+  }
+
+  /**
+   * @return a <code>List</code> of <code>ByteBuffer</code> object
+   */
+  public List<ByteBuffer> getBinaryVectors() {
+    return binaryVectors;
+  }
+
+  public Response getResponse() {
+    return response;
+  }
+
+  /** @return <code>true</code> if the response status equals SUCCESS */
+  public boolean ok() {
+    return response.ok();
+  }
+
+}

+ 53 - 0
src/main/java/io/milvus/client/HasPartitionResponse.java

@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.client;
+
+/**
+ * Contains the returned <code>response</code> and <code>hasPartition</code> for <code>
+ * hasPartition</code>
+ */
+public class HasPartitionResponse {
+    private final Response response;
+    private final boolean hasPartition;
+
+    HasPartitionResponse(Response response, boolean hasPartition) {
+        this.response = response;
+        this.hasPartition = hasPartition;
+    }
+
+    public boolean hasPartition() {
+        return hasPartition;
+    }
+
+    public Response getResponse() {
+        return response;
+    }
+
+    /** @return <code>true</code> if the response status equals SUCCESS */
+    public boolean ok() {
+        return response.ok();
+    }
+
+    @Override
+    public String toString() {
+        return String.format(
+                "HasPartitionResponse {%s, has partition = %s}", response.toString(), hasPartition);
+    }
+}

+ 45 - 14
src/main/java/io/milvus/client/MilvusClient.java

@@ -26,9 +26,9 @@ import java.util.List;
 /** The Milvus Client Interface */
 public interface MilvusClient {
 
-  String clientVersion = "0.7.0";
+  String clientVersion = "0.8.0";
 
-  /** @return current Milvus client version: 0.7.0 */
+  /** @return current Milvus client version: 0.8.0 */
   default String getClientVersion() {
     return clientVersion;
   }
@@ -166,6 +166,16 @@ public interface MilvusClient {
    */
   Response createPartition(String collectionName, String tag);
 
+  /**
+   * Checks whether the partition exists
+   *
+   * @param collectionName collection name
+   * @param tag partition tag
+   * @return <code>HasPartitionResponse</code>
+   * @see Response
+   */
+  HasPartitionResponse hasPartition(String collectionName, String tag);
+
   /**
    * Shows current partitions of a collection
    *
@@ -254,6 +264,30 @@ public interface MilvusClient {
    */
   SearchResponse search(SearchParam searchParam);
 
+  /**
+   * Searches vectors specified by <code>searchByIdsParam</code>
+   *
+   * @param searchByIdsParam the <code>SearchByIdsParam</code> object
+   *     <pre>
+   * example usage:
+   * <code>
+   * SearchByIdsParam searchByIdsParam = new SearchByIdsParam.Builder(collectionName)
+   *                                          .withIDs(ids)
+   *                                          .withTopK(topK)
+   *                                          .withPartitionTags(partitionTagsList)
+   *                                          .withParamsInJson("{\"nprobe\": 20}")
+   *                                          .build();
+   * </code>
+   * </pre>
+   *
+   * @return <code>SearchResponse</code>
+   * @see SearchByIdsParam
+   * @see SearchResponse
+   * @see SearchResponse.QueryResult
+   * @see Response
+   */
+  SearchResponse searchByIds(SearchByIdsParam searchByIdsParam);
+
   /**
    * Searches vectors specified by <code>searchParam</code> asynchronously
    *
@@ -389,27 +423,24 @@ public interface MilvusClient {
    * Shows collection information. A collection consists of one or multiple partitions (including
    * the default partition), and a partitions consists of one or more segments. Each partition or
    * segment can be uniquely identified by its partition tag or segment name respectively.
+   * The result will be returned as JSON string.
    *
    * @param collectionName collection to show info from
-   * @return <code>ShowCollectionInfoResponse</code>
-   * @see ShowCollectionInfoResponse
-   * @see CollectionInfo
-   * @see CollectionInfo.PartitionInfo
-   * @see CollectionInfo.PartitionInfo.SegmentInfo
+   * @return <code>Response</code>
    * @see Response
    */
-  ShowCollectionInfoResponse showCollectionInfo(String collectionName);
+  Response showCollectionInfo(String collectionName);
 
   /**
-   * Gets either a float or binary vector by id.
+   * Gets vectors data by id array
    *
-   * @param collectionName collection to get vector from
-   * @param id vector id
-   * @return <code>GetVectorByIdResponse</code>
-   * @see GetVectorByIdResponse
+   * @param collectionName collection to get vectors from
+   * @param ids a <code>List</code> of vector ids
+   * @return <code>GetVectorsByIdsResponse</code>
+   * @see GetVectorsByIdsResponse
    * @see Response
    */
-  GetVectorByIdResponse getVectorById(String collectionName, Long id);
+  GetVectorsByIdsResponse getVectorsByIds(String collectionName, List<Long> ids);
 
   /**
    * Gets all vector ids in a segment

+ 130 - 59
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -367,6 +367,40 @@ public class MilvusGrpcClient implements MilvusClient {
     }
   }
 
+  @Override
+  public HasPartitionResponse hasPartition(String collectionName, String tag) {
+
+    if (!channelIsReadyOrIdle()) {
+      logWarning("You are not connected to Milvus server");
+      return new HasPartitionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
+    }
+
+    PartitionParam request =
+        PartitionParam.newBuilder().setCollectionName(collectionName).setTag(tag).build();
+    BoolReply response;
+
+    try {
+      response = blockingStub.hasPartition(request);
+
+      if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
+        logInfo("hasPartition with tag `{0}` in `{1}` = {2}", tag, collectionName, response.getBoolReply());
+        return new HasPartitionResponse(
+                new Response(Response.Status.SUCCESS), response.getBoolReply());
+      } else {
+        logSevere("hasPartition with tag `{0}` in `{1}` failed:\n{2}", tag, collectionName, response.toString());
+        return new HasPartitionResponse(
+                new Response(
+                        Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
+                        response.getStatus().getReason()),
+                false);
+      }
+    } catch (StatusRuntimeException e) {
+      logSevere("hasPartition RPC failed:\n{0}", e.getStatus().toString());
+      return new HasPartitionResponse(
+              new Response(Response.Status.RPC_ERROR, e.toString()), false);
+    }
+  }
+
   @Override
   public ShowPartitionsResponse showPartitions(String collectionName) {
 
@@ -599,6 +633,62 @@ public class MilvusGrpcClient implements MilvusClient {
     }
   }
 
+  @Override
+  public SearchResponse searchByIds(@Nonnull SearchByIdsParam searchByIdsParam) {
+
+    if (!channelIsReadyOrIdle()) {
+      logWarning("You are not connected to Milvus server");
+      SearchResponse searchResponse = new SearchResponse();
+      searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
+      return searchResponse;
+    }
+
+    List<Long> idList = searchByIdsParam.getIds();
+
+    KeyValuePair extraParam =
+            KeyValuePair.newBuilder()
+                    .setKey(extraParamKey)
+                    .setValue(searchByIdsParam.getParamsInJson())
+                    .build();
+
+    io.milvus.grpc.SearchByIDParam request =
+            io.milvus.grpc.SearchByIDParam.newBuilder()
+                    .setCollectionName(searchByIdsParam.getCollectionName())
+                    .addAllIdArray(idList)
+                    .addAllPartitionTagArray(searchByIdsParam.getPartitionTags())
+                    .setTopk(searchByIdsParam.getTopK())
+                    .addExtraParams(extraParam)
+                    .build();
+
+    TopKQueryResult response;
+
+    try {
+      response = blockingStub.searchByID(request);
+
+      if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
+        SearchResponse searchResponse = buildSearchResponse(response);
+        searchResponse.setResponse(new Response(Response.Status.SUCCESS));
+        logInfo(
+                "Search by ids completed successfully! Returned results for {0} queries",
+                searchResponse.getNumQueries());
+        return searchResponse;
+      } else {
+        logSevere("Search by ids failed:\n{0}", response.getStatus().toString());
+        SearchResponse searchResponse = new SearchResponse();
+        searchResponse.setResponse(
+                new Response(
+                        Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
+                        response.getStatus().getReason()));
+        return searchResponse;
+      }
+    } catch (StatusRuntimeException e) {
+      logSevere("search by ids RPC failed:\n{0}", e.getStatus().toString());
+      SearchResponse searchResponse = new SearchResponse();
+      searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
+      return searchResponse;
+    }
+  }
+
   @Override
   public ListenableFuture<SearchResponse> searchAsync(@Nonnull SearchParam searchParam) {
 
@@ -988,11 +1078,10 @@ public class MilvusGrpcClient implements MilvusClient {
   }
 
   @Override
-  public ShowCollectionInfoResponse showCollectionInfo(String collectionName) {
+  public Response showCollectionInfo(String collectionName) {
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
-      return new ShowCollectionInfoResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
+      return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
 
     CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
@@ -1002,82 +1091,56 @@ public class MilvusGrpcClient implements MilvusClient {
       response = blockingStub.showCollectionInfo(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
-
-        List<CollectionInfo.PartitionInfo> partitionInfos = new ArrayList<>();
-
-        for (PartitionStat partitionStat : response.getPartitionsStatList()) {
-
-          List<CollectionInfo.PartitionInfo.SegmentInfo> segmentInfos = new ArrayList<>();
-
-          for (SegmentStat segmentStat : partitionStat.getSegmentsStatList()) {
-
-            CollectionInfo.PartitionInfo.SegmentInfo segmentInfo =
-                new CollectionInfo.PartitionInfo.SegmentInfo(
-                    segmentStat.getSegmentName(),
-                    segmentStat.getRowCount(),
-                    segmentStat.getIndexName(),
-                    segmentStat.getDataSize());
-            segmentInfos.add(segmentInfo);
-          }
-
-          CollectionInfo.PartitionInfo partitionInfo =
-              new CollectionInfo.PartitionInfo(
-                  partitionStat.getTag(), partitionStat.getTotalRowCount(), segmentInfos);
-          partitionInfos.add(partitionInfo);
-        }
-
-        CollectionInfo collectionInfo =
-            new CollectionInfo(response.getTotalRowCount(), partitionInfos);
-
         logInfo("ShowCollectionInfo for `{0}` returned successfully!", collectionName);
-        return new ShowCollectionInfoResponse(
-            new Response(Response.Status.SUCCESS), collectionInfo);
+        return new Response(Response.Status.SUCCESS, response.getJsonInfo());
       } else {
         logSevere(
             "ShowCollectionInfo for `{0}` failed:\n{1}",
             collectionName, response.getStatus().toString());
-        return new ShowCollectionInfoResponse(
-            new Response(
+        return new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
-                response.getStatus().getReason()),
-            null);
+                response.getStatus().getReason());
       }
     } catch (StatusRuntimeException e) {
-      logSevere("describeIndex RPC failed:\n{0}", e.getStatus().toString());
-      return new ShowCollectionInfoResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), null);
+      logSevere("showCollectionInfo RPC failed:\n{0}", e.getStatus().toString());
+      return new Response(Response.Status.RPC_ERROR, e.toString());
     }
   }
 
   @Override
-  public GetVectorByIdResponse getVectorById(String collectionName, Long id) {
+  public GetVectorsByIdsResponse getVectorsByIds(String collectionName, List<Long> ids) {
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
-      return new GetVectorByIdResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>(), null);
+      return new GetVectorsByIdsResponse(
+              new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>(), null);
     }
 
-    VectorIdentity request =
-        VectorIdentity.newBuilder().setCollectionName(collectionName).setId(id).build();
-    VectorData response;
+    VectorsIdentity request =
+        VectorsIdentity.newBuilder().setCollectionName(collectionName).addAllIdArray(ids).build();
+    VectorsData response;
 
     try {
-      response = blockingStub.getVectorByID(request);
+      response = blockingStub.getVectorsByID(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
 
         logInfo(
-            "getVectorById for id={0} in collection `{1}` returned successfully!",
-            String.valueOf(id), collectionName);
-        return new GetVectorByIdResponse(
-            new Response(Response.Status.SUCCESS),
-            response.getVectorData().getFloatDataList(),
-            response.getVectorData().getBinaryData().asReadOnlyByteBuffer());
+            "getVectorsByIds in collection `{0}` returned successfully!", collectionName);
+
+        List<List<Float>> floatVectors = new ArrayList<>();
+        List<ByteBuffer> binaryVectors = new ArrayList<>();
+        for (int i = 0; i < ids.size(); i++) {
+          floatVectors.add(response.getVectorsData(i).getFloatDataList());
+          binaryVectors.add(response.getVectorsData(i).getBinaryData().asReadOnlyByteBuffer());
+        }
+        return new GetVectorsByIdsResponse(
+                new Response(Response.Status.SUCCESS), floatVectors, binaryVectors);
+
       } else {
         logSevere(
-            "getVectorById for `{0}` in collection `{1}` failed:\n{2}",
-            String.valueOf(id), collectionName, response.getStatus().toString());
-        return new GetVectorByIdResponse(
+            "getVectorsByIds in collection `{0}` failed:\n{1}",
+            collectionName, response.getStatus().toString());
+        return new GetVectorsByIdsResponse(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 response.getStatus().getReason()),
@@ -1085,8 +1148,8 @@ public class MilvusGrpcClient implements MilvusClient {
             null);
       }
     } catch (StatusRuntimeException e) {
-      logSevere("getVectorById RPC failed:\n{0}", e.getStatus().toString());
-      return new GetVectorByIdResponse(
+      logSevere("getVectorsByIds RPC failed:\n{0}", e.getStatus().toString());
+      return new GetVectorsByIdsResponse(
           new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>(), null);
     }
   }
@@ -1369,9 +1432,17 @@ public class MilvusGrpcClient implements MilvusClient {
 
     List<List<Long>> resultIdsList = new ArrayList<>();
     List<List<Float>> resultDistancesList = new ArrayList<>();
+
     if (topK > 0) {
-      resultIdsList = ListUtils.partition(topKQueryResult.getIdsList(), topK);
-      resultDistancesList = ListUtils.partition(topKQueryResult.getDistancesList(), topK);
+      for (int i = 0; i < numQueries; i++) {
+        // Process result of query i
+        int pos = i * topK;
+        while (pos < i * topK + topK && topKQueryResult.getIdsList().get(pos) != -1) {
+          pos++;
+        }
+        resultIdsList.add(topKQueryResult.getIdsList().subList(i * topK, pos));
+        resultDistancesList.add(topKQueryResult.getDistancesList().subList(i * topK, pos));
+      }
     }
 
     SearchResponse searchResponse = new SearchResponse();

+ 145 - 0
src/main/java/io/milvus/client/SearchByIdsParam.java

@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.client;
+
+import javax.annotation.Nonnull;
+import java.util.ArrayList;
+import java.util.List;
+
+/** Contains parameters for <code>searchByIds</code> */
+public class SearchByIdsParam {
+
+    private final String collectionName;
+    private final List<String> partitionTags;
+    private final List<Long> ids;
+    private final long topK;
+    private final String paramsInJson;
+
+    private SearchByIdsParam(@Nonnull Builder builder) {
+        this.collectionName = builder.collectionName;
+        this.partitionTags = builder.partitionTags;
+        this.ids = builder.ids;
+        this.topK = builder.topK;
+        this.paramsInJson = builder.paramsInJson;
+    }
+
+    public String getCollectionName() {
+        return collectionName;
+    }
+
+    public List<String> getPartitionTags() {
+        return partitionTags;
+    }
+
+    public List<Long> getIds() {
+        return ids;
+    }
+
+    public long getTopK() {
+        return topK;
+    }
+
+    public String getParamsInJson() {
+        return paramsInJson;
+    }
+
+    /** Builder for <code>SearchByIdsParam</code> */
+    public static class Builder {
+        // Required parameters
+        private final String collectionName;
+
+        // Optional parameters - initialized to default values
+        private List<String> partitionTags = new ArrayList<>();
+        private List<Long> ids = new ArrayList<>();
+        private long topK = 1024;
+        private String paramsInJson;
+
+        /** @param collectionName collection to search from */
+        public Builder(@Nonnull String collectionName) {
+            this.collectionName = collectionName;
+        }
+
+        /**
+         * Search vectors IDs. Default to an empty <code>List</code>
+         *
+         * @param ids IDs of vectors
+         * @return <code>Builder</code>
+         */
+        public Builder withIDs(@Nonnull List<Long> ids) {
+            this.ids = ids;
+            return this;
+        }
+
+        /**
+         * Optional. Search vectors with corresponding <code>partitionTags</code>. Default to an empty
+         * <code>List</code>
+         *
+         * @param partitionTags a <code>List</code> of partition tags
+         * @return <code>Builder</code>
+         */
+        public Builder withPartitionTags(@Nonnull List<String> partitionTags) {
+            this.partitionTags = partitionTags;
+            return this;
+        }
+
+        /**
+         * Optional. Limits search result to <code>topK</code>. Default to 1024.
+         *
+         * @param topK a topK number
+         * @return <code>Builder</code>
+         */
+        public Builder withTopK(long topK) {
+            this.topK = topK;
+            return this;
+        }
+
+        /**
+         * Optional. Default to empty <code>String</code>. Search parameters are different for different
+         * index types. Refer to <a
+         * href="https://milvus.io/docs/v0.8.0/guides/milvus_operation.md">https://milvus.io/docs/v0.8.0/guides/milvus_operation.md</a>
+         * for more information.
+         *
+         * <pre>
+         *   FLAT/IVFLAT/SQ8/IVFPQ: {"nprobe": 32}
+         *   nprobe range:[1,999999]
+         *
+         *   NSG: {"search_length": 100}
+         *   search_length range:[10, 300]
+         *
+         *   HNSW: {"ef": 64}
+         *   ef range:[topk, 4096]
+         *
+         *   ANNOY: {search_k", 0.05 * totalDataCount}
+         *   search_k range: none
+         * </pre>
+         *
+         * @param paramsInJson extra parameters in JSON format
+         * @return <code>Builder</code>
+         */
+        public SearchByIdsParam.Builder withParamsInJson(@Nonnull String paramsInJson) {
+            this.paramsInJson = paramsInJson;
+            return this;
+        }
+
+        public SearchByIdsParam build() {
+            return new SearchByIdsParam(this);
+        }
+    }
+}

+ 1 - 1
src/main/java/io/milvus/client/SearchResponse.java

@@ -60,7 +60,7 @@ public class SearchResponse {
     return IntStream.range(0, numQueries)
         .mapToObj(
             i ->
-                LongStream.range(0, topK)
+                LongStream.range(0, resultIdsList.get(i).size())
                     .mapToObj(
                         j ->
                             new QueryResult(

+ 20 - 31
src/main/proto/milvus.proto

@@ -112,7 +112,7 @@ message SearchInFilesParam {
 message SearchByIDParam {
     string collection_name = 1;
     repeated string partition_tag_array = 2;
-    int64 id = 3;
+    repeated int64 id_array = 3;
     int64 topk = 4;
     repeated KeyValuePair extra_params = 5;
 }
@@ -184,48 +184,28 @@ message DeleteByIDParam {
     repeated int64 id_array = 2;
 }
 
-/**
- * @brief segment statistics
- */
-message SegmentStat {
-    string segment_name = 1;
-    int64 row_count = 2;
-    string index_name = 3;
-    int64 data_size = 4;
-}
-
-/**
- * @brief collection statistics
- */
-message PartitionStat {
-    string tag = 1;
-    int64 total_row_count = 2;
-    repeated SegmentStat segments_stat = 3;
-}
-
 /**
  * @brief collection information
  */
 message CollectionInfo {
     Status status = 1;
-    int64 total_row_count = 2;
-    repeated PartitionStat partitions_stat = 3;
+    string json_info = 2;
 }
 
 /**
- * @brief vector identity
+ * @brief vectors identity
  */
-message VectorIdentity {
+message VectorsIdentity {
     string collection_name = 1;
-    int64 id = 2;
+    repeated int64 id_array = 2;
 }
 
 /**
  * @brief vector data
  */
-message VectorData {
+message VectorsData {
     Status status = 1;
-    RowRecord vector_data = 2;
+    repeated RowRecord vectors_data = 2;
 }
 
 /**
@@ -336,6 +316,15 @@ service MilvusService {
      */
     rpc CreatePartition(PartitionParam) returns (Status) {}
 
+    /**
+     * @brief This method is used to test partition existence.
+     *
+     * @param PartitionParam, target partition.
+     *
+     * @return BoolReply
+     */
+    rpc HasPartition(PartitionParam) returns (BoolReply) {}
+
     /**
      * @brief This method is used to show partition information
      *
@@ -364,13 +353,13 @@ service MilvusService {
     rpc Insert(InsertParam) returns (VectorIds) {}
 
     /**
-     * @brief This method is used to get vector data by id.
+     * @brief This method is used to get vectors data by id array.
      *
-     * @param VectorIdentity, target vector id.
+     * @param VectorsIdentity, target vector id array.
      *
-     * @return VectorData
+     * @return VectorsData
      */
-    rpc GetVectorByID(VectorIdentity) returns (VectorData) {}
+    rpc GetVectorsByID(VectorsIdentity) returns (VectorsData) {}
 
     /**
      * @brief This method is used to get vector ids from a segment

+ 116 - 77
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -21,6 +21,7 @@ package io.milvus.client;
 
 import com.google.common.util.concurrent.ListenableFuture;
 import org.apache.commons.text.RandomStringGenerator;
+import org.json.*;
 
 import java.nio.ByteBuffer;
 import java.util.*;
@@ -268,9 +269,15 @@ class MilvusClientTest {
 
     assertTrue(Collections.disjoint(resultIdsList1, resultIdsList2));
 
+    HasPartitionResponse testHasPartition = client.hasPartition(randomCollectionName, tag1);
+    assertTrue(testHasPartition.hasPartition());
+
     Response dropPartitionResponse = client.dropPartition(randomCollectionName, tag1);
     assertTrue(dropPartitionResponse.ok());
 
+    testHasPartition = client.hasPartition(randomCollectionName, tag1);
+    assertFalse(testHasPartition.hasPartition());
+
     dropPartitionResponse = client.dropPartition(randomCollectionName, tag2);
     assertTrue(dropPartitionResponse.ok());
   }
@@ -379,6 +386,7 @@ class MilvusClientTest {
     assertEquals(searchSize, resultDistancesList.size());
     List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
     assertEquals(searchSize, queryResultsList.size());
+
     final double epsilon = 0.001;
     for (int i = 0; i < searchSize; i++) {
       SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
@@ -389,6 +397,45 @@ class MilvusClientTest {
     }
   }
 
+  @org.junit.jupiter.api.Test
+  void searchByIds() {
+    List<List<Float>> vectors = generateFloatVectors(size, dimension);
+    vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
+    InsertParam insertParam =
+            new InsertParam.Builder(randomCollectionName).withFloatVectors(vectors).build();
+    InsertResponse insertResponse = client.insert(insertParam);
+    assertTrue(insertResponse.ok());
+    List<Long> vectorIds = insertResponse.getVectorIds();
+    assertEquals(size, vectorIds.size());
+
+    assertTrue(client.flush(randomCollectionName).ok());
+
+    final long topK = 10;
+    final int queryLength = 5;
+    SearchByIdsParam searchByIdsParam =
+            new SearchByIdsParam.Builder(randomCollectionName)
+                    .withIDs(vectorIds.subList(0, queryLength))
+                    .withTopK(topK)
+                    .withParamsInJson("{\"nprobe\": 20}")
+                    .build();
+    SearchResponse searchResponse = client.searchByIds(searchByIdsParam);
+    assertTrue(searchResponse.ok());
+    List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
+    assertEquals(queryLength, resultIdsList.size());
+    List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
+    assertEquals(queryLength, resultDistancesList.size());
+    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
+    assertEquals(queryLength, queryResultsList.size());
+    final double epsilon = 0.001;
+    for (int i = 0; i < queryLength; i++) {
+      SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+      assertEquals(vectorIds.get(i), firstQueryResult.getVectorId());
+      assertEquals(vectorIds.get(i), resultIdsList.get(i).get(0));
+      assertTrue(Math.abs(1 - firstQueryResult.getDistance()) < epsilon);
+      assertTrue(Math.abs(1 - resultDistancesList.get(i).get(0)) < epsilon);
+    }
+  }
+
   @org.junit.jupiter.api.Test
   void searchAsync() throws ExecutionException, InterruptedException {
     List<List<Float>> vectors = generateFloatVectors(size, dimension);
@@ -506,7 +553,7 @@ class MilvusClientTest {
   void showCollections() {
     ShowCollectionsResponse showCollectionsResponse = client.showCollections();
     assertTrue(showCollectionsResponse.ok());
-    assertEquals(showCollectionsResponse.getCollectionNames().get(0), randomCollectionName);
+    assertTrue(showCollectionsResponse.getCollectionNames().contains(randomCollectionName));
   }
 
   @org.junit.jupiter.api.Test
@@ -564,25 +611,27 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    ShowCollectionInfoResponse showCollectionInfoResponse =
+    Response showCollectionInfoResponse =
         client.showCollectionInfo(randomCollectionName);
     assertTrue(showCollectionInfoResponse.ok());
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
 
-    CollectionInfo collectionInfo = showCollectionInfoResponse.getCollectionInfo().get();
-    assertEquals(collectionInfo.getRowCount(), size);
+    String jsonString = showCollectionInfoResponse.getMessage();
+    JSONObject jsonInfo = new JSONObject(jsonString);
+    assertTrue(jsonInfo.getInt("row_count") == size);
 
-    CollectionInfo.PartitionInfo partitionInfo = collectionInfo.getPartitionInfos().get(0);
-    assertEquals(partitionInfo.getTag(), "_default");
-    assertEquals(partitionInfo.getRowCount(), size);
+    JSONArray partitions = jsonInfo.getJSONArray("partitions");
+    JSONObject partitionInfo = partitions.getJSONObject(0);
+    assertEquals(partitionInfo.getString("tag"), "_default");
+    assertEquals(partitionInfo.getInt("row_count"), size);
 
-    CollectionInfo.PartitionInfo.SegmentInfo segmentInfo = partitionInfo.getSegmentInfos().get(0);
-    assertEquals(segmentInfo.getRowCount(), size);
-    assertEquals(segmentInfo.getIndexName(), "IDMAP");
+    JSONArray segments = partitionInfo.getJSONArray("segments");
+    JSONObject segmentInfo = segments.getJSONObject(0);
+    assertEquals(segmentInfo.getString("index_name"), "IDMAP");
+    assertEquals(segmentInfo.getInt("row_count"), size);
   }
 
   @org.junit.jupiter.api.Test
-  void getVectorById() {
+  void getVectorsByIds() {
     List<List<Float>> vectors = generateFloatVectors(size, dimension);
     InsertParam insertParam =
         new InsertParam.Builder(randomCollectionName).withFloatVectors(vectors).build();
@@ -593,13 +642,13 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    GetVectorByIdResponse getVectorByIdResponse =
-        client.getVectorById(randomCollectionName, vectorIds.get(0));
-    assertTrue(getVectorByIdResponse.ok());
-    assertTrue(getVectorByIdResponse.exists());
-    assertTrue(getVectorByIdResponse.isFloatVector());
-    assertFalse(getVectorByIdResponse.isBinaryVector());
-    assertArrayEquals(getVectorByIdResponse.getFloatVector().toArray(), vectors.get(0).toArray());
+    GetVectorsByIdsResponse getVectorsByIdsResponse =
+        client.getVectorsByIds(randomCollectionName, vectorIds.subList(0, 100));
+    assertTrue(getVectorsByIdsResponse.ok());
+    ByteBuffer bb = getVectorsByIdsResponse.getBinaryVectors().get(0);
+    assertTrue(bb == null || bb.remaining() == 0);
+
+    assertArrayEquals(getVectorsByIdsResponse.getFloatVectors().get(0).toArray(), vectors.get(0).toArray());
   }
 
   @org.junit.jupiter.api.Test
@@ -608,21 +657,19 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    ShowCollectionInfoResponse showCollectionInfoResponse =
+    Response showCollectionInfoResponse =
         client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
+    assertTrue(showCollectionInfoResponse.ok());
 
-    CollectionInfo.PartitionInfo.SegmentInfo segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
+    JSONObject jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    JSONObject segmentInfo = jsonInfo
+                                 .getJSONArray("partitions")
+                                 .getJSONObject(0)
+                                 .getJSONArray("segments")
+                                 .getJSONObject(0);
 
     GetVectorIdsResponse getVectorIdsResponse =
-        client.getVectorIds(randomCollectionName, segmentInfo.getSegmentName());
+        client.getVectorIds(randomCollectionName,segmentInfo.getString("name"));
     assertTrue(getVectorIdsResponse.ok());
     assertFalse(getVectorIdsResponse.getIds().isEmpty());
   }
@@ -684,36 +731,32 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    ShowCollectionInfoResponse showCollectionInfoResponse =
+    Response showCollectionInfoResponse =
         client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
-    CollectionInfo.PartitionInfo.SegmentInfo segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
-    long previousSegmentSize = segmentInfo.getDataSize();
+    assertTrue(showCollectionInfoResponse.ok());
 
-    assertTrue(client.deleteByIds(randomCollectionName, vectorIds.subList(0, 100)).ok());
+    JSONObject jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    JSONObject segmentInfo = jsonInfo
+            .getJSONArray("partitions")
+            .getJSONObject(0)
+            .getJSONArray("segments")
+            .getJSONObject(0);
 
-    assertTrue(client.flush(randomCollectionName).ok());
+    long previousSegmentSize = segmentInfo.getLong("data_size");
 
+    assertTrue(client.deleteByIds(randomCollectionName, vectorIds.subList(0, 100)).ok());
+    assertTrue(client.flush(randomCollectionName).ok());
     assertTrue(client.compact(randomCollectionName).ok());
 
     showCollectionInfoResponse = client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
-    segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
-    long currentSegmentSize = segmentInfo.getDataSize();
+    assertTrue(showCollectionInfoResponse.ok());
+    jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    segmentInfo = jsonInfo
+            .getJSONArray("partitions")
+            .getJSONObject(0)
+            .getJSONArray("segments")
+            .getJSONObject(0);
+    long currentSegmentSize = segmentInfo.getLong("data_size");
 
     assertTrue(currentSegmentSize < previousSegmentSize);
   }
@@ -730,36 +773,32 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    ShowCollectionInfoResponse showCollectionInfoResponse =
-        client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
-    CollectionInfo.PartitionInfo.SegmentInfo segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
-    long previousSegmentSize = segmentInfo.getDataSize();
+    Response showCollectionInfoResponse =
+            client.showCollectionInfo(randomCollectionName);
+    assertTrue(showCollectionInfoResponse.ok());
 
-    assertTrue(client.deleteByIds(randomCollectionName, vectorIds.subList(0, 100)).ok());
+    JSONObject jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    JSONObject segmentInfo = jsonInfo
+            .getJSONArray("partitions")
+            .getJSONObject(0)
+            .getJSONArray("segments")
+            .getJSONObject(0);
 
-    assertTrue(client.flush(randomCollectionName).ok());
+    long previousSegmentSize = segmentInfo.getLong("data_size");
 
+    assertTrue(client.deleteByIds(randomCollectionName, vectorIds.subList(0, 100)).ok());
+    assertTrue(client.flush(randomCollectionName).ok());
     assertTrue(client.compactAsync(randomCollectionName).get().ok());
 
     showCollectionInfoResponse = client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
-    segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
-    long currentSegmentSize = segmentInfo.getDataSize();
+    assertTrue(showCollectionInfoResponse.ok());
+    jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    segmentInfo = jsonInfo
+            .getJSONArray("partitions")
+            .getJSONObject(0)
+            .getJSONArray("segments")
+            .getJSONObject(0);
+    long currentSegmentSize = segmentInfo.getLong("data_size");
 
     assertTrue(currentSegmentSize < previousSegmentSize);
   }