|
@@ -32,7 +32,10 @@ import io.grpc.ManagedChannel;
|
|
import io.grpc.ManagedChannelBuilder;
|
|
import io.grpc.ManagedChannelBuilder;
|
|
import io.grpc.MethodDescriptor;
|
|
import io.grpc.MethodDescriptor;
|
|
import io.grpc.StatusRuntimeException;
|
|
import io.grpc.StatusRuntimeException;
|
|
|
|
+import io.milvus.client.exception.ClientSideMilvusException;
|
|
import io.milvus.client.exception.InitializationException;
|
|
import io.milvus.client.exception.InitializationException;
|
|
|
|
+import io.milvus.client.exception.MilvusException;
|
|
|
|
+import io.milvus.client.exception.ServerSideMilvusException;
|
|
import io.milvus.client.exception.UnsupportedServerVersion;
|
|
import io.milvus.client.exception.UnsupportedServerVersion;
|
|
import io.milvus.grpc.*;
|
|
import io.milvus.grpc.*;
|
|
import org.apache.commons.lang3.ArrayUtils;
|
|
import org.apache.commons.lang3.ArrayUtils;
|
|
@@ -53,6 +56,7 @@ import java.util.List;
|
|
import java.util.Map;
|
|
import java.util.Map;
|
|
import java.util.concurrent.TimeUnit;
|
|
import java.util.concurrent.TimeUnit;
|
|
import java.util.function.Function;
|
|
import java.util.function.Function;
|
|
|
|
+import java.util.function.Supplier;
|
|
|
|
|
|
/** Actual implementation of interface <code>MilvusClient</code> */
|
|
/** Actual implementation of interface <code>MilvusClient</code> */
|
|
public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
|
|
public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
|
|
@@ -60,16 +64,16 @@ public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
|
|
private static final Logger logger = LoggerFactory.getLogger(MilvusGrpcClient.class);
|
|
private static final Logger logger = LoggerFactory.getLogger(MilvusGrpcClient.class);
|
|
private static final String SUPPORTED_SERVER_VERSION = "0.11";
|
|
private static final String SUPPORTED_SERVER_VERSION = "0.11";
|
|
|
|
|
|
|
|
+ private final String target;
|
|
private final ManagedChannel channel;
|
|
private final ManagedChannel channel;
|
|
private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
|
|
private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
|
|
private final MilvusServiceGrpc.MilvusServiceFutureStub futureStub;
|
|
private final MilvusServiceGrpc.MilvusServiceFutureStub futureStub;
|
|
|
|
|
|
public MilvusGrpcClient(ConnectParam connectParam) {
|
|
public MilvusGrpcClient(ConnectParam connectParam) {
|
|
- ManagedChannelBuilder builder = connectParam.getTarget() != null
|
|
|
|
- ? ManagedChannelBuilder.forTarget(connectParam.getTarget())
|
|
|
|
- : ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort());
|
|
|
|
-
|
|
|
|
- channel = builder.usePlaintext()
|
|
|
|
|
|
+ target = connectParam.getTarget();
|
|
|
|
+ channel = ManagedChannelBuilder
|
|
|
|
+ .forTarget(connectParam.getTarget())
|
|
|
|
+ .usePlaintext()
|
|
.maxInboundMessageSize(Integer.MAX_VALUE)
|
|
.maxInboundMessageSize(Integer.MAX_VALUE)
|
|
.defaultLoadBalancingPolicy(connectParam.getDefaultLoadBalancingPolicy())
|
|
.defaultLoadBalancingPolicy(connectParam.getDefaultLoadBalancingPolicy())
|
|
.keepAliveTime(connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
|
|
.keepAliveTime(connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
|
|
@@ -84,10 +88,10 @@ public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
|
|
if (response.ok()) {
|
|
if (response.ok()) {
|
|
String serverVersion = response.getMessage();
|
|
String serverVersion = response.getMessage();
|
|
if (!serverVersion.matches("^" + SUPPORTED_SERVER_VERSION + "(\\..*)?$")) {
|
|
if (!serverVersion.matches("^" + SUPPORTED_SERVER_VERSION + "(\\..*)?$")) {
|
|
- throw new UnsupportedServerVersion(connectParam.getHost(), SUPPORTED_SERVER_VERSION, serverVersion);
|
|
|
|
|
|
+ throw new UnsupportedServerVersion(connectParam.getTarget(), SUPPORTED_SERVER_VERSION, serverVersion);
|
|
}
|
|
}
|
|
} else {
|
|
} else {
|
|
- throw new InitializationException(connectParam.getHost(), response.getMessage());
|
|
|
|
|
|
+ throw new InitializationException(connectParam.getTarget(), response.getMessage());
|
|
}
|
|
}
|
|
} catch (Throwable t) {
|
|
} catch (Throwable t) {
|
|
channel.shutdownNow();
|
|
channel.shutdownNow();
|
|
@@ -95,6 +99,11 @@ public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Override
|
|
|
|
+ public String target() {
|
|
|
|
+ return target;
|
|
|
|
+ }
|
|
|
|
+
|
|
@Override
|
|
@Override
|
|
protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
|
|
protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
|
|
return blockingStub;
|
|
return blockingStub;
|
|
@@ -138,6 +147,10 @@ public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
|
|
this.futureStub.withInterceptors(timeoutInterceptor);
|
|
this.futureStub.withInterceptors(timeoutInterceptor);
|
|
|
|
|
|
return new AbstractMilvusGrpcClient() {
|
|
return new AbstractMilvusGrpcClient() {
|
|
|
|
+ @Override
|
|
|
|
+ public String target() {
|
|
|
|
+ return MilvusGrpcClient.this.target();
|
|
|
|
+ }
|
|
|
|
|
|
@Override
|
|
@Override
|
|
protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
|
|
protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
|
|
@@ -184,81 +197,55 @@ public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
|
|
abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
private static final Logger logger = LoggerFactory.getLogger(AbstractMilvusGrpcClient.class);
|
|
private static final Logger logger = LoggerFactory.getLogger(AbstractMilvusGrpcClient.class);
|
|
|
|
|
|
- private final String extraParamKey = "params";
|
|
|
|
-
|
|
|
|
protected abstract MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub();
|
|
protected abstract MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub();
|
|
|
|
|
|
protected abstract MilvusServiceGrpc.MilvusServiceFutureStub futureStub();
|
|
protected abstract MilvusServiceGrpc.MilvusServiceFutureStub futureStub();
|
|
|
|
|
|
protected abstract boolean maybeAvailable();
|
|
protected abstract boolean maybeAvailable();
|
|
|
|
|
|
- @Override
|
|
|
|
- public Response createCollection(@Nonnull CollectionMapping collectionMapping) {
|
|
|
|
-
|
|
|
|
- if (!maybeAvailable()) {
|
|
|
|
- logWarning("You are not connected to Milvus server");
|
|
|
|
- return new Response(Response.Status.CLIENT_NOT_CONNECTED);
|
|
|
|
- }
|
|
|
|
|
|
+ private void translateExceptions(Runnable body) {
|
|
|
|
+ translateExceptions(() -> {
|
|
|
|
+ body.run();
|
|
|
|
+ return null;
|
|
|
|
+ });
|
|
|
|
+ }
|
|
|
|
|
|
- List<FieldParam> fields = new ArrayList<>();
|
|
|
|
- if (collectionMapping.getFields().size() == 0) {
|
|
|
|
- logError("Param fields must not be empty.");
|
|
|
|
- return new Response(Response.Status.ILLEGAL_ARGUMENT);
|
|
|
|
- }
|
|
|
|
- for (Map<String, Object> map : collectionMapping.getFields()) {
|
|
|
|
- if (!map.containsKey("field") || !(map.get("field") instanceof String)) {
|
|
|
|
- logError("Param fields must contain key 'field' of String.");
|
|
|
|
- return new Response(Response.Status.ILLEGAL_ARGUMENT);
|
|
|
|
- }
|
|
|
|
- if (!map.containsKey("type") || !(map.get("type") instanceof DataType)) {
|
|
|
|
- logError("Param fields must contain key 'type' of DataType.");
|
|
|
|
- return new Response(Response.Status.ILLEGAL_ARGUMENT);
|
|
|
|
- }
|
|
|
|
- io.milvus.grpc.FieldParam.Builder fieldParamBuilder = FieldParam.newBuilder()
|
|
|
|
- .setName(map.get("field").toString())
|
|
|
|
- .setTypeValue(((DataType) map.get("type")).getVal());
|
|
|
|
- if (map.containsKey(extraParamKey)) {
|
|
|
|
- KeyValuePair extraFieldParam = KeyValuePair.newBuilder()
|
|
|
|
- .setKey(extraParamKey)
|
|
|
|
- .setValue(map.get(extraParamKey).toString())
|
|
|
|
- .build();
|
|
|
|
- fieldParamBuilder.addExtraParams(extraFieldParam);
|
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
|
+ private <T> T translateExceptions(Supplier<T> body) {
|
|
|
|
+ try {
|
|
|
|
+ T result = body.get();
|
|
|
|
+ if (result instanceof ListenableFuture) {
|
|
|
|
+ ListenableFuture futureResult = (ListenableFuture) result;
|
|
|
|
+ result = (T) Futures.catching(
|
|
|
|
+ futureResult, Throwable.class, this::translate, MoreExecutors.directExecutor());
|
|
}
|
|
}
|
|
- fields.add(fieldParamBuilder.build());
|
|
|
|
|
|
+ return result;
|
|
|
|
+ } catch (Throwable e) {
|
|
|
|
+ return translate(e);
|
|
}
|
|
}
|
|
|
|
+ }
|
|
|
|
|
|
- Mapping request =
|
|
|
|
- Mapping.newBuilder()
|
|
|
|
- .setCollectionName(collectionMapping.getCollectionName())
|
|
|
|
- .addAllFields(fields)
|
|
|
|
- .addExtraParams(KeyValuePair.newBuilder()
|
|
|
|
- .setKey(extraParamKey)
|
|
|
|
- .setValue(collectionMapping.getParamsInJson())
|
|
|
|
- .build())
|
|
|
|
- .build();
|
|
|
|
-
|
|
|
|
- Status response;
|
|
|
|
-
|
|
|
|
- try {
|
|
|
|
- response = blockingStub().createCollection(request);
|
|
|
|
|
|
+ private <R> R translate(Throwable e) {
|
|
|
|
+ if (e instanceof MilvusException) {
|
|
|
|
+ throw (MilvusException) e;
|
|
|
|
+ } else {
|
|
|
|
+ throw new ClientSideMilvusException(target(), e);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
|
|
- if (response.getErrorCode() == ErrorCode.SUCCESS) {
|
|
|
|
- logInfo("Created collection successfully!\n{}", collectionMapping.toString());
|
|
|
|
- return new Response(Response.Status.SUCCESS);
|
|
|
|
- } else if (response.getReason().contentEquals("Collection already exists")) {
|
|
|
|
- logWarning("Collection `{}` already exists", collectionMapping.getCollectionName());
|
|
|
|
- return new Response(
|
|
|
|
- Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
|
|
|
|
- } else {
|
|
|
|
- logError(
|
|
|
|
- "Create collection failed\n{}\n{}", collectionMapping.toString(), response.toString());
|
|
|
|
- return new Response(
|
|
|
|
- Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
|
|
|
|
- }
|
|
|
|
- } catch (StatusRuntimeException e) {
|
|
|
|
- logError("createCollection RPC failed:\n{}", e.getStatus().toString());
|
|
|
|
- return new Response(Response.Status.RPC_ERROR, e.toString());
|
|
|
|
|
|
+ private Void checkResponseStatus(Status status) {
|
|
|
|
+ if (status.getErrorCode() != ErrorCode.SUCCESS) {
|
|
|
|
+ throw new ServerSideMilvusException(target(), status);
|
|
}
|
|
}
|
|
|
|
+ return null;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public void createCollection(@Nonnull CollectionMapping collectionMapping) {
|
|
|
|
+ translateExceptions(() -> {
|
|
|
|
+ Status response = blockingStub().createCollection(collectionMapping.grpc());
|
|
|
|
+ checkResponseStatus(response);
|
|
|
|
+ });
|
|
}
|
|
}
|
|
|
|
|
|
@Override
|
|
@Override
|
|
@@ -1130,29 +1117,7 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
response = blockingStub().describeCollection(request);
|
|
response = blockingStub().describeCollection(request);
|
|
|
|
|
|
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
|
|
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
|
|
- String extraParam = "";
|
|
|
|
- for (KeyValuePair kv : response.getExtraParamsList()) {
|
|
|
|
- if (kv.getKey().contentEquals(extraParamKey)) {
|
|
|
|
- extraParam = kv.getValue();
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- // convert fields to list of hashmap
|
|
|
|
- List<FieldParam> fields = response.getFieldsList();
|
|
|
|
- List<Map<String, Object>> fieldsCollection = new ArrayList<>(fields.size());
|
|
|
|
- for (FieldParam fieldParam : fields) {
|
|
|
|
- Map<String, Object> map = new HashMap<>();
|
|
|
|
- // copy from fieldParam to map
|
|
|
|
- map.put("field", fieldParam.getName());
|
|
|
|
- map.put("type", fieldParam.getType());
|
|
|
|
- map.put("indexParams", kvListToString(fieldParam.getIndexParamsList()));
|
|
|
|
- map.put("params", kvListToString(fieldParam.getExtraParamsList()));
|
|
|
|
- fieldsCollection.add(map);
|
|
|
|
- }
|
|
|
|
- CollectionMapping collectionMapping =
|
|
|
|
- new CollectionMapping.Builder(response.getCollectionName())
|
|
|
|
- .withFields(fieldsCollection)
|
|
|
|
- .withParamsInJson(extraParam)
|
|
|
|
- .build();
|
|
|
|
|
|
+ CollectionMapping collectionMapping = new CollectionMapping(response);
|
|
logInfo("Get Collection Info `{}` returned:\n{}", collectionName, collectionMapping);
|
|
logInfo("Get Collection Info `{}` returned:\n{}", collectionName, collectionMapping);
|
|
return new GetCollectionInfoResponse(
|
|
return new GetCollectionInfoResponse(
|
|
new Response(Response.Status.SUCCESS), collectionMapping);
|
|
new Response(Response.Status.SUCCESS), collectionMapping);
|