|
@@ -19,13 +19,16 @@
|
|
|
|
|
|
package io.milvus.client;
|
|
package io.milvus.client;
|
|
|
|
|
|
-import com.google.protobuf.ByteString;
|
|
|
|
|
|
+import com.google.common.util.concurrent.FutureCallback;
|
|
|
|
+import com.google.common.util.concurrent.Futures;
|
|
|
|
+import com.google.common.util.concurrent.ListenableFuture;
|
|
|
|
+import com.google.common.util.concurrent.MoreExecutors;
|
|
import io.grpc.StatusRuntimeException;
|
|
import io.grpc.StatusRuntimeException;
|
|
import io.milvus.exception.ClientNotConnectedException;
|
|
import io.milvus.exception.ClientNotConnectedException;
|
|
import io.milvus.exception.IllegalResponseException;
|
|
import io.milvus.exception.IllegalResponseException;
|
|
import io.milvus.exception.ParamException;
|
|
import io.milvus.exception.ParamException;
|
|
import io.milvus.grpc.*;
|
|
import io.milvus.grpc.*;
|
|
-import io.milvus.param.Constant;
|
|
|
|
|
|
+import io.milvus.param.ParamUtils;
|
|
import io.milvus.param.R;
|
|
import io.milvus.param.R;
|
|
import io.milvus.param.RpcStatus;
|
|
import io.milvus.param.RpcStatus;
|
|
import io.milvus.param.alias.AlterAliasParam;
|
|
import io.milvus.param.alias.AlterAliasParam;
|
|
@@ -42,11 +45,10 @@ import org.apache.commons.collections4.MapUtils;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
|
|
-import java.nio.ByteBuffer;
|
|
|
|
-import java.nio.ByteOrder;
|
|
|
|
|
|
+import javax.annotation.Nonnull;
|
|
import java.util.*;
|
|
import java.util.*;
|
|
import java.util.concurrent.TimeUnit;
|
|
import java.util.concurrent.TimeUnit;
|
|
-import java.util.stream.Collectors;
|
|
|
|
|
|
+import java.util.function.Function;
|
|
|
|
|
|
public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
|
|
|
|
@@ -72,97 +74,6 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
return result;
|
|
return result;
|
|
}
|
|
}
|
|
|
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
|
- private FieldData genFieldData(String fieldName, DataType dataType, List<?> objects) {
|
|
|
|
- if (objects == null) {
|
|
|
|
- throw new ParamException("Cannot generate FieldData from null object");
|
|
|
|
- }
|
|
|
|
- FieldData.Builder builder = FieldData.newBuilder();
|
|
|
|
- if (vectorDataType.contains(dataType)) {
|
|
|
|
- if (dataType == DataType.FloatVector) {
|
|
|
|
- List<Float> floats = new ArrayList<>();
|
|
|
|
- // each object is List<Float>
|
|
|
|
- for (Object object : objects) {
|
|
|
|
- if (object instanceof List) {
|
|
|
|
- List<Float> list = (List<Float>) object;
|
|
|
|
- floats.addAll(list);
|
|
|
|
- } else {
|
|
|
|
- throw new ParamException("The type of FloatVector must be List<Float>");
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- int dim = floats.size() / objects.size();
|
|
|
|
- FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
|
|
|
|
- VectorField vectorField = VectorField.newBuilder().setDim(dim).setFloatVector(floatArray).build();
|
|
|
|
- return builder.setFieldName(fieldName).setType(DataType.FloatVector).setVectors(vectorField).build();
|
|
|
|
- } else if (dataType == DataType.BinaryVector) {
|
|
|
|
- ByteBuffer totalBuf = null;
|
|
|
|
- int dim = 0;
|
|
|
|
- // each object is ByteBuffer
|
|
|
|
- for (Object object : objects) {
|
|
|
|
- ByteBuffer buf = (ByteBuffer) object;
|
|
|
|
- if (totalBuf == null){
|
|
|
|
- totalBuf = ByteBuffer.allocate(buf.position() * objects.size());
|
|
|
|
- totalBuf.put(buf.array());
|
|
|
|
- dim = buf.position() * 8;
|
|
|
|
- } else {
|
|
|
|
- totalBuf.put(buf.array());
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- assert totalBuf != null;
|
|
|
|
- ByteString byteString = ByteString.copyFrom(totalBuf.array());
|
|
|
|
- VectorField vectorField = VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build();
|
|
|
|
- return builder.setFieldName(fieldName).setType(DataType.BinaryVector).setVectors(vectorField).build();
|
|
|
|
- }
|
|
|
|
- } else {
|
|
|
|
- switch (dataType) {
|
|
|
|
- case None:
|
|
|
|
- case UNRECOGNIZED:
|
|
|
|
- throw new ParamException("Cannot support this dataType:" + dataType);
|
|
|
|
- case Int64:
|
|
|
|
- List<Long> longs = objects.stream().map(p -> (Long) p).collect(Collectors.toList());
|
|
|
|
- LongArray longArray = LongArray.newBuilder().addAllData(longs).build();
|
|
|
|
- ScalarField scalarField1 = ScalarField.newBuilder().setLongData(longArray).build();
|
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField1).build();
|
|
|
|
- case Int32:
|
|
|
|
- case Int16:
|
|
|
|
- case Int8:
|
|
|
|
- List<Integer> integers = objects.stream().map(p -> p instanceof Short ? ((Short)p).intValue() :(Integer) p).collect(Collectors.toList());
|
|
|
|
- IntArray intArray = IntArray.newBuilder().addAllData(integers).build();
|
|
|
|
- ScalarField scalarField2 = ScalarField.newBuilder().setIntData(intArray).build();
|
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField2).build();
|
|
|
|
- case Bool:
|
|
|
|
- List<Boolean> booleans = objects.stream().map(p -> (Boolean) p).collect(Collectors.toList());
|
|
|
|
- BoolArray boolArray = BoolArray.newBuilder().addAllData(booleans).build();
|
|
|
|
- ScalarField scalarField3 = ScalarField.newBuilder().setBoolData(boolArray).build();
|
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField3).build();
|
|
|
|
- case Float:
|
|
|
|
- List<Float> floats = objects.stream().map(p -> (Float) p).collect(Collectors.toList());
|
|
|
|
- FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
|
|
|
|
- ScalarField scalarField4 = ScalarField.newBuilder().setFloatData(floatArray).build();
|
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField4).build();
|
|
|
|
- case Double:
|
|
|
|
- List<Double> doubles = objects.stream().map(p -> (Double) p).collect(Collectors.toList());
|
|
|
|
- DoubleArray doubleArray = DoubleArray.newBuilder().addAllData(doubles).build();
|
|
|
|
- ScalarField scalarField5 = ScalarField.newBuilder().setDoubleData(doubleArray).build();
|
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField5).build();
|
|
|
|
- case String:
|
|
|
|
- List<String> strings = objects.stream().map(p -> (String) p).collect(Collectors.toList());
|
|
|
|
- StringArray stringArray = StringArray.newBuilder().addAllData(strings).build();
|
|
|
|
- ScalarField scalarField6 = ScalarField.newBuilder().setStringData(stringArray).build();
|
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField6).build();
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return null;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- private static final Set<DataType> vectorDataType = new HashSet<DataType>() {{
|
|
|
|
- add(DataType.FloatVector);
|
|
|
|
- add(DataType.BinaryVector);
|
|
|
|
- }};
|
|
|
|
-
|
|
|
|
private void waitForLoadingCollection(String collectionName, List<String> partitionNames,
|
|
private void waitForLoadingCollection(String collectionName, List<String> partitionNames,
|
|
long waitingInterval, long timeout) throws IllegalResponseException {
|
|
long waitingInterval, long timeout) throws IllegalResponseException {
|
|
long tsBegin = System.currentTimeMillis();
|
|
long tsBegin = System.currentTimeMillis();
|
|
@@ -178,7 +89,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
// If waiting time exceed timeout, exist the circle
|
|
// If waiting time exceed timeout, exist the circle
|
|
while (true) {
|
|
while (true) {
|
|
long tsNow = System.currentTimeMillis();
|
|
long tsNow = System.currentTimeMillis();
|
|
- if ((tsNow - tsBegin) >= timeout*1000) {
|
|
|
|
|
|
+ if ((tsNow - tsBegin) >= timeout * 1000) {
|
|
logWarning("Waiting load thread is timeout, loading process may not be finished");
|
|
logWarning("Waiting load thread is timeout, loading process may not be finished");
|
|
break;
|
|
break;
|
|
}
|
|
}
|
|
@@ -222,9 +133,9 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
// If each partition's inMemory percentage is 100, that means all the partitions have finished loading.
|
|
// If each partition's inMemory percentage is 100, that means all the partitions have finished loading.
|
|
// Otherwise, this thread will sleep a small interval and check again.
|
|
// Otherwise, this thread will sleep a small interval and check again.
|
|
// If waiting time exceed timeout, exist the circle
|
|
// If waiting time exceed timeout, exist the circle
|
|
- while(true) {
|
|
|
|
|
|
+ while (true) {
|
|
long tsNow = System.currentTimeMillis();
|
|
long tsNow = System.currentTimeMillis();
|
|
- if ((tsNow - tsBegin) >= timeout*1000) {
|
|
|
|
|
|
+ if ((tsNow - tsBegin) >= timeout * 1000) {
|
|
logWarning("Waiting load thread is timeout, loading process may not be finished");
|
|
logWarning("Waiting load thread is timeout, loading process may not be finished");
|
|
break;
|
|
break;
|
|
}
|
|
}
|
|
@@ -292,7 +203,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
collectionSegIDs.forEach((collectionName, segmentIDs) -> {
|
|
collectionSegIDs.forEach((collectionName, segmentIDs) -> {
|
|
while (segmentIDs.getDataCount() > 0) {
|
|
while (segmentIDs.getDataCount() > 0) {
|
|
long tsNow = System.currentTimeMillis();
|
|
long tsNow = System.currentTimeMillis();
|
|
- if ((tsNow - tsBegin) >= timeout*1000) {
|
|
|
|
|
|
+ if ((tsNow - tsBegin) >= timeout * 1000) {
|
|
logWarning("Waiting flush thread is timeout, flush process may not be finished");
|
|
logWarning("Waiting flush thread is timeout, flush process may not be finished");
|
|
break;
|
|
break;
|
|
}
|
|
}
|
|
@@ -301,7 +212,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
.addAllSegmentIDs(segmentIDs.getDataList())
|
|
.addAllSegmentIDs(segmentIDs.getDataList())
|
|
.build();
|
|
.build();
|
|
GetFlushStateResponse response = blockingStub().getFlushState(getFlushStateRequest);
|
|
GetFlushStateResponse response = blockingStub().getFlushState(getFlushStateRequest);
|
|
- if(response.getFlushed()) {
|
|
|
|
|
|
+ if (response.getFlushed()) {
|
|
// if all segment of this collection has been flushed, break this circle and check next collection
|
|
// if all segment of this collection has been flushed, break this circle and check next collection
|
|
String msg = segmentIDs.getDataCount() + " segments of " + collectionName + " has been flushed.";
|
|
String msg = segmentIDs.getDataCount() + " segments of " + collectionName + " has been flushed.";
|
|
logInfo(msg);
|
|
logInfo(msg);
|
|
@@ -327,7 +238,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
long tsBegin = System.currentTimeMillis();
|
|
long tsBegin = System.currentTimeMillis();
|
|
while (true) {
|
|
while (true) {
|
|
long tsNow = System.currentTimeMillis();
|
|
long tsNow = System.currentTimeMillis();
|
|
- if ((tsNow - tsBegin) >= timeout*1000) {
|
|
|
|
|
|
+ if ((tsNow - tsBegin) >= timeout * 1000) {
|
|
String msg = "Waiting index thread is timeout, index process may not be finished";
|
|
String msg = "Waiting index thread is timeout, index process may not be finished";
|
|
logWarning(msg);
|
|
logWarning(msg);
|
|
return R.failed(R.Status.Success, msg);
|
|
return R.failed(R.Status.Success, msg);
|
|
@@ -680,11 +591,11 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
- * Currently we do not support this method on client since compaction is not supported on server.
|
|
|
|
- * Now it is only for internal use of getCollectionStatistics().
|
|
|
|
|
|
+ * Flush insert buffer into storage. To make sure the buffer persisted successfully, it calls
|
|
|
|
+ * GetFlushState() to check related segments state.
|
|
*/
|
|
*/
|
|
-// @Override
|
|
|
|
- private R<FlushResponse> flush(@NonNull FlushParam requestParam) {
|
|
|
|
|
|
+ @Override
|
|
|
|
+ public R<FlushResponse> flush(@NonNull FlushParam requestParam) {
|
|
if (!clientIsReady()) {
|
|
if (!clientIsReady()) {
|
|
return R.failed(new ClientNotConnectedException("Client rpc channel is not ready"));
|
|
return R.failed(new ClientNotConnectedException("Client rpc channel is not ready"));
|
|
}
|
|
}
|
|
@@ -1320,26 +1231,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
logInfo(requestParam.toString());
|
|
logInfo(requestParam.toString());
|
|
|
|
|
|
try {
|
|
try {
|
|
- String collectionName = requestParam.getCollectionName();
|
|
|
|
- String partitionName = requestParam.getPartitionName();
|
|
|
|
- List<InsertParam.Field> fields = requestParam.getFields();
|
|
|
|
-
|
|
|
|
- //1. gen insert request
|
|
|
|
- MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build();
|
|
|
|
- InsertRequest.Builder insertBuilder = InsertRequest.newBuilder()
|
|
|
|
- .setCollectionName(collectionName)
|
|
|
|
- .setPartitionName(partitionName)
|
|
|
|
- .setBase(msgBase)
|
|
|
|
- .setNumRows(requestParam.getRowCount());
|
|
|
|
-
|
|
|
|
- //2. gen fieldData
|
|
|
|
- // TODO: check field type(use DescribeCollection get schema to compare)
|
|
|
|
- for (InsertParam.Field field : fields) {
|
|
|
|
- insertBuilder.addFieldsData(genFieldData(field.getName(), field.getType(), field.getValues()));
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- //3. call insert
|
|
|
|
- InsertRequest insertRequest = insertBuilder.build();
|
|
|
|
|
|
+ InsertRequest insertRequest = ParamUtils.ConvertInsertParam(requestParam);
|
|
MutationResult response = blockingStub().insert(insertRequest);
|
|
MutationResult response = blockingStub().insert(insertRequest);
|
|
|
|
|
|
if (response.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
if (response.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
@@ -1364,108 +1256,62 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
}
|
|
}
|
|
|
|
|
|
@Override
|
|
@Override
|
|
- @SuppressWarnings("unchecked")
|
|
|
|
- public R<SearchResults> search(@NonNull SearchParam requestParam) {
|
|
|
|
|
|
+ @SuppressWarnings("UnstableApiUsage")
|
|
|
|
+ public ListenableFuture<R<MutationResult>> insertAsync(InsertParam requestParam) {
|
|
if (!clientIsReady()) {
|
|
if (!clientIsReady()) {
|
|
- return R.failed(new ClientNotConnectedException("Client rpc channel is not ready"));
|
|
|
|
|
|
+ return Futures.immediateFuture(
|
|
|
|
+ R.failed(new ClientNotConnectedException("Client rpc channel is not ready")));
|
|
}
|
|
}
|
|
|
|
|
|
logInfo(requestParam.toString());
|
|
logInfo(requestParam.toString());
|
|
|
|
|
|
- try {
|
|
|
|
- SearchRequest.Builder builder = SearchRequest.newBuilder()
|
|
|
|
- .setDbName("")
|
|
|
|
- .setCollectionName(requestParam.getCollectionName());
|
|
|
|
- if (!requestParam.getPartitionNames().isEmpty()) {
|
|
|
|
- requestParam.getPartitionNames().forEach(builder::addPartitionNames);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- // prepare target vectors
|
|
|
|
- // TODO: check target vector dimension(use DescribeCollection get schema to compare)
|
|
|
|
- PlaceholderType plType = PlaceholderType.None;
|
|
|
|
- List<?> vectors = requestParam.getVectors();
|
|
|
|
- List<ByteString> byteStrings = new ArrayList<>();
|
|
|
|
- for (Object vector : vectors) {
|
|
|
|
- if (vector instanceof List) {
|
|
|
|
- plType = PlaceholderType.FloatVector;
|
|
|
|
- List<Float> list = (List<Float>) vector;
|
|
|
|
- ByteBuffer buf = ByteBuffer.allocate(Float.BYTES * list.size());
|
|
|
|
- buf.order(ByteOrder.LITTLE_ENDIAN);
|
|
|
|
- list.forEach(buf::putFloat);
|
|
|
|
-
|
|
|
|
- byte[] array = buf.array();
|
|
|
|
- ByteString bs = ByteString.copyFrom(array);
|
|
|
|
- byteStrings.add(bs);
|
|
|
|
- } else if (vector instanceof ByteBuffer) {
|
|
|
|
- plType = PlaceholderType.BinaryVector;
|
|
|
|
- ByteBuffer buf = (ByteBuffer) vector;
|
|
|
|
- byte[] array = buf.array();
|
|
|
|
- ByteString bs = ByteString.copyFrom(array);
|
|
|
|
- byteStrings.add(bs);
|
|
|
|
- } else {
|
|
|
|
- String msg = "Search target vector type is illegal(Only allow List<Float> or ByteBuffer)";
|
|
|
|
- logError(msg);
|
|
|
|
- return R.failed(R.Status.UnexpectedError, msg);
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- PlaceholderValue.Builder pldBuilder = PlaceholderValue.newBuilder()
|
|
|
|
- .setTag(Constant.VECTOR_TAG)
|
|
|
|
- .setType(plType);
|
|
|
|
- byteStrings.forEach(pldBuilder::addValues);
|
|
|
|
-
|
|
|
|
- PlaceholderValue plv = pldBuilder.build();
|
|
|
|
- PlaceholderGroup placeholderGroup = PlaceholderGroup.newBuilder()
|
|
|
|
- .addPlaceholders(plv)
|
|
|
|
- .build();
|
|
|
|
|
|
+ InsertRequest insertRequest = ParamUtils.ConvertInsertParam(requestParam);
|
|
|
|
+ ListenableFuture<MutationResult> response = futureStub().insert(insertRequest);
|
|
|
|
+
|
|
|
|
+ Futures.addCallback(
|
|
|
|
+ response,
|
|
|
|
+ new FutureCallback<MutationResult>() {
|
|
|
|
+ @Override
|
|
|
|
+ public void onSuccess(MutationResult result) {
|
|
|
|
+ if (result.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
|
|
+ logInfo("insertAsync successfully! Collection name:{}",
|
|
|
|
+ requestParam.getCollectionName());
|
|
|
|
+ } else {
|
|
|
|
+ logError("insertAsync failed! Collection name:{}\n{}",
|
|
|
|
+ requestParam.getCollectionName(), result.getStatus().getReason());
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
|
|
- ByteString byteStr = placeholderGroup.toByteString();
|
|
|
|
- builder.setPlaceholderGroup(byteStr);
|
|
|
|
|
|
+ @Override
|
|
|
|
+ public void onFailure(@Nonnull Throwable t) {
|
|
|
|
+ logError("insertAsync failed:\n{}", t.getMessage());
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ MoreExecutors.directExecutor());
|
|
|
|
|
|
- // search parameters
|
|
|
|
- builder.addSearchParams(
|
|
|
|
- KeyValuePair.newBuilder()
|
|
|
|
- .setKey(Constant.VECTOR_FIELD)
|
|
|
|
- .setValue(requestParam.getVectorFieldName())
|
|
|
|
- .build())
|
|
|
|
- .addSearchParams(
|
|
|
|
- KeyValuePair.newBuilder()
|
|
|
|
- .setKey(Constant.TOP_K)
|
|
|
|
- .setValue(String.valueOf(requestParam.getTopK()))
|
|
|
|
- .build())
|
|
|
|
- .addSearchParams(
|
|
|
|
- KeyValuePair.newBuilder()
|
|
|
|
- .setKey(Constant.METRIC_TYPE)
|
|
|
|
- .setValue(requestParam.getMetricType())
|
|
|
|
- .build())
|
|
|
|
- .addSearchParams(
|
|
|
|
- KeyValuePair.newBuilder()
|
|
|
|
- .setKey(Constant.ROUND_DECIMAL)
|
|
|
|
- .setValue(String.valueOf(requestParam.getRoundDecimal()))
|
|
|
|
- .build());
|
|
|
|
-
|
|
|
|
- if (null != requestParam.getParams() && !requestParam.getParams().isEmpty()) {
|
|
|
|
- builder.addSearchParams(
|
|
|
|
- KeyValuePair.newBuilder()
|
|
|
|
- .setKey(Constant.PARAMS)
|
|
|
|
- .setValue(requestParam.getParams())
|
|
|
|
- .build());
|
|
|
|
- }
|
|
|
|
|
|
+ Function<MutationResult, R<MutationResult>> transformFunc =
|
|
|
|
+ results -> {
|
|
|
|
+ if (results.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
|
|
+ return R.success(results);
|
|
|
|
+ } else {
|
|
|
|
+ return R.failed(R.Status.valueOf(results.getStatus().getErrorCode().getNumber()),
|
|
|
|
+ results.getStatus().getReason());
|
|
|
|
+ }
|
|
|
|
+ };
|
|
|
|
|
|
- if (!requestParam.getOutFields().isEmpty()) {
|
|
|
|
- requestParam.getOutFields().forEach(builder::addOutputFields);
|
|
|
|
- }
|
|
|
|
|
|
+ return Futures.transform(response, transformFunc::apply, MoreExecutors.directExecutor());
|
|
|
|
+ }
|
|
|
|
|
|
- // always use expression since dsl is discarded
|
|
|
|
- builder.setDslType(DslType.BoolExprV1);
|
|
|
|
- if (requestParam.getExpr() != null && !requestParam.getExpr().isEmpty()) {
|
|
|
|
- builder.setDsl(requestParam.getExpr());
|
|
|
|
- }
|
|
|
|
|
|
+ @Override
|
|
|
|
+ public R<SearchResults> search(@NonNull SearchParam requestParam) {
|
|
|
|
+ if (!clientIsReady()) {
|
|
|
|
+ return R.failed(new ClientNotConnectedException("Client rpc channel is not ready"));
|
|
|
|
+ }
|
|
|
|
|
|
- builder.setTravelTimestamp(requestParam.getTravelTimestamp());
|
|
|
|
- builder.setGuaranteeTimestamp(requestParam.getGuaranteeTimestamp());
|
|
|
|
|
|
+ logInfo(requestParam.toString());
|
|
|
|
|
|
- SearchRequest searchRequest = builder.build();
|
|
|
|
|
|
+ try {
|
|
|
|
+ SearchRequest searchRequest = ParamUtils.ConvertSearchParam(requestParam);
|
|
SearchResults response = this.blockingStub().search(searchRequest);
|
|
SearchResults response = this.blockingStub().search(searchRequest);
|
|
|
|
|
|
//TODO: truncate distance value by round decimal
|
|
//TODO: truncate distance value by round decimal
|
|
@@ -1481,12 +1327,59 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
} catch (StatusRuntimeException e) {
|
|
} catch (StatusRuntimeException e) {
|
|
logError("SearchRequest RPC failed:{}", e.getMessage());
|
|
logError("SearchRequest RPC failed:{}", e.getMessage());
|
|
return R.failed(e);
|
|
return R.failed(e);
|
|
- } catch (Exception e) {
|
|
|
|
|
|
+ } catch (ParamException e) {
|
|
logError("SearchRequest failed:\n{}", e.getMessage());
|
|
logError("SearchRequest failed:\n{}", e.getMessage());
|
|
return R.failed(e);
|
|
return R.failed(e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Override
|
|
|
|
+ @SuppressWarnings("UnstableApiUsage")
|
|
|
|
+ public ListenableFuture<R<SearchResults>> searchAsync(SearchParam requestParam) {
|
|
|
|
+ if (!clientIsReady()) {
|
|
|
|
+ return Futures.immediateFuture(
|
|
|
|
+ R.failed(new ClientNotConnectedException("Client rpc channel is not ready")));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ logInfo(requestParam.toString());
|
|
|
|
+
|
|
|
|
+ SearchRequest searchRequest = ParamUtils.ConvertSearchParam(requestParam);
|
|
|
|
+ ListenableFuture<SearchResults> response = this.futureStub().search(searchRequest);
|
|
|
|
+
|
|
|
|
+ Futures.addCallback(
|
|
|
|
+ response,
|
|
|
|
+ new FutureCallback<SearchResults>() {
|
|
|
|
+ @Override
|
|
|
|
+ public void onSuccess(SearchResults result) {
|
|
|
|
+ if (result.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
|
|
+ logInfo("searchAsync successfully! Collection name:{}",
|
|
|
|
+ requestParam.getCollectionName());
|
|
|
|
+ } else {
|
|
|
|
+ logError("searchAsync failed! Collection name:{}\n{}",
|
|
|
|
+ requestParam.getCollectionName(), result.getStatus().getReason());
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public void onFailure(@Nonnull Throwable t) {
|
|
|
|
+ logError("searchAsync failed:\n{}", t.getMessage());
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ MoreExecutors.directExecutor());
|
|
|
|
+
|
|
|
|
+ Function<SearchResults, R<SearchResults>> transformFunc =
|
|
|
|
+ results -> {
|
|
|
|
+ if (results.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
|
|
+ return R.success(results);
|
|
|
|
+ } else {
|
|
|
|
+ return R.failed(R.Status.valueOf(results.getStatus().getErrorCode().getNumber()),
|
|
|
|
+ results.getStatus().getReason());
|
|
|
|
+ }
|
|
|
|
+ };
|
|
|
|
+
|
|
|
|
+ return Futures.transform(response, transformFunc::apply, MoreExecutors.directExecutor());
|
|
|
|
+ }
|
|
|
|
+
|
|
@Override
|
|
@Override
|
|
public R<QueryResults> query(@NonNull QueryParam requestParam) {
|
|
public R<QueryResults> query(@NonNull QueryParam requestParam) {
|
|
if (!clientIsReady()) {
|
|
if (!clientIsReady()) {
|
|
@@ -1496,16 +1389,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
logInfo(requestParam.toString());
|
|
logInfo(requestParam.toString());
|
|
|
|
|
|
try {
|
|
try {
|
|
- QueryRequest queryRequest = QueryRequest.newBuilder()
|
|
|
|
- .setDbName("")
|
|
|
|
- .setCollectionName(requestParam.getCollectionName())
|
|
|
|
- .addAllPartitionNames(requestParam.getPartitionNames())
|
|
|
|
- .addAllOutputFields(requestParam.getOutFields())
|
|
|
|
- .setExpr(requestParam.getExpr())
|
|
|
|
- .setTravelTimestamp(requestParam.getTravelTimestamp())
|
|
|
|
- .setGuaranteeTimestamp(requestParam.getGuaranteeTimestamp())
|
|
|
|
- .build();
|
|
|
|
-
|
|
|
|
|
|
+ QueryRequest queryRequest = ParamUtils.ConvertQueryParam(requestParam);
|
|
QueryResults response = this.blockingStub().query(queryRequest);
|
|
QueryResults response = this.blockingStub().query(queryRequest);
|
|
|
|
|
|
if (response.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
if (response.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
@@ -1526,6 +1410,53 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Override
|
|
|
|
+ @SuppressWarnings("UnstableApiUsage")
|
|
|
|
+ public ListenableFuture<R<QueryResults>> queryAsync(QueryParam requestParam) {
|
|
|
|
+ if (!clientIsReady()) {
|
|
|
|
+ return Futures.immediateFuture(
|
|
|
|
+ R.failed(new ClientNotConnectedException("Client rpc channel is not ready")));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ logInfo(requestParam.toString());
|
|
|
|
+
|
|
|
|
+ QueryRequest queryRequest = ParamUtils.ConvertQueryParam(requestParam);
|
|
|
|
+ ListenableFuture<QueryResults> response = this.futureStub().query(queryRequest);
|
|
|
|
+
|
|
|
|
+ Futures.addCallback(
|
|
|
|
+ response,
|
|
|
|
+ new FutureCallback<QueryResults>() {
|
|
|
|
+ @Override
|
|
|
|
+ public void onSuccess(QueryResults result) {
|
|
|
|
+ if (result.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
|
|
+ logInfo("queryAsync successfully! Collection name:{}",
|
|
|
|
+ requestParam.getCollectionName());
|
|
|
|
+ } else {
|
|
|
|
+ logError("queryAsync failed! Collection name:{}\n{}",
|
|
|
|
+ requestParam.getCollectionName(), result.getStatus().getReason());
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public void onFailure(@Nonnull Throwable t) {
|
|
|
|
+ logError("queryAsync failed:\n{}", t.getMessage());
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ MoreExecutors.directExecutor());
|
|
|
|
+
|
|
|
|
+ Function<QueryResults, R<QueryResults>> transformFunc =
|
|
|
|
+ results -> {
|
|
|
|
+ if (results.getStatus().getErrorCode() == ErrorCode.Success) {
|
|
|
|
+ return R.success(results);
|
|
|
|
+ } else {
|
|
|
|
+ return R.failed(R.Status.valueOf(results.getStatus().getErrorCode().getNumber()),
|
|
|
|
+ results.getStatus().getReason());
|
|
|
|
+ }
|
|
|
|
+ };
|
|
|
|
+
|
|
|
|
+ return Futures.transform(response, transformFunc::apply, MoreExecutors.directExecutor());
|
|
|
|
+ }
|
|
|
|
+
|
|
@Override
|
|
@Override
|
|
public R<CalcDistanceResults> calcDistance(@NonNull CalcDistanceParam requestParam) {
|
|
public R<CalcDistanceResults> calcDistance(@NonNull CalcDistanceParam requestParam) {
|
|
if (!clientIsReady()) {
|
|
if (!clientIsReady()) {
|