Browse Source

Simplify `getEntityByID`

jianghua 4 years ago
parent
commit
1788fc2739

+ 4 - 5
src/main/java/io/milvus/client/MilvusClient.java

@@ -25,6 +25,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils;
 import java.io.IOException;
 import java.io.InputStream;
 import java.util.List;
+import java.util.Map;
 import java.util.Properties;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Supplier;
@@ -341,11 +342,9 @@ public interface MilvusClient {
    * @param ids a <code>List</code> of entity ids
    * @param fieldNames  a <code>List</code> of field names. Server will only return entity
    *                    information for these fields.
-   * @return <code>GetEntityByIDResponse</code>
-   * @see GetEntityByIDResponse
-   * @see Response
+   * @return a list of property maps.
    */
-  GetEntityByIDResponse getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames);
+  List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames);
 
   /**
    * Gets entities data by id array
@@ -356,7 +355,7 @@ public interface MilvusClient {
    * @see GetEntityByIDResponse
    * @see Response
    */
-  GetEntityByIDResponse getEntityByID(String collectionName, List<Long> ids);
+  List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids);
 
   /**
    * Gets all entity ids in a segment

+ 44 - 78
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -47,11 +47,13 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 import java.util.function.Supplier;
+import java.util.stream.Collectors;
 
 /** Actual implementation of interface <code>MilvusClient</code> */
 public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
@@ -429,90 +431,54 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
       return response.getJsonInfo();
     });
   }
-
+  
   @Override
-  public GetEntityByIDResponse getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames) {
-    if (!maybeAvailable()) {
-      logWarning("You are not connected to Milvus server");
-      return new GetEntityByIDResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
-    }
-
-    EntityIdentity request =
-        EntityIdentity.newBuilder()
-            .setCollectionName(collectionName)
-            .addAllIdArray(ids)
-            .addAllFieldNames(fieldNames)
-            .build();
-    Entities response;
-
-    try {
-      response = blockingStub().getEntityByID(request);
-
-      if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
-
-        logInfo("getEntityByID in collection `{}` returned successfully!", collectionName);
-
-        List<Map<String, Object>> fieldsMap = new ArrayList<>();
-        List<Boolean> isValid = response.getValidRowList();
-        for (int i = 0; i < isValid.size(); i++) {
-          fieldsMap.add(new HashMap<>());
-        }
-        List<FieldValue> fieldValueList = response.getFieldsList();
-        for (FieldValue fieldValue : fieldValueList) {
-          String fieldName = fieldValue.getFieldName();
-          for (int j = 0; j < isValid.size(); j++) {
-            if (!isValid.get(j)) continue;
-            if (fieldValue.getAttrRecord().getInt32ValueCount() > 0) {
-              fieldsMap.get(j)
-                  .put(fieldName, fieldValue.getAttrRecord().getInt32ValueList().get(j));
-            } else if (fieldValue.getAttrRecord().getInt64ValueCount() > 0) {
-              fieldsMap.get(j)
-                  .put(fieldName, fieldValue.getAttrRecord().getInt64ValueList().get(j));
-            } else if (fieldValue.getAttrRecord().getDoubleValueCount() > 0) {
-              fieldsMap.get(j)
-                  .put(fieldName, fieldValue.getAttrRecord().getDoubleValueList().get(j));
-            } else if (fieldValue.getAttrRecord().getFloatValueCount() > 0) {
-              fieldsMap.get(j)
-                  .put(fieldName, fieldValue.getAttrRecord().getFloatValueList().get(j));
-            } else {
-              // the object is vector
-              List<VectorRowRecord> vectorRowRecordList =
-                  fieldValue.getVectorRecord().getRecordsList();
-              if (vectorRowRecordList.get(j).getFloatDataCount() > 0) {
-                fieldsMap.get(j).put(fieldName, vectorRowRecordList.get(j).getFloatDataList());
-              } else {
-                ByteBuffer bb = vectorRowRecordList.get(j).getBinaryData().asReadOnlyByteBuffer();
-                byte[] b = new byte[bb.remaining()];
-                bb.get(b);
-                fieldsMap.get(j).put(fieldName, Arrays.asList(ArrayUtils.toObject(b)));
-              }
-            }
-          }
-        }
-        return new GetEntityByIDResponse(
-            new Response(Response.Status.SUCCESS), fieldsMap);
-
-      } else {
-        logError(
-            "getEntityByID in collection `{}` failed:\n{}",
-            collectionName,
-            response.getStatus().toString());
-        return new GetEntityByIDResponse(
-            new Response(
-                Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
-                response.getStatus().getReason()),
-            Collections.emptyList());
+  public List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames) {
+    return translateExceptions(() -> {
+      EntityIdentity request = EntityIdentity.newBuilder()
+          .setCollectionName(collectionName)
+          .addAllIdArray(ids)
+          .addAllFieldNames(fieldNames)
+          .build();
+      Entities response = blockingStub().getEntityByID(request);
+      checkResponseStatus(response.getStatus());
+      Map<String, Iterator<?>> fieldIterators = response.getFieldsList()
+          .stream()
+          .collect(Collectors.toMap(FieldValue::getFieldName, this::fieldValueIterator));
+      return response.getValidRowList().stream()
+          .map(valid -> valid ? toMap(fieldIterators) : Collections.<String, Object>emptyMap())
+          .collect(Collectors.toList());
+    });
+  }
+  
+  private Map<String, Object> toMap(Map<String, Iterator<?>> fieldIterators) {
+    return fieldIterators.entrySet().stream()
+        .collect(Collectors.toMap(
+            entry -> entry.getKey(),
+            entry -> entry.getValue().next()));
+  }
+
+  private Iterator<?> fieldValueIterator(FieldValue fieldValue) {
+    if (fieldValue.hasAttrRecord()) {
+      AttrRecord record = fieldValue.getAttrRecord();
+      if (record.getInt32ValueCount() > 0) {
+        return record.getInt32ValueList().iterator();
+      } else if (record.getInt64ValueCount() > 0) {
+        return record.getInt64ValueList().iterator();
+      } else if (record.getFloatValueCount() > 0) {
+        return record.getFloatValueList().iterator();
+      } else if (record.getDoubleValueCount() > 0) {
+        return record.getDoubleValueList().iterator();
       }
-    } catch (StatusRuntimeException e) {
-      logError("getEntityByID RPC failed:\n{}", e.getStatus().toString());
-      return new GetEntityByIDResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), Collections.emptyList());
     }
+    VectorRecord record = fieldValue.getVectorRecord();
+    return record.getRecordsList().stream()
+        .map(row -> row.getFloatDataCount() > 0 ? row.getFloatDataList() : row.getBinaryData().asReadOnlyByteBuffer())
+        .iterator();
   }
 
   @Override
-  public GetEntityByIDResponse getEntityByID(String collectionName, List<Long> ids) {
+  public List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids) {
     return getEntityByID(collectionName, ids, Collections.emptyList());
   }
 

+ 6 - 14
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -687,11 +687,9 @@ class MilvusClientTest {
 
     assertTrue(client.flush(randomCollectionName).ok());
 
-    GetEntityByIDResponse getEntityByIDResponse =
+    List<Map<String, Object>> fieldsMap =
         client.getEntityByID(randomCollectionName, entityIds.subList(0, 100));
-    assertTrue(getEntityByIDResponse.ok());
     int vecIndex = 0;
-    List<Map<String, Object>> fieldsMap = getEntityByIDResponse.getFieldsMap();
     assertTrue(fieldsMap.get(vecIndex).get("float_vec") instanceof List);
     List<Float> first = (List<Float>) (fieldsMap.get(vecIndex).get("float_vec"));
 
@@ -721,18 +719,12 @@ class MilvusClientTest {
 
     assertTrue(client.flush(binaryCollectionName).ok());
 
-    GetEntityByIDResponse getEntityByIDResponse =
+    List<Map<String, Object>> fieldsMap =
         client.getEntityByID(binaryCollectionName, entityIds.subList(0, 100));
-    assertTrue(getEntityByIDResponse.ok());
-    assertEquals(getEntityByIDResponse.getFieldsMap().size(), 100);
-    List<Map<String, Object>> fieldsMap = getEntityByIDResponse.getFieldsMap();
-    assertTrue(fieldsMap.get(0).get("binary_vec") instanceof List);
-    List<Byte> first = (List<Byte>) (fieldsMap.get(0).get("binary_vec"));
-    byte[] bytes = new byte[first.size()];
-    for (int i = 0; i < first.size(); i++) {
-      bytes[i] = first.get(i);
-    }
-    assertEquals(ByteBuffer.wrap(bytes), vectors.get(0));
+    assertEquals(100, fieldsMap.size());
+    assertTrue(fieldsMap.get(0).get("binary_vec") instanceof ByteBuffer);
+    ByteBuffer first = (ByteBuffer) (fieldsMap.get(0).get("binary_vec"));
+    assertEquals(vectors.get(0), first);
   }
 
   @org.junit.jupiter.api.Test