|
@@ -47,11 +47,13 @@ import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
|
+import java.util.Iterator;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
import java.util.function.Function;
|
|
|
import java.util.function.Supplier;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
/** Actual implementation of interface <code>MilvusClient</code> */
|
|
|
public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
|
|
@@ -429,90 +431,54 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
|
return response.getJsonInfo();
|
|
|
});
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
@Override
|
|
|
- public GetEntityByIDResponse getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames) {
|
|
|
- if (!maybeAvailable()) {
|
|
|
- logWarning("You are not connected to Milvus server");
|
|
|
- return new GetEntityByIDResponse(
|
|
|
- new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
|
|
|
- }
|
|
|
-
|
|
|
- EntityIdentity request =
|
|
|
- EntityIdentity.newBuilder()
|
|
|
- .setCollectionName(collectionName)
|
|
|
- .addAllIdArray(ids)
|
|
|
- .addAllFieldNames(fieldNames)
|
|
|
- .build();
|
|
|
- Entities response;
|
|
|
-
|
|
|
- try {
|
|
|
- response = blockingStub().getEntityByID(request);
|
|
|
-
|
|
|
- if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
|
|
|
-
|
|
|
- logInfo("getEntityByID in collection `{}` returned successfully!", collectionName);
|
|
|
-
|
|
|
- List<Map<String, Object>> fieldsMap = new ArrayList<>();
|
|
|
- List<Boolean> isValid = response.getValidRowList();
|
|
|
- for (int i = 0; i < isValid.size(); i++) {
|
|
|
- fieldsMap.add(new HashMap<>());
|
|
|
- }
|
|
|
- List<FieldValue> fieldValueList = response.getFieldsList();
|
|
|
- for (FieldValue fieldValue : fieldValueList) {
|
|
|
- String fieldName = fieldValue.getFieldName();
|
|
|
- for (int j = 0; j < isValid.size(); j++) {
|
|
|
- if (!isValid.get(j)) continue;
|
|
|
- if (fieldValue.getAttrRecord().getInt32ValueCount() > 0) {
|
|
|
- fieldsMap.get(j)
|
|
|
- .put(fieldName, fieldValue.getAttrRecord().getInt32ValueList().get(j));
|
|
|
- } else if (fieldValue.getAttrRecord().getInt64ValueCount() > 0) {
|
|
|
- fieldsMap.get(j)
|
|
|
- .put(fieldName, fieldValue.getAttrRecord().getInt64ValueList().get(j));
|
|
|
- } else if (fieldValue.getAttrRecord().getDoubleValueCount() > 0) {
|
|
|
- fieldsMap.get(j)
|
|
|
- .put(fieldName, fieldValue.getAttrRecord().getDoubleValueList().get(j));
|
|
|
- } else if (fieldValue.getAttrRecord().getFloatValueCount() > 0) {
|
|
|
- fieldsMap.get(j)
|
|
|
- .put(fieldName, fieldValue.getAttrRecord().getFloatValueList().get(j));
|
|
|
- } else {
|
|
|
- // the object is vector
|
|
|
- List<VectorRowRecord> vectorRowRecordList =
|
|
|
- fieldValue.getVectorRecord().getRecordsList();
|
|
|
- if (vectorRowRecordList.get(j).getFloatDataCount() > 0) {
|
|
|
- fieldsMap.get(j).put(fieldName, vectorRowRecordList.get(j).getFloatDataList());
|
|
|
- } else {
|
|
|
- ByteBuffer bb = vectorRowRecordList.get(j).getBinaryData().asReadOnlyByteBuffer();
|
|
|
- byte[] b = new byte[bb.remaining()];
|
|
|
- bb.get(b);
|
|
|
- fieldsMap.get(j).put(fieldName, Arrays.asList(ArrayUtils.toObject(b)));
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return new GetEntityByIDResponse(
|
|
|
- new Response(Response.Status.SUCCESS), fieldsMap);
|
|
|
-
|
|
|
- } else {
|
|
|
- logError(
|
|
|
- "getEntityByID in collection `{}` failed:\n{}",
|
|
|
- collectionName,
|
|
|
- response.getStatus().toString());
|
|
|
- return new GetEntityByIDResponse(
|
|
|
- new Response(
|
|
|
- Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
|
|
|
- response.getStatus().getReason()),
|
|
|
- Collections.emptyList());
|
|
|
+ public List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames) {
|
|
|
+ return translateExceptions(() -> {
|
|
|
+ EntityIdentity request = EntityIdentity.newBuilder()
|
|
|
+ .setCollectionName(collectionName)
|
|
|
+ .addAllIdArray(ids)
|
|
|
+ .addAllFieldNames(fieldNames)
|
|
|
+ .build();
|
|
|
+ Entities response = blockingStub().getEntityByID(request);
|
|
|
+ checkResponseStatus(response.getStatus());
|
|
|
+ Map<String, Iterator<?>> fieldIterators = response.getFieldsList()
|
|
|
+ .stream()
|
|
|
+ .collect(Collectors.toMap(FieldValue::getFieldName, this::fieldValueIterator));
|
|
|
+ return response.getValidRowList().stream()
|
|
|
+ .map(valid -> valid ? toMap(fieldIterators) : Collections.<String, Object>emptyMap())
|
|
|
+ .collect(Collectors.toList());
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ private Map<String, Object> toMap(Map<String, Iterator<?>> fieldIterators) {
|
|
|
+ return fieldIterators.entrySet().stream()
|
|
|
+ .collect(Collectors.toMap(
|
|
|
+ entry -> entry.getKey(),
|
|
|
+ entry -> entry.getValue().next()));
|
|
|
+ }
|
|
|
+
|
|
|
+ private Iterator<?> fieldValueIterator(FieldValue fieldValue) {
|
|
|
+ if (fieldValue.hasAttrRecord()) {
|
|
|
+ AttrRecord record = fieldValue.getAttrRecord();
|
|
|
+ if (record.getInt32ValueCount() > 0) {
|
|
|
+ return record.getInt32ValueList().iterator();
|
|
|
+ } else if (record.getInt64ValueCount() > 0) {
|
|
|
+ return record.getInt64ValueList().iterator();
|
|
|
+ } else if (record.getFloatValueCount() > 0) {
|
|
|
+ return record.getFloatValueList().iterator();
|
|
|
+ } else if (record.getDoubleValueCount() > 0) {
|
|
|
+ return record.getDoubleValueList().iterator();
|
|
|
}
|
|
|
- } catch (StatusRuntimeException e) {
|
|
|
- logError("getEntityByID RPC failed:\n{}", e.getStatus().toString());
|
|
|
- return new GetEntityByIDResponse(
|
|
|
- new Response(Response.Status.RPC_ERROR, e.toString()), Collections.emptyList());
|
|
|
}
|
|
|
+ VectorRecord record = fieldValue.getVectorRecord();
|
|
|
+ return record.getRecordsList().stream()
|
|
|
+ .map(row -> row.getFloatDataCount() > 0 ? row.getFloatDataList() : row.getBinaryData().asReadOnlyByteBuffer())
|
|
|
+ .iterator();
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public GetEntityByIDResponse getEntityByID(String collectionName, List<Long> ids) {
|
|
|
+ public List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids) {
|
|
|
return getEntityByID(collectionName, ids, Collections.emptyList());
|
|
|
}
|
|
|
|