Browse Source

Make `getEntityByID` return a map of entity ids to entity property maps

jianghua 4 năm trước cách đây
mục cha
commit
4138103627

+ 4 - 4
src/main/java/io/milvus/client/MilvusClient.java

@@ -342,18 +342,18 @@ public interface MilvusClient {
    * @param ids a <code>List</code> of entity ids
    * @param ids a <code>List</code> of entity ids
    * @param fieldNames  a <code>List</code> of field names. Server will only return entity
    * @param fieldNames  a <code>List</code> of field names. Server will only return entity
    *                    information for these fields.
    *                    information for these fields.
-   * @return a list of property maps.
+   * @return a map of entity id to entity properties
    */
    */
-  List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames);
+  Map<Long, Map<String, Object>> getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames);
 
 
   /**
   /**
    * Gets entities data by id array
    * Gets entities data by id array
    *
    *
    * @param collectionName collection to get entities from
    * @param collectionName collection to get entities from
    * @param ids a <code>List</code> of entity ids
    * @param ids a <code>List</code> of entity ids
-   * @return a list of property maps.
+   * @return a map of entity id to entity properties
    */
    */
-  List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids);
+  Map<Long, Map<String, Object>> getEntityByID(String collectionName, List<Long> ids);
 
 
   /**
   /**
    * Gets all entity ids in a segment
    * Gets all entity ids in a segment

+ 11 - 8
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -34,14 +34,11 @@ import io.milvus.client.exception.MilvusException;
 import io.milvus.client.exception.ServerSideMilvusException;
 import io.milvus.client.exception.ServerSideMilvusException;
 import io.milvus.client.exception.UnsupportedServerVersion;
 import io.milvus.client.exception.UnsupportedServerVersion;
 import io.milvus.grpc.*;
 import io.milvus.grpc.*;
-import org.apache.commons.lang3.ArrayUtils;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nonnull;
-import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Iterator;
@@ -405,7 +402,7 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
   }
   }
   
   
   @Override
   @Override
-  public List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames) {
+  public Map<Long, Map<String, Object>> getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames) {
     return translateExceptions(() -> {
     return translateExceptions(() -> {
       EntityIdentity request = EntityIdentity.newBuilder()
       EntityIdentity request = EntityIdentity.newBuilder()
           .setCollectionName(collectionName)
           .setCollectionName(collectionName)
@@ -417,9 +414,15 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
       Map<String, Iterator<?>> fieldIterators = response.getFieldsList()
       Map<String, Iterator<?>> fieldIterators = response.getFieldsList()
           .stream()
           .stream()
           .collect(Collectors.toMap(FieldValue::getFieldName, this::fieldValueIterator));
           .collect(Collectors.toMap(FieldValue::getFieldName, this::fieldValueIterator));
-      return response.getValidRowList().stream()
-          .map(valid -> valid ? toMap(fieldIterators) : Collections.<String, Object>emptyMap())
-          .collect(Collectors.toList());
+      Iterator<Long> idIterator = ids.iterator();
+      Map<Long, Map<String, Object>> entities = new HashMap<>(response.getValidRowList().size());
+      for (boolean valid : response.getValidRowList()) {
+        long id = idIterator.next();
+        if (valid) {
+          entities.put(id, toMap(fieldIterators));
+        }
+      }
+      return entities;
     });
     });
   }
   }
   
   
@@ -450,7 +453,7 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
   }
   }
 
 
   @Override
   @Override
-  public List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids) {
+  public Map<Long, Map<String, Object>> getEntityByID(String collectionName, List<Long> ids) {
     return getEntityByID(collectionName, ids, Collections.emptyList());
     return getEntityByID(collectionName, ids, Collections.emptyList());
   }
   }
 
 

+ 20 - 12
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -53,7 +53,6 @@ import java.util.stream.DoubleStream;
 import java.util.stream.IntStream;
 import java.util.stream.IntStream;
 import java.util.stream.LongStream;
 import java.util.stream.LongStream;
 
 
-import static org.junit.jupiter.api.Assertions.assertArrayEquals;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertThrows;
@@ -680,15 +679,19 @@ class MilvusClientTest {
     List<Long> entityIds = client.insert(insertParam);
     List<Long> entityIds = client.insert(insertParam);
     assertEquals(size, entityIds.size());
     assertEquals(size, entityIds.size());
 
 
+    client.deleteEntityByID(randomCollectionName, entityIds.subList(10, 20));
+
     client.flush(randomCollectionName);
     client.flush(randomCollectionName);
 
 
-    List<Map<String, Object>> fieldsMap =
+    Map<Long, Map<String, Object>> entities =
         client.getEntityByID(randomCollectionName, entityIds.subList(0, 100));
         client.getEntityByID(randomCollectionName, entityIds.subList(0, 100));
-    int vecIndex = 0;
-    assertTrue(fieldsMap.get(vecIndex).get("float_vec") instanceof List);
-    List<Float> first = (List<Float>) (fieldsMap.get(vecIndex).get("float_vec"));
-
-    assertArrayEquals(first.toArray(), vectors.get(0).toArray());
+    for (int i = 0; i < 100; i++) {
+      if (i >= 10 && i < 20) {
+        assertFalse(entities.containsKey(entityIds.get(i)));
+      } else {
+        assertEquals(vectors.get(i), entities.get(entityIds.get(i)).get("float_vec"));
+      }
+    }
   }
   }
 
 
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test
@@ -711,14 +714,19 @@ class MilvusClientTest {
         .setEntityIds(entityIds);
         .setEntityIds(entityIds);
     assertEquals(size, client.insert(insertParam).size());
     assertEquals(size, client.insert(insertParam).size());
 
 
+    client.deleteEntityByID(binaryCollectionName, entityIds.subList(10, 20));
+
     client.flush(binaryCollectionName);
     client.flush(binaryCollectionName);
 
 
-    List<Map<String, Object>> fieldsMap =
+    Map<Long, Map<String, Object>> entities =
         client.getEntityByID(binaryCollectionName, entityIds.subList(0, 100));
         client.getEntityByID(binaryCollectionName, entityIds.subList(0, 100));
-    assertEquals(100, fieldsMap.size());
-    assertTrue(fieldsMap.get(0).get("binary_vec") instanceof ByteBuffer);
-    ByteBuffer first = (ByteBuffer) (fieldsMap.get(0).get("binary_vec"));
-    assertEquals(vectors.get(0), first);
+    for (int i = 0; i < 100; i++) {
+      if (i >= 10 && i < 20) {
+        assertFalse(entities.containsKey(entityIds.get(i)));
+      } else {
+        assertEquals(vectors.get(i), entities.get(entityIds.get(i)).get("binary_vec"));
+      }
+    }
   }
   }
 
 
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test