ソースを参照

Add new APIs

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>
sahuang 5 年 前
コミット
a39cac7b7f

+ 1 - 1
examples/pom.xml

@@ -63,7 +63,7 @@
         <dependency>
             <groupId>io.milvus</groupId>
             <artifactId>milvus-sdk-java</artifactId>
-            <version>0.7.0</version>
+            <version>0.8.0-SNAPSHOT</version>
         </dependency>
         <dependency>
             <groupId>com.google.code.gson</groupId>

+ 41 - 3
examples/src/main/java/MilvusClientExample.java

@@ -140,6 +140,21 @@ public class MilvusClientExample {
     // Describe the index for your collection
     DescribeIndexResponse describeIndexResponse = client.describeIndex(collectionName);
 
+    // Get collection info
+    Response showCollectionInfoResponse = client.showCollectionInfo(collectionName);
+    if (showCollectionInfoResponse.ok()) {
+      // Collection info is sent back with JSON type string
+      String jsonString = showCollectionInfoResponse.getMessage();
+      System.out.println(jsonString);
+    }
+
+    // Check whether a partition exists in collection
+    // Obviously we do not have partition "tag" now
+    HasPartitionResponse testHasPartition = client.hasPartition(collectionName, "tag");
+    if (testHasPartition.ok() && testHasPartition.hasPartition()) {
+      throw new AssertionError("Wrong results!");
+    }
+
     // Search vectors
     // Searching the first 5 vectors of the vectors we just inserted
     final int searchBatchSize = 5;
@@ -175,6 +190,29 @@ public class MilvusClientExample {
     List<List<Long>> resultIds = searchResponse.getResultIdsList();
     List<List<Float>> resultDistances = searchResponse.getResultDistancesList();
 
+    // SearchByIDs
+    // Searching the first 5 vectors of the vectors we just inserted by ID
+    SearchByIDParam searchByIDParam =
+            new SearchByIDParam.Builder(collectionName)
+                    .withIDs(vectorIds.subList(0, searchBatchSize))
+                    .withTopK(topK)
+                    .withParamsInJson(searchParamsJson.toString())
+                    .build();
+    SearchResponse searchByIDResponse = client.searchByID(searchByIDParam);
+    if (searchByIDResponse.ok()) {
+      List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
+      List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
+      List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
+      final double epsilon = 0.001;
+      for (int i = 0; i < searchBatchSize; i++) {
+        SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+        if (firstQueryResult.getVectorId() != vectorIds.get(i)
+                || Math.abs(1 - firstQueryResult.getDistance()) > epsilon) {
+          throw new AssertionError("Wrong results!");
+        }
+      }
+    }
+
     // You can send search request asynchronously, which returns a ListenableFuture object
     ListenableFuture<SearchResponse> searchResponseFuture = client.searchAsync(searchParam);
     try {
@@ -193,10 +231,10 @@ public class MilvusClientExample {
     flushResponse = client.flush(collectionName);
 
     // Try to get the corresponding vector of the first id you just deleted.
-    GetVectorByIdResponse getVectorByIdResponse =
-        client.getVectorById(collectionName, vectorIds.get(0));
+    List<GetVectorByIdResponse> getVectorByIdResponse =
+        client.getVectorsById(collectionName, vectorIds.subList(0, searchBatchSize));
     // Obviously you won't get anything
-    if (getVectorByIdResponse.exists()) {
+    if (getVectorByIdResponse.get(0).exists()) {
       throw new AssertionError("This can never happen!");
     }
 

+ 13 - 1
pom.xml

@@ -25,7 +25,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>0.7.0</version>
+    <version>0.8.0-SNAPSHOT</version>
     <packaging>jar</packaging>
 
     <name>io.milvus:milvus-sdk-java</name>
@@ -47,6 +47,12 @@
             <organization>Milvus</organization>
             <organizationUrl>http://www.milvus.io</organizationUrl>
         </developer>
+        <developer>
+            <name>Xiaohai Xu</name>
+            <email>xiaohai.xu@zilliz.com</email>
+            <organization>Milvus</organization>
+            <organizationUrl>http://www.milvus.io</organizationUrl>
+        </developer>
     </developers>
 
     <scm>
@@ -139,6 +145,12 @@
             <artifactId>commons-collections4</artifactId>
             <version>4.4</version>
         </dependency>
+        <dependency>
+            <groupId>org.json</groupId>
+            <artifactId>json</artifactId>
+            <version>20190722</version>
+        </dependency>
+
     </dependencies>
 
     <profiles>

+ 1 - 1
src/main/java/io/milvus/client/GetVectorByIdResponse.java

@@ -6,7 +6,7 @@ import java.util.Optional;
 
 /**
  * Contains the returned <code>response</code> and either a <code>floatVector</code> or a <code>
- * binaryVector</code> for <code>getVectorById</code>. If the id does not exist, both returned
+ * binaryVector</code> for each vector of <code>getVectorsById</code>. If the id does not exist, both returned
  * vectors will be empty.
  */
 public class GetVectorByIdResponse {

+ 53 - 0
src/main/java/io/milvus/client/HasPartitionResponse.java

@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.client;
+
+/**
+ * Contains the returned <code>response</code> and <code>hasPartition</code> for <code>
+ * hasPartition</code>
+ */
+public class HasPartitionResponse {
+    private final Response response;
+    private final boolean hasPartition;
+
+    HasPartitionResponse(Response response, boolean hasPartition) {
+        this.response = response;
+        this.hasPartition = hasPartition;
+    }
+
+    public boolean hasPartition() {
+        return hasPartition;
+    }
+
+    public Response getResponse() {
+        return response;
+    }
+
+    /** @return <code>true</code> if the response status equals SUCCESS */
+    public boolean ok() {
+        return response.ok();
+    }
+
+    @Override
+    public String toString() {
+        return String.format(
+                "HasPartitionResponse {%s, has partition = %s}", response.toString(), hasPartition);
+    }
+}

+ 44 - 13
src/main/java/io/milvus/client/MilvusClient.java

@@ -26,9 +26,9 @@ import java.util.List;
 /** The Milvus Client Interface */
 public interface MilvusClient {
 
-  String clientVersion = "0.7.0";
+  String clientVersion = "0.8.0";
 
-  /** @return current Milvus client version: 0.7.0 */
+  /** @return current Milvus client version: 0.8.0 */
   default String getClientVersion() {
     return clientVersion;
   }
@@ -166,6 +166,16 @@ public interface MilvusClient {
    */
   Response createPartition(String collectionName, String tag);
 
+  /**
+   * Checks whether the partition exists
+   *
+   * @param collectionName collection name
+   * @param tag partition tag
+   * @return <code>HasPartitionResponse</code>
+   * @see Response
+   */
+  HasPartitionResponse hasPartition(String collectionName, String tag);
+
   /**
    * Shows current partitions of a collection
    *
@@ -254,6 +264,30 @@ public interface MilvusClient {
    */
   SearchResponse search(SearchParam searchParam);
 
+  /**
+   * Searches vectors specified by <code>searchByIDParam</code>
+   *
+   * @param searchByIDParam the <code>SearchByIDParam</code> object
+   *     <pre>
+   * example usage:
+   * <code>
+   * SearchByIDParam searchByIDParam = new SearchByIDParam.Builder(collectionName)
+   *                                          .withIDs(ids)
+   *                                          .withTopK(topK)
+   *                                          .withPartitionTags(partitionTagsList)
+   *                                          .withParamsInJson("{\"nprobe\": 20}")
+   *                                          .build();
+   * </code>
+   * </pre>
+   *
+   * @return <code>SearchResponse</code>
+   * @see SearchByIDParam
+   * @see SearchResponse
+   * @see SearchResponse.QueryResult
+   * @see Response
+   */
+  SearchResponse searchByID(SearchByIDParam searchByIDParam);
+
   /**
    * Searches vectors specified by <code>searchParam</code> asynchronously
    *
@@ -389,27 +423,24 @@ public interface MilvusClient {
    * Shows collection information. A collection consists of one or multiple partitions (including
    * the default partition), and a partitions consists of one or more segments. Each partition or
    * segment can be uniquely identified by its partition tag or segment name respectively.
+   * The result will be returned as JSON string.
    *
    * @param collectionName collection to show info from
-   * @return <code>ShowCollectionInfoResponse</code>
-   * @see ShowCollectionInfoResponse
-   * @see CollectionInfo
-   * @see CollectionInfo.PartitionInfo
-   * @see CollectionInfo.PartitionInfo.SegmentInfo
+   * @return <code>Response</code>
    * @see Response
    */
-  ShowCollectionInfoResponse showCollectionInfo(String collectionName);
+  Response showCollectionInfo(String collectionName);
 
   /**
-   * Gets either a float or binary vector by id.
+   * Gets vectors data by id array
    *
-   * @param collectionName collection to get vector from
-   * @param id vector id
-   * @return <code>GetVectorByIdResponse</code>
+   * @param collectionName collection to get vectors from
+   * @param ids a <code>List</code> of vector ids
+   * @return <code>List<GetVectorByIdResponse></code>
    * @see GetVectorByIdResponse
    * @see Response
    */
-  GetVectorByIdResponse getVectorById(String collectionName, Long id);
+  List<GetVectorByIdResponse> getVectorsById(String collectionName, List<Long> ids);
 
   /**
    * Gets all vector ids in a segment

+ 125 - 59
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -367,6 +367,40 @@ public class MilvusGrpcClient implements MilvusClient {
     }
   }
 
+  @Override
+  public HasPartitionResponse hasPartition(String collectionName, String tag) {
+
+    if (!channelIsReadyOrIdle()) {
+      logWarning("You are not connected to Milvus server");
+      return new HasPartitionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
+    }
+
+    PartitionParam request =
+        PartitionParam.newBuilder().setCollectionName(collectionName).setTag(tag).build();
+    BoolReply response;
+
+    try {
+      response = blockingStub.hasPartition(request);
+
+      if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
+        logInfo("hasPartition with tag `{0}` in `{1}` = {2}", tag, collectionName, response.getBoolReply());
+        return new HasPartitionResponse(
+                new Response(Response.Status.SUCCESS), response.getBoolReply());
+      } else {
+        logSevere("hasPartition with tag `{0}` in `{1}` failed:\n{2}", tag, collectionName, response.toString());
+        return new HasPartitionResponse(
+                new Response(
+                        Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
+                        response.getStatus().getReason()),
+                false);
+      }
+    } catch (StatusRuntimeException e) {
+      logSevere("hasPartition RPC failed:\n{0}", e.getStatus().toString());
+      return new HasPartitionResponse(
+              new Response(Response.Status.RPC_ERROR, e.toString()), false);
+    }
+  }
+
   @Override
   public ShowPartitionsResponse showPartitions(String collectionName) {
 
@@ -599,6 +633,62 @@ public class MilvusGrpcClient implements MilvusClient {
     }
   }
 
+  @Override
+  public SearchResponse searchByID(@Nonnull SearchByIDParam searchByIDParam) {
+
+    if (!channelIsReadyOrIdle()) {
+      logWarning("You are not connected to Milvus server");
+      SearchResponse searchResponse = new SearchResponse();
+      searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
+      return searchResponse;
+    }
+
+    List<Long> idList = searchByIDParam.getIds();
+
+    KeyValuePair extraParam =
+            KeyValuePair.newBuilder()
+                    .setKey(extraParamKey)
+                    .setValue(searchByIDParam.getParamsInJson())
+                    .build();
+
+    io.milvus.grpc.SearchByIDParam request =
+            io.milvus.grpc.SearchByIDParam.newBuilder()
+                    .setCollectionName(searchByIDParam.getCollectionName())
+                    .addAllIdArray(idList)
+                    .addAllPartitionTagArray(searchByIDParam.getPartitionTags())
+                    .setTopk(searchByIDParam.getTopK())
+                    .addExtraParams(extraParam)
+                    .build();
+
+    TopKQueryResult response;
+
+    try {
+      response = blockingStub.searchByID(request);
+
+      if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
+        SearchResponse searchResponse = buildSearchResponse(response);
+        searchResponse.setResponse(new Response(Response.Status.SUCCESS));
+        logInfo(
+                "Search completed successfully! Returned results for {0} queries",
+                searchResponse.getNumQueries());
+        return searchResponse;
+      } else {
+        logSevere("Search failed:\n{0}", response.getStatus().toString());
+        SearchResponse searchResponse = new SearchResponse();
+        searchResponse.setResponse(
+                new Response(
+                        Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
+                        response.getStatus().getReason()));
+        return searchResponse;
+      }
+    } catch (StatusRuntimeException e) {
+      logSevere("search RPC failed:\n{0}", e.getStatus().toString());
+      SearchResponse searchResponse = new SearchResponse();
+      searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
+      return searchResponse;
+    }
+  }
+
   @Override
   public ListenableFuture<SearchResponse> searchAsync(@Nonnull SearchParam searchParam) {
 
@@ -988,11 +1078,10 @@ public class MilvusGrpcClient implements MilvusClient {
   }
 
   @Override
-  public ShowCollectionInfoResponse showCollectionInfo(String collectionName) {
+  public Response showCollectionInfo(String collectionName) {
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
-      return new ShowCollectionInfoResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
+      return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
 
     CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
@@ -1002,92 +1091,69 @@ public class MilvusGrpcClient implements MilvusClient {
       response = blockingStub.showCollectionInfo(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
-
-        List<CollectionInfo.PartitionInfo> partitionInfos = new ArrayList<>();
-
-        for (PartitionStat partitionStat : response.getPartitionsStatList()) {
-
-          List<CollectionInfo.PartitionInfo.SegmentInfo> segmentInfos = new ArrayList<>();
-
-          for (SegmentStat segmentStat : partitionStat.getSegmentsStatList()) {
-
-            CollectionInfo.PartitionInfo.SegmentInfo segmentInfo =
-                new CollectionInfo.PartitionInfo.SegmentInfo(
-                    segmentStat.getSegmentName(),
-                    segmentStat.getRowCount(),
-                    segmentStat.getIndexName(),
-                    segmentStat.getDataSize());
-            segmentInfos.add(segmentInfo);
-          }
-
-          CollectionInfo.PartitionInfo partitionInfo =
-              new CollectionInfo.PartitionInfo(
-                  partitionStat.getTag(), partitionStat.getTotalRowCount(), segmentInfos);
-          partitionInfos.add(partitionInfo);
-        }
-
-        CollectionInfo collectionInfo =
-            new CollectionInfo(response.getTotalRowCount(), partitionInfos);
-
         logInfo("ShowCollectionInfo for `{0}` returned successfully!", collectionName);
-        return new ShowCollectionInfoResponse(
-            new Response(Response.Status.SUCCESS), collectionInfo);
+        return new Response(Response.Status.SUCCESS, response.getJsonInfo());
       } else {
         logSevere(
             "ShowCollectionInfo for `{0}` failed:\n{1}",
             collectionName, response.getStatus().toString());
-        return new ShowCollectionInfoResponse(
-            new Response(
+        return new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
-                response.getStatus().getReason()),
-            null);
+                response.getStatus().getReason());
       }
     } catch (StatusRuntimeException e) {
-      logSevere("describeIndex RPC failed:\n{0}", e.getStatus().toString());
-      return new ShowCollectionInfoResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), null);
+      logSevere("showCollectionInfo RPC failed:\n{0}", e.getStatus().toString());
+      return new Response(Response.Status.RPC_ERROR, e.toString());
     }
   }
 
   @Override
-  public GetVectorByIdResponse getVectorById(String collectionName, Long id) {
+  public List<GetVectorByIdResponse> getVectorsById(String collectionName, List<Long> ids) {
+    List<GetVectorByIdResponse> res = new ArrayList<>();
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
-      return new GetVectorByIdResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>(), null);
+      res.add(new GetVectorByIdResponse(
+              new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>(), null));
+      return res;
     }
 
-    VectorIdentity request =
-        VectorIdentity.newBuilder().setCollectionName(collectionName).setId(id).build();
-    VectorData response;
+    VectorsIdentity request =
+        VectorsIdentity.newBuilder().setCollectionName(collectionName).addAllIdArray(ids).build();
+    VectorsData response;
 
     try {
-      response = blockingStub.getVectorByID(request);
+      response = blockingStub.getVectorsByID(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
 
         logInfo(
-            "getVectorById for id={0} in collection `{1}` returned successfully!",
-            String.valueOf(id), collectionName);
-        return new GetVectorByIdResponse(
-            new Response(Response.Status.SUCCESS),
-            response.getVectorData().getFloatDataList(),
-            response.getVectorData().getBinaryData().asReadOnlyByteBuffer());
+            "getVectorsById in collection `{0}` returned successfully!", collectionName);
+
+        for (int i = 0; i < ids.size(); i++) {
+          res.add(new GetVectorByIdResponse(
+                  new Response(Response.Status.SUCCESS),
+                  response.getVectorsData(i).getFloatDataList(),
+                  response.getVectorsData(i).getBinaryData().asReadOnlyByteBuffer()));
+        }
+        return res;
+
       } else {
         logSevere(
-            "getVectorById for `{0}` in collection `{1}` failed:\n{2}",
-            String.valueOf(id), collectionName, response.getStatus().toString());
-        return new GetVectorByIdResponse(
+            "getVectorsById in collection `{0}` failed:\n{1}",
+            collectionName, response.getStatus().toString());
+        res.add(new GetVectorByIdResponse(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 response.getStatus().getReason()),
             new ArrayList<>(),
-            null);
+            null));
+        return res;
       }
     } catch (StatusRuntimeException e) {
-      logSevere("getVectorById RPC failed:\n{0}", e.getStatus().toString());
-      return new GetVectorByIdResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>(), null);
+      logSevere("getVectorsById RPC failed:\n{0}", e.getStatus().toString());
+      res.add(new GetVectorByIdResponse(
+          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>(), null));
+      return res;
     }
   }
 

+ 146 - 0
src/main/java/io/milvus/client/SearchByIDParam.java

@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.client;
+
+import javax.annotation.Nonnull;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+/** Contains parameters for <code>searchByID</code> */
+public class SearchByIDParam {
+
+    private final String collectionName;
+    private final List<String> partitionTags;
+    private final List<Long> ids;
+    private final long topK;
+    private final String paramsInJson;
+
+    private SearchByIDParam(@Nonnull Builder builder) {
+        this.collectionName = builder.collectionName;
+        this.partitionTags = builder.partitionTags;
+        this.ids = builder.ids;
+        this.topK = builder.topK;
+        this.paramsInJson = builder.paramsInJson;
+    }
+
+    public String getCollectionName() {
+        return collectionName;
+    }
+
+    public List<String> getPartitionTags() {
+        return partitionTags;
+    }
+
+    public List<Long> getIds() {
+        return ids;
+    }
+
+    public long getTopK() {
+        return topK;
+    }
+
+    public String getParamsInJson() {
+        return paramsInJson;
+    }
+
+    /** Builder for <code>SearchByIDParam</code> */
+    public static class Builder {
+        // Required parameters
+        private final String collectionName;
+
+        // Optional parameters - initialized to default values
+        private List<String> partitionTags = new ArrayList<>();
+        private List<Long> ids = new ArrayList<>();
+        private long topK = 1024;
+        private String paramsInJson;
+
+        /** @param collectionName collection to search from */
+        public Builder(@Nonnull String collectionName) {
+            this.collectionName = collectionName;
+        }
+
+        /**
+         * Search vectors IDs. Default to an empty <code>List</code>
+         *
+         * @param ids IDs of vectors
+         * @return <code>Builder</code>
+         */
+        public Builder withIDs(@Nonnull List<Long> ids) {
+            this.ids = ids;
+            return this;
+        }
+
+        /**
+         * Optional. Search vectors with corresponding <code>partitionTags</code>. Default to an empty
+         * <code>List</code>
+         *
+         * @param partitionTags a <code>List</code> of partition tags
+         * @return <code>Builder</code>
+         */
+        public Builder withPartitionTags(@Nonnull List<String> partitionTags) {
+            this.partitionTags = partitionTags;
+            return this;
+        }
+
+        /**
+         * Optional. Limits search result to <code>topK</code>. Default to 1024.
+         *
+         * @param topK a topK number
+         * @return <code>Builder</code>
+         */
+        public Builder withTopK(long topK) {
+            this.topK = topK;
+            return this;
+        }
+
+        /**
+         * Optional. Default to empty <code>String</code>. Search parameters are different for different
+         * index types. Refer to <a
+         * href="https://milvus.io/docs/v0.8.0/guides/milvus_operation.md">https://milvus.io/docs/v0.8.0/guides/milvus_operation.md</a>
+         * for more information.
+         *
+         * <pre>
+         *   FLAT/IVFLAT/SQ8/IVFPQ: {"nprobe": 32}
+         *   nprobe range:[1,999999]
+         *
+         *   NSG: {"search_length": 100}
+         *   search_length range:[10, 300]
+         *
+         *   HNSW: {"ef": 64}
+         *   ef range:[topk, 4096]
+         *
+         *   ANNOY: {search_k", 0.05 * totalDataCount}
+         *   search_k range: none
+         * </pre>
+         *
+         * @param paramsInJson extra parameters in JSON format
+         * @return <code>Builder</code>
+         */
+        public SearchByIDParam.Builder withParamsInJson(@Nonnull String paramsInJson) {
+            this.paramsInJson = paramsInJson;
+            return this;
+        }
+
+        public SearchByIDParam build() {
+            return new SearchByIDParam(this);
+        }
+    }
+}

+ 20 - 31
src/main/proto/milvus.proto

@@ -112,7 +112,7 @@ message SearchInFilesParam {
 message SearchByIDParam {
     string collection_name = 1;
     repeated string partition_tag_array = 2;
-    int64 id = 3;
+    repeated int64 id_array = 3;
     int64 topk = 4;
     repeated KeyValuePair extra_params = 5;
 }
@@ -184,48 +184,28 @@ message DeleteByIDParam {
     repeated int64 id_array = 2;
 }
 
-/**
- * @brief segment statistics
- */
-message SegmentStat {
-    string segment_name = 1;
-    int64 row_count = 2;
-    string index_name = 3;
-    int64 data_size = 4;
-}
-
-/**
- * @brief collection statistics
- */
-message PartitionStat {
-    string tag = 1;
-    int64 total_row_count = 2;
-    repeated SegmentStat segments_stat = 3;
-}
-
 /**
  * @brief collection information
  */
 message CollectionInfo {
     Status status = 1;
-    int64 total_row_count = 2;
-    repeated PartitionStat partitions_stat = 3;
+    string json_info = 2;
 }
 
 /**
- * @brief vector identity
+ * @brief vectors identity
  */
-message VectorIdentity {
+message VectorsIdentity {
     string collection_name = 1;
-    int64 id = 2;
+    repeated int64 id_array = 2;
 }
 
 /**
  * @brief vector data
  */
-message VectorData {
+message VectorsData {
     Status status = 1;
-    RowRecord vector_data = 2;
+    repeated RowRecord vectors_data = 2;
 }
 
 /**
@@ -336,6 +316,15 @@ service MilvusService {
      */
     rpc CreatePartition(PartitionParam) returns (Status) {}
 
+    /**
+     * @brief This method is used to test partition existence.
+     *
+     * @param PartitionParam, target partition.
+     *
+     * @return BoolReply
+     */
+    rpc HasPartition(PartitionParam) returns (BoolReply) {}
+
     /**
      * @brief This method is used to show partition information
      *
@@ -364,13 +353,13 @@ service MilvusService {
     rpc Insert(InsertParam) returns (VectorIds) {}
 
     /**
-     * @brief This method is used to get vector data by id.
+     * @brief This method is used to get vectors data by id array.
      *
-     * @param VectorIdentity, target vector id.
+     * @param VectorsIdentity, target vector id array.
      *
-     * @return VectorData
+     * @return VectorsData
      */
-    rpc GetVectorByID(VectorIdentity) returns (VectorData) {}
+    rpc GetVectorsByID(VectorsIdentity) returns (VectorsData) {}
 
     /**
      * @brief This method is used to get vector ids from a segment

+ 116 - 77
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -21,6 +21,7 @@ package io.milvus.client;
 
 import com.google.common.util.concurrent.ListenableFuture;
 import org.apache.commons.text.RandomStringGenerator;
+import org.json.*;
 
 import java.nio.ByteBuffer;
 import java.util.*;
@@ -268,9 +269,15 @@ class MilvusClientTest {
 
     assertTrue(Collections.disjoint(resultIdsList1, resultIdsList2));
 
+    HasPartitionResponse testHasPartition = client.hasPartition(randomCollectionName, tag1);
+    assertTrue(testHasPartition.hasPartition());
+
     Response dropPartitionResponse = client.dropPartition(randomCollectionName, tag1);
     assertTrue(dropPartitionResponse.ok());
 
+    testHasPartition = client.hasPartition(randomCollectionName, tag1);
+    assertFalse(testHasPartition.hasPartition());
+
     dropPartitionResponse = client.dropPartition(randomCollectionName, tag2);
     assertTrue(dropPartitionResponse.ok());
   }
@@ -389,6 +396,45 @@ class MilvusClientTest {
     }
   }
 
+  @org.junit.jupiter.api.Test
+  void searchById() {
+    List<List<Float>> vectors = generateFloatVectors(size, dimension);
+    vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
+    InsertParam insertParam =
+            new InsertParam.Builder(randomCollectionName).withFloatVectors(vectors).build();
+    InsertResponse insertResponse = client.insert(insertParam);
+    assertTrue(insertResponse.ok());
+    List<Long> vectorIds = insertResponse.getVectorIds();
+    assertEquals(size, vectorIds.size());
+
+    assertTrue(client.flush(randomCollectionName).ok());
+
+    final long topK = 10;
+    final int queryLength = 5;
+    SearchByIDParam searchByIDParam =
+            new SearchByIDParam.Builder(randomCollectionName)
+                    .withIDs(vectorIds.subList(0, queryLength))
+                    .withTopK(topK)
+                    .withParamsInJson("{\"nprobe\": 20}")
+                    .build();
+    SearchResponse searchResponse = client.searchByID(searchByIDParam);
+    assertTrue(searchResponse.ok());
+    List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
+    assertEquals(queryLength, resultIdsList.size());
+    List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
+    assertEquals(queryLength, resultDistancesList.size());
+    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
+    assertEquals(queryLength, queryResultsList.size());
+    final double epsilon = 0.001;
+    for (int i = 0; i < queryLength; i++) {
+      SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+      assertEquals(vectorIds.get(i), firstQueryResult.getVectorId());
+      assertEquals(vectorIds.get(i), resultIdsList.get(i).get(0));
+      assertTrue(Math.abs(1 - firstQueryResult.getDistance()) < epsilon);
+      assertTrue(Math.abs(1 - resultDistancesList.get(i).get(0)) < epsilon);
+    }
+  }
+
   @org.junit.jupiter.api.Test
   void searchAsync() throws ExecutionException, InterruptedException {
     List<List<Float>> vectors = generateFloatVectors(size, dimension);
@@ -506,7 +552,7 @@ class MilvusClientTest {
   void showCollections() {
     ShowCollectionsResponse showCollectionsResponse = client.showCollections();
     assertTrue(showCollectionsResponse.ok());
-    assertEquals(showCollectionsResponse.getCollectionNames().get(0), randomCollectionName);
+    assertTrue(showCollectionsResponse.getCollectionNames().contains(randomCollectionName));
   }
 
   @org.junit.jupiter.api.Test
@@ -564,25 +610,27 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    ShowCollectionInfoResponse showCollectionInfoResponse =
+    Response showCollectionInfoResponse =
         client.showCollectionInfo(randomCollectionName);
     assertTrue(showCollectionInfoResponse.ok());
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
 
-    CollectionInfo collectionInfo = showCollectionInfoResponse.getCollectionInfo().get();
-    assertEquals(collectionInfo.getRowCount(), size);
+    String jsonString = showCollectionInfoResponse.getMessage();
+    JSONObject jsonInfo = new JSONObject(jsonString);
+    assertTrue(jsonInfo.getInt("row_count") == size);
 
-    CollectionInfo.PartitionInfo partitionInfo = collectionInfo.getPartitionInfos().get(0);
-    assertEquals(partitionInfo.getTag(), "_default");
-    assertEquals(partitionInfo.getRowCount(), size);
+    JSONArray partitions = jsonInfo.getJSONArray("partitions");
+    JSONObject partitionInfo = partitions.getJSONObject(0);
+    assertEquals(partitionInfo.getString("tag"), "_default");
+    assertEquals(partitionInfo.getInt("row_count"), size);
 
-    CollectionInfo.PartitionInfo.SegmentInfo segmentInfo = partitionInfo.getSegmentInfos().get(0);
-    assertEquals(segmentInfo.getRowCount(), size);
-    assertEquals(segmentInfo.getIndexName(), "IDMAP");
+    JSONArray segments = partitionInfo.getJSONArray("segments");
+    JSONObject segmentInfo = segments.getJSONObject(0);
+    assertEquals(segmentInfo.getString("index_name"), "IDMAP");
+    assertEquals(segmentInfo.getInt("row_count"), size);
   }
 
   @org.junit.jupiter.api.Test
-  void getVectorById() {
+  void getVectorsById() {
     List<List<Float>> vectors = generateFloatVectors(size, dimension);
     InsertParam insertParam =
         new InsertParam.Builder(randomCollectionName).withFloatVectors(vectors).build();
@@ -593,13 +641,14 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    GetVectorByIdResponse getVectorByIdResponse =
-        client.getVectorById(randomCollectionName, vectorIds.get(0));
-    assertTrue(getVectorByIdResponse.ok());
-    assertTrue(getVectorByIdResponse.exists());
-    assertTrue(getVectorByIdResponse.isFloatVector());
-    assertFalse(getVectorByIdResponse.isBinaryVector());
-    assertArrayEquals(getVectorByIdResponse.getFloatVector().toArray(), vectors.get(0).toArray());
+    List<GetVectorByIdResponse> getVectorByIdResponse =
+        client.getVectorsById(randomCollectionName, vectorIds.subList(0, 100));
+    assertTrue(getVectorByIdResponse.size() == 100);
+    assertTrue(getVectorByIdResponse.get(0).ok());
+    assertTrue(getVectorByIdResponse.get(0).exists());
+    assertTrue(getVectorByIdResponse.get(0).isFloatVector());
+    assertFalse(getVectorByIdResponse.get(0).isBinaryVector());
+    assertArrayEquals(getVectorByIdResponse.get(0).getFloatVector().toArray(), vectors.get(0).toArray());
   }
 
   @org.junit.jupiter.api.Test
@@ -608,21 +657,19 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    ShowCollectionInfoResponse showCollectionInfoResponse =
+    Response showCollectionInfoResponse =
         client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
+    assertTrue(showCollectionInfoResponse.ok());
 
-    CollectionInfo.PartitionInfo.SegmentInfo segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
+    JSONObject jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    JSONObject segmentInfo = jsonInfo
+                                 .getJSONArray("partitions")
+                                 .getJSONObject(0)
+                                 .getJSONArray("segments")
+                                 .getJSONObject(0);
 
     GetVectorIdsResponse getVectorIdsResponse =
-        client.getVectorIds(randomCollectionName, segmentInfo.getSegmentName());
+        client.getVectorIds(randomCollectionName,segmentInfo.getString("name"));
     assertTrue(getVectorIdsResponse.ok());
     assertFalse(getVectorIdsResponse.getIds().isEmpty());
   }
@@ -684,36 +731,32 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    ShowCollectionInfoResponse showCollectionInfoResponse =
+    Response showCollectionInfoResponse =
         client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
-    CollectionInfo.PartitionInfo.SegmentInfo segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
-    long previousSegmentSize = segmentInfo.getDataSize();
+    assertTrue(showCollectionInfoResponse.ok());
 
-    assertTrue(client.deleteByIds(randomCollectionName, vectorIds.subList(0, 100)).ok());
+    JSONObject jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    JSONObject segmentInfo = jsonInfo
+            .getJSONArray("partitions")
+            .getJSONObject(0)
+            .getJSONArray("segments")
+            .getJSONObject(0);
 
-    assertTrue(client.flush(randomCollectionName).ok());
+    long previousSegmentSize = segmentInfo.getLong("data_size");
 
+    assertTrue(client.deleteByIds(randomCollectionName, vectorIds.subList(0, 100)).ok());
+    assertTrue(client.flush(randomCollectionName).ok());
     assertTrue(client.compact(randomCollectionName).ok());
 
     showCollectionInfoResponse = client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
-    segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
-    long currentSegmentSize = segmentInfo.getDataSize();
+    assertTrue(showCollectionInfoResponse.ok());
+    jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    segmentInfo = jsonInfo
+            .getJSONArray("partitions")
+            .getJSONObject(0)
+            .getJSONArray("segments")
+            .getJSONObject(0);
+    long currentSegmentSize = segmentInfo.getLong("data_size");
 
     assertTrue(currentSegmentSize < previousSegmentSize);
   }
@@ -730,36 +773,32 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    ShowCollectionInfoResponse showCollectionInfoResponse =
-        client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
-    CollectionInfo.PartitionInfo.SegmentInfo segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
-    long previousSegmentSize = segmentInfo.getDataSize();
+    Response showCollectionInfoResponse =
+            client.showCollectionInfo(randomCollectionName);
+    assertTrue(showCollectionInfoResponse.ok());
 
-    assertTrue(client.deleteByIds(randomCollectionName, vectorIds.subList(0, 100)).ok());
+    JSONObject jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    JSONObject segmentInfo = jsonInfo
+            .getJSONArray("partitions")
+            .getJSONObject(0)
+            .getJSONArray("segments")
+            .getJSONObject(0);
 
-    assertTrue(client.flush(randomCollectionName).ok());
+    long previousSegmentSize = segmentInfo.getLong("data_size");
 
+    assertTrue(client.deleteByIds(randomCollectionName, vectorIds.subList(0, 100)).ok());
+    assertTrue(client.flush(randomCollectionName).ok());
     assertTrue(client.compactAsync(randomCollectionName).get().ok());
 
     showCollectionInfoResponse = client.showCollectionInfo(randomCollectionName);
-    assertTrue(showCollectionInfoResponse.getCollectionInfo().isPresent());
-    segmentInfo =
-        showCollectionInfoResponse
-            .getCollectionInfo()
-            .get()
-            .getPartitionInfos()
-            .get(0)
-            .getSegmentInfos()
-            .get(0);
-    long currentSegmentSize = segmentInfo.getDataSize();
+    assertTrue(showCollectionInfoResponse.ok());
+    jsonInfo = new JSONObject(showCollectionInfoResponse.getMessage());
+    segmentInfo = jsonInfo
+            .getJSONArray("partitions")
+            .getJSONObject(0)
+            .getJSONArray("segments")
+            .getJSONObject(0);
+    long currentSegmentSize = segmentInfo.getLong("data_size");
 
     assertTrue(currentSegmentSize < previousSegmentSize);
   }