Browse Source

Fix GetVectorsByIdsResponse

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>
sahuang 5 years ago
parent
commit
a0bc6f96b9

+ 2 - 2
examples/src/main/java/MilvusClientExample.java

@@ -229,10 +229,10 @@ public class MilvusClientExample {
     flushResponse = client.flush(collectionName);
 
     // Try to get the corresponding vector of the first id you just deleted.
-    List<GetVectorByIdResponse> getVectorByIdResponse =
+    GetVectorsByIdsResponse getVectorsByIdsResponse =
         client.getVectorsByIds(collectionName, vectorIds.subList(0, searchBatchSize));
     // Obviously you won't get anything
-    if (getVectorByIdResponse.get(0).exists()) {
+    if (!getVectorsByIdsResponse.getFloatVectors().get(0).isEmpty()) {
       throw new AssertionError("This can never happen!");
     }
 

+ 0 - 59
src/main/java/io/milvus/client/GetVectorByIdResponse.java

@@ -1,59 +0,0 @@
-package io.milvus.client;
-
-import java.nio.ByteBuffer;
-import java.util.List;
-import java.util.Optional;
-
-/**
- * Contains the returned <code>response</code> and either a <code>floatVector</code> or a <code>
- * binaryVector</code> for each vector of <code>getVectorsByIds</code>. If the id does not exist, both returned
- * vectors will be empty.
- */
-public class GetVectorByIdResponse {
-  private final Response response;
-  private final List<Float> floatVector;
-  private final ByteBuffer binaryVector;
-
-  GetVectorByIdResponse(Response response, List<Float> floatVector, ByteBuffer binaryVector) {
-    this.response = response;
-    this.floatVector = floatVector;
-    this.binaryVector = binaryVector;
-  }
-
-  public List<Float> getFloatVector() {
-    return floatVector;
-  }
-
-  /**
-   * @return an <code>Optional</code> object which may or may not contain a <code>ByteBuffer</code>
-   *     object
-   * @see Optional
-   */
-  public Optional<ByteBuffer> getBinaryVector() {
-    return Optional.ofNullable(binaryVector);
-  }
-
-  /** @return <code>true</code> if the id corresponds to a float vector */
-  public boolean isFloatVector() {
-    return !floatVector.isEmpty();
-  }
-
-  /** @return <code>true</code> if the id corresponds to a binary vector */
-  public boolean isBinaryVector() {
-    return binaryVector != null && binaryVector.hasRemaining();
-  }
-
-  /** @return <code>true</code> if the id's corresponding vector exists */
-  public boolean exists() {
-    return isFloatVector() || isBinaryVector();
-  }
-
-  public Response getResponse() {
-    return response;
-  }
-
-  /** @return <code>true</code> if the response status equals SUCCESS */
-  public boolean ok() {
-    return response.ok();
-  }
-}

+ 43 - 0
src/main/java/io/milvus/client/GetVectorsByIdsResponse.java

@@ -0,0 +1,43 @@
+package io.milvus.client;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Contains the returned <code>response</code> and either a <code>List</code> of <code>floatVectors</code> or <code>
+ * binaryVectors</code> for <code>getVectorsByIds</code>. If the id does not exist, both float and binary
+ * vectors corresponding to the id will be empty.
+ */
+public class GetVectorsByIdsResponse {
+  private final Response response;
+  private final List<List<Float>> floatVectors;
+  private final List<ByteBuffer> binaryVectors;
+
+  GetVectorsByIdsResponse(Response response, List<List<Float>> floatVectors, List<ByteBuffer> binaryVectors) {
+    this.response = response;
+    this.floatVectors = floatVectors;
+    this.binaryVectors = binaryVectors;
+  }
+
+  public List<List<Float>> getFloatVectors() {
+    return floatVectors;
+  }
+
+  /**
+   * @return a <code>List</code> of <code>ByteBuffer</code> object
+   */
+  public List<ByteBuffer> getBinaryVectors() {
+    return binaryVectors;
+  }
+
+  public Response getResponse() {
+    return response;
+  }
+
+  /** @return <code>true</code> if the response status equals SUCCESS */
+  public boolean ok() {
+    return response.ok();
+  }
+
+}

+ 3 - 3
src/main/java/io/milvus/client/MilvusClient.java

@@ -436,11 +436,11 @@ public interface MilvusClient {
    *
    * @param collectionName collection to get vectors from
    * @param ids a <code>List</code> of vector ids
-   * @return <code>List<GetVectorByIdResponse></code>
-   * @see GetVectorByIdResponse
+   * @return <code>GetVectorsByIdsResponse</code>
+   * @see GetVectorsByIdsResponse
    * @see Response
    */
-  List<GetVectorByIdResponse> getVectorsByIds(String collectionName, List<Long> ids);
+  GetVectorsByIdsResponse getVectorsByIds(String collectionName, List<Long> ids);
 
   /**
    * Gets all vector ids in a segment

+ 13 - 16
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -1108,13 +1108,11 @@ public class MilvusGrpcClient implements MilvusClient {
   }
 
   @Override
-  public List<GetVectorByIdResponse> getVectorsByIds(String collectionName, List<Long> ids) {
-    List<GetVectorByIdResponse> res = new ArrayList<>();
+  public GetVectorsByIdsResponse getVectorsByIds(String collectionName, List<Long> ids) {
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
-      res.add(new GetVectorByIdResponse(
-              new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>(), null));
-      return res;
+      return new GetVectorsByIdsResponse(
+              new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>(), null);
     }
 
     VectorsIdentity request =
@@ -1129,31 +1127,30 @@ public class MilvusGrpcClient implements MilvusClient {
         logInfo(
             "getVectorsByIds in collection `{0}` returned successfully!", collectionName);
 
+        List<List<Float>> floatVectors = new ArrayList<>();
+        List<ByteBuffer> binaryVectors = new ArrayList<>();
         for (int i = 0; i < ids.size(); i++) {
-          res.add(new GetVectorByIdResponse(
-                  new Response(Response.Status.SUCCESS),
-                  response.getVectorsData(i).getFloatDataList(),
-                  response.getVectorsData(i).getBinaryData().asReadOnlyByteBuffer()));
+          floatVectors.add(response.getVectorsData(i).getFloatDataList());
+          binaryVectors.add(response.getVectorsData(i).getBinaryData().asReadOnlyByteBuffer());
         }
-        return res;
+        return new GetVectorsByIdsResponse(
+                new Response(Response.Status.SUCCESS), floatVectors, binaryVectors);
 
       } else {
         logSevere(
             "getVectorsByIds in collection `{0}` failed:\n{1}",
             collectionName, response.getStatus().toString());
-        res.add(new GetVectorByIdResponse(
+        return new GetVectorsByIdsResponse(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 response.getStatus().getReason()),
             new ArrayList<>(),
-            null));
-        return res;
+            null);
       }
     } catch (StatusRuntimeException e) {
       logSevere("getVectorsByIds RPC failed:\n{0}", e.getStatus().toString());
-      res.add(new GetVectorByIdResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>(), null));
-      return res;
+      return new GetVectorsByIdsResponse(
+          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>(), null);
     }
   }
 

+ 6 - 7
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -641,14 +641,13 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    List<GetVectorByIdResponse> getVectorByIdResponse =
+    GetVectorsByIdsResponse getVectorsByIdsResponse =
         client.getVectorsByIds(randomCollectionName, vectorIds.subList(0, 100));
-    assertTrue(getVectorByIdResponse.size() == 100);
-    assertTrue(getVectorByIdResponse.get(0).ok());
-    assertTrue(getVectorByIdResponse.get(0).exists());
-    assertTrue(getVectorByIdResponse.get(0).isFloatVector());
-    assertFalse(getVectorByIdResponse.get(0).isBinaryVector());
-    assertArrayEquals(getVectorByIdResponse.get(0).getFloatVector().toArray(), vectors.get(0).toArray());
+    assertTrue(getVectorsByIdsResponse.ok());
+    ByteBuffer bb = getVectorsByIdsResponse.getBinaryVectors().get(0);
+    assertTrue(bb == null || bb.remaining() == 0);
+
+    assertArrayEquals(getVectorsByIdsResponse.getFloatVectors().get(0).toArray(), vectors.get(0).toArray());
   }
 
   @org.junit.jupiter.api.Test