/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.milvus.client;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.StatusRuntimeException;
import io.milvus.grpc.*;
import javax.annotation.Nonnull;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Actual implementation of interface MilvusClient
*/
public class MilvusGrpcClient implements MilvusClient {
private static final Logger logger = LoggerFactory.getLogger(MilvusGrpcClient.class);
private static final String ANSI_RESET = "\u001B[0m";
private static final String ANSI_YELLOW = "\u001B[33m";
private static final String ANSI_PURPLE = "\u001B[35m";
private static final String ANSI_BRIGHT_PURPLE = "\u001B[95m";
private final String extraParamKey = "params";
private ManagedChannel channel = null;
private MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub = null;
private MilvusServiceGrpc.MilvusServiceFutureStub futureStub = null;
////////////////////// Constructor //////////////////////
public MilvusGrpcClient() {}
/////////////////////// Client Calls///////////////////////
@Override
public Response connect(ConnectParam connectParam) throws ConnectFailedException {
if (channel != null && !(channel.isShutdown() || channel.isTerminated())) {
logWarning("Channel is not shutdown or terminated");
throw new ConnectFailedException("Channel is not shutdown or terminated");
}
try {
channel =
ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
.usePlaintext()
.maxInboundMessageSize(Integer.MAX_VALUE)
.keepAliveTime(
connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
.keepAliveTimeout(
connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
.keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
.idleTimeout(connectParam.getIdleTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
.build();
channel.getState(true);
long timeout = connectParam.getConnectTimeout(TimeUnit.MILLISECONDS);
logInfo("Trying to connect...Timeout in {} ms", timeout);
final long checkFrequency = 100; // ms
while (channel.getState(false) != ConnectivityState.READY) {
if (timeout <= 0) {
logError("Connect timeout!");
throw new ConnectFailedException("Connect timeout");
}
TimeUnit.MILLISECONDS.sleep(checkFrequency);
timeout -= checkFrequency;
}
blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
futureStub = MilvusServiceGrpc.newFutureStub(channel);
} catch (Exception e) {
if (!(e instanceof ConnectFailedException)) {
logError("Connect failed! {}", e.toString());
}
throw new ConnectFailedException("Exception occurred: " + e.toString());
}
logInfo(
"Connection established successfully to host={}, port={}",
connectParam.getHost(), String.valueOf(connectParam.getPort()));
return new Response(Response.Status.SUCCESS);
}
@Override
public boolean isConnected() {
if (channel == null) {
return false;
}
ConnectivityState connectivityState = channel.getState(false);
return connectivityState == ConnectivityState.READY;
}
@Override
public Response disconnect() throws InterruptedException {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
} else {
try {
if (channel.shutdown().awaitTermination(60, TimeUnit.SECONDS)) {
logInfo("Channel terminated");
} else {
logError("Encountered error when terminating channel");
return new Response(Response.Status.RPC_ERROR);
}
} catch (InterruptedException e) {
logError("Exception thrown when terminating channel: {}", e.toString());
throw e;
}
}
return new Response(Response.Status.SUCCESS);
}
@Override
public Response createCollection(@Nonnull CollectionMapping collectionMapping) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
CollectionSchema request =
CollectionSchema.newBuilder()
.setCollectionName(collectionMapping.getCollectionName())
.setDimension(collectionMapping.getDimension())
.setIndexFileSize(collectionMapping.getIndexFileSize())
.setMetricType(collectionMapping.getMetricType().getVal())
.build();
Status response;
try {
response = blockingStub.createCollection(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Created collection successfully!\n{}", collectionMapping.toString());
return new Response(Response.Status.SUCCESS);
} else if (response.getReason().contentEquals("Collection already exists")) {
logWarning("Collection `{}` already exists", collectionMapping.getCollectionName());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
} else {
logError(
"Create collection failed\n{}\n{}",
collectionMapping.toString(), response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("createCollection RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public HasCollectionResponse hasCollection(@Nonnull String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new HasCollectionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
BoolReply response;
try {
response = blockingStub.hasCollection(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo("hasCollection `{}` = {}", collectionName, response.getBoolReply());
return new HasCollectionResponse(
new Response(Response.Status.SUCCESS), response.getBoolReply());
} else {
logError("hasCollection `{}` failed:\n{}", collectionName, response.toString());
return new HasCollectionResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
false);
}
} catch (StatusRuntimeException e) {
logError("hasCollection RPC failed:\n{}", e.getStatus().toString());
return new HasCollectionResponse(
new Response(Response.Status.RPC_ERROR, e.toString()), false);
}
}
@Override
public Response dropCollection(@Nonnull String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
Status response;
try {
response = blockingStub.dropCollection(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Dropped collection `{}` successfully!", collectionName);
return new Response(Response.Status.SUCCESS);
} else {
logError("Drop collection `{}` failed:\n{}", collectionName, response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("dropCollection RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public Response createIndex(@Nonnull Index index) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
KeyValuePair extraParam =
KeyValuePair.newBuilder().setKey(extraParamKey).setValue(index.getParamsInJson()).build();
IndexParam request =
IndexParam.newBuilder()
.setCollectionName(index.getCollectionName())
.setIndexType(index.getIndexType().getVal())
.addExtraParams(extraParam)
.build();
Status response;
try {
response = blockingStub.createIndex(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Created index successfully!\n{}", index.toString());
return new Response(Response.Status.SUCCESS);
} else {
logError("Create index failed:\n{}\n{}", index.toString(), response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("createIndex RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public ListenableFuture createIndexAsync(@Nonnull Index index) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
}
KeyValuePair extraParam =
KeyValuePair.newBuilder().setKey(extraParamKey).setValue(index.getParamsInJson()).build();
IndexParam request =
IndexParam.newBuilder()
.setCollectionName(index.getCollectionName())
.setIndexType(index.getIndexType().getVal())
.addExtraParams(extraParam)
.build();
ListenableFuture response;
response = futureStub.createIndex(request);
Futures.addCallback(
response,
new FutureCallback() {
@Override
public void onSuccess(Status result) {
if (result.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Created index successfully!\n{}", index.toString());
} else {
logError("CreateIndexAsync failed:\n{}\n{}", index.toString(), result.toString());
}
}
@Override
public void onFailure(Throwable t) {
logError("CreateIndexAsync failed:\n{}", t.getMessage());
}
},
MoreExecutors.directExecutor());
return Futures.transform(
response, transformStatusToResponseFunc::apply, MoreExecutors.directExecutor());
}
@Override
public Response createPartition(String collectionName, String tag) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
PartitionParam request =
PartitionParam.newBuilder().setCollectionName(collectionName).setTag(tag).build();
Status response;
try {
response = blockingStub.createPartition(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Created partition `{}` in collection `{}` successfully!", tag, collectionName);
return new Response(Response.Status.SUCCESS);
} else {
logError(
"Create partition `{}` in collection `{}` failed: {}",
tag, collectionName, response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("createPartition RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public HasPartitionResponse hasPartition(String collectionName, String tag) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new HasPartitionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
}
PartitionParam request =
PartitionParam.newBuilder().setCollectionName(collectionName).setTag(tag).build();
BoolReply response;
try {
response = blockingStub.hasPartition(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo("hasPartition with tag `{}` in `{}` = {}", tag, collectionName, response.getBoolReply());
return new HasPartitionResponse(
new Response(Response.Status.SUCCESS), response.getBoolReply());
} else {
logError("hasPartition with tag `{}` in `{}` failed:\n{}", tag, collectionName, response.toString());
return new HasPartitionResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
false);
}
} catch (StatusRuntimeException e) {
logError("hasPartition RPC failed:\n{}", e.getStatus().toString());
return new HasPartitionResponse(
new Response(Response.Status.RPC_ERROR, e.toString()), false);
}
}
@Override
public ShowPartitionsResponse showPartitions(String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new ShowPartitionsResponse(
new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
PartitionList response;
try {
response = blockingStub.showPartitions(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo(
"Current partitions of collection {}: {}",
collectionName, response.getPartitionTagArrayList());
return new ShowPartitionsResponse(
new Response(Response.Status.SUCCESS), response.getPartitionTagArrayList());
} else {
logError("Show partitions failed:\n{}", response.toString());
return new ShowPartitionsResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
new ArrayList<>());
}
} catch (StatusRuntimeException e) {
logError("showPartitions RPC failed:\n{}", e.getStatus().toString());
return new ShowPartitionsResponse(
new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
}
}
@Override
public Response dropPartition(String collectionName, String tag) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
PartitionParam request =
PartitionParam.newBuilder().setCollectionName(collectionName).setTag(tag).build();
Status response;
try {
response = blockingStub.dropPartition(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Dropped partition `{}` in collection `{}` successfully!", tag, collectionName);
return new Response(Response.Status.SUCCESS);
} else {
logError(
"Drop partition `{}` in collection `{}` failed:\n{}",
tag, collectionName, response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("dropPartition RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public InsertResponse insert(@Nonnull InsertParam insertParam) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new InsertResponse(
new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
}
List rowRecordList =
buildRowRecordList(insertParam.getFloatVectors(), insertParam.getBinaryVectors());
io.milvus.grpc.InsertParam request =
io.milvus.grpc.InsertParam.newBuilder()
.setCollectionName(insertParam.getCollectionName())
.addAllRowRecordArray(rowRecordList)
.addAllRowIdArray(insertParam.getVectorIds())
.setPartitionTag(insertParam.getPartitionTag())
.build();
VectorIds response;
try {
response = blockingStub.insert(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo(
"Inserted {} vectors to collection `{}` successfully!",
response.getVectorIdArrayCount(), insertParam.getCollectionName());
return new InsertResponse(
new Response(Response.Status.SUCCESS), response.getVectorIdArrayList());
} else {
logError("Insert vectors failed:\n{}", response.getStatus().toString());
return new InsertResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
new ArrayList<>());
}
} catch (StatusRuntimeException e) {
logError("insert RPC failed:\n{}", e.getStatus().toString());
return new InsertResponse(
new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
}
}
@Override
public ListenableFuture insertAsync(@Nonnull InsertParam insertParam) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return Futures.immediateFuture(
new InsertResponse(
new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>()));
}
List rowRecordList =
buildRowRecordList(insertParam.getFloatVectors(), insertParam.getBinaryVectors());
io.milvus.grpc.InsertParam request =
io.milvus.grpc.InsertParam.newBuilder()
.setCollectionName(insertParam.getCollectionName())
.addAllRowRecordArray(rowRecordList)
.addAllRowIdArray(insertParam.getVectorIds())
.setPartitionTag(insertParam.getPartitionTag())
.build();
ListenableFuture response;
response = futureStub.insert(request);
Futures.addCallback(
response,
new FutureCallback() {
@Override
public void onSuccess(VectorIds result) {
if (result.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo(
"Inserted {} vectors to collection `{}` successfully!",
result.getVectorIdArrayCount(), insertParam.getCollectionName());
} else {
logError("InsertAsync failed:\n{}", result.getStatus().toString());
}
}
@Override
public void onFailure(Throwable t) {
logError("InsertAsync failed:\n{}", t.getMessage());
}
},
MoreExecutors.directExecutor());
Function transformFunc =
vectorIds -> {
if (vectorIds.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
return new InsertResponse(
new Response(Response.Status.SUCCESS), vectorIds.getVectorIdArrayList());
} else {
return new InsertResponse(
new Response(
Response.Status.valueOf(vectorIds.getStatus().getErrorCodeValue()),
vectorIds.getStatus().getReason()),
new ArrayList<>());
}
};
return Futures.transform(response, transformFunc::apply, MoreExecutors.directExecutor());
}
@Override
public SearchResponse search(@Nonnull SearchParam searchParam) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
return searchResponse;
}
List rowRecordList =
buildRowRecordList(searchParam.getFloatVectors(), searchParam.getBinaryVectors());
KeyValuePair extraParam =
KeyValuePair.newBuilder()
.setKey(extraParamKey)
.setValue(searchParam.getParamsInJson())
.build();
io.milvus.grpc.SearchParam request =
io.milvus.grpc.SearchParam.newBuilder()
.setCollectionName(searchParam.getCollectionName())
.addAllQueryRecordArray(rowRecordList)
.addAllPartitionTagArray(searchParam.getPartitionTags())
.setTopk(searchParam.getTopK())
.addExtraParams(extraParam)
.build();
TopKQueryResult response;
try {
response = blockingStub.search(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
SearchResponse searchResponse = buildSearchResponse(response);
searchResponse.setResponse(new Response(Response.Status.SUCCESS));
logInfo(
"Search completed successfully! Returned results for {} queries",
searchResponse.getNumQueries());
return searchResponse;
} else {
logError("Search failed:\n{}", response.getStatus().toString());
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()));
return searchResponse;
}
} catch (StatusRuntimeException e) {
logError("search RPC failed:\n{}", e.getStatus().toString());
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
return searchResponse;
}
}
@Override
public SearchResponse searchByIds(@Nonnull SearchByIdsParam searchByIdsParam) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
return searchResponse;
}
List idList = searchByIdsParam.getIds();
KeyValuePair extraParam =
KeyValuePair.newBuilder()
.setKey(extraParamKey)
.setValue(searchByIdsParam.getParamsInJson())
.build();
io.milvus.grpc.SearchByIDParam request =
io.milvus.grpc.SearchByIDParam.newBuilder()
.setCollectionName(searchByIdsParam.getCollectionName())
.addAllIdArray(idList)
.addAllPartitionTagArray(searchByIdsParam.getPartitionTags())
.setTopk(searchByIdsParam.getTopK())
.addExtraParams(extraParam)
.build();
TopKQueryResult response;
try {
response = blockingStub.searchByID(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
SearchResponse searchResponse = buildSearchResponse(response);
searchResponse.setResponse(new Response(Response.Status.SUCCESS));
logInfo(
"Search by ids completed successfully! Returned results for {} queries",
searchResponse.getNumQueries());
return searchResponse;
} else {
logError("Search by ids failed:\n{}", response.getStatus().toString());
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()));
return searchResponse;
}
} catch (StatusRuntimeException e) {
logError("search by ids RPC failed:\n{}", e.getStatus().toString());
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
return searchResponse;
}
}
@Override
public ListenableFuture searchAsync(@Nonnull SearchParam searchParam) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
return Futures.immediateFuture(searchResponse);
}
List rowRecordList =
buildRowRecordList(searchParam.getFloatVectors(), searchParam.getBinaryVectors());
KeyValuePair extraParam =
KeyValuePair.newBuilder()
.setKey(extraParamKey)
.setValue(searchParam.getParamsInJson())
.build();
io.milvus.grpc.SearchParam request =
io.milvus.grpc.SearchParam.newBuilder()
.setCollectionName(searchParam.getCollectionName())
.addAllQueryRecordArray(rowRecordList)
.addAllPartitionTagArray(searchParam.getPartitionTags())
.setTopk(searchParam.getTopK())
.addExtraParams(extraParam)
.build();
ListenableFuture response;
response = futureStub.search(request);
Futures.addCallback(
response,
new FutureCallback() {
@Override
public void onSuccess(TopKQueryResult result) {
if (result.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo(
"SearchAsync completed successfully! Returned results for {} queries",
result.getRowNum());
} else {
logError("SearchAsync failed:\n{}", result.getStatus().toString());
}
}
@Override
public void onFailure(Throwable t) {
logError("SearchAsync failed:\n{}", t.getMessage());
}
},
MoreExecutors.directExecutor());
Function transformFunc =
topKQueryResult -> {
if (topKQueryResult.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
SearchResponse searchResponse = buildSearchResponse(topKQueryResult);
searchResponse.setResponse(new Response(Response.Status.SUCCESS));
return searchResponse;
} else {
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(
new Response(
Response.Status.valueOf(topKQueryResult.getStatus().getErrorCodeValue()),
topKQueryResult.getStatus().getReason()));
return searchResponse;
}
};
return Futures.transform(response, transformFunc::apply, MoreExecutors.directExecutor());
}
@Override
public SearchResponse searchInFiles(
@Nonnull List fileIds, @Nonnull SearchParam searchParam) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
return searchResponse;
}
List rowRecordList =
buildRowRecordList(searchParam.getFloatVectors(), searchParam.getBinaryVectors());
KeyValuePair extraParam =
KeyValuePair.newBuilder()
.setKey(extraParamKey)
.setValue(searchParam.getParamsInJson())
.build();
io.milvus.grpc.SearchParam constructSearchParam =
io.milvus.grpc.SearchParam.newBuilder()
.setCollectionName(searchParam.getCollectionName())
.addAllQueryRecordArray(rowRecordList)
.addAllPartitionTagArray(searchParam.getPartitionTags())
.setTopk(searchParam.getTopK())
.addExtraParams(extraParam)
.build();
SearchInFilesParam request =
SearchInFilesParam.newBuilder()
.addAllFileIdArray(fileIds)
.setSearchParam(constructSearchParam)
.build();
TopKQueryResult response;
try {
response = blockingStub.searchInFiles(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
SearchResponse searchResponse = buildSearchResponse(response);
searchResponse.setResponse(new Response(Response.Status.SUCCESS));
logInfo(
"Search in files completed successfully! Returned results for {} queries",
searchResponse.getNumQueries());
return searchResponse;
} else {
logError("Search in files failed: {}", response.getStatus().toString());
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()));
return searchResponse;
}
} catch (StatusRuntimeException e) {
logError("searchInFiles RPC failed:\n{}", e.getStatus().toString());
SearchResponse searchResponse = new SearchResponse();
searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
return searchResponse;
}
}
@Override
public DescribeCollectionResponse describeCollection(@Nonnull String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new DescribeCollectionResponse(
new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
CollectionSchema response;
try {
response = blockingStub.describeCollection(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
CollectionMapping collectionMapping =
new CollectionMapping.Builder(response.getCollectionName(), response.getDimension())
.withIndexFileSize(response.getIndexFileSize())
.withMetricType(MetricType.valueOf(response.getMetricType()))
.build();
logInfo("Describe Collection `{}` returned:\n{}", collectionName, collectionMapping);
return new DescribeCollectionResponse(
new Response(Response.Status.SUCCESS), collectionMapping);
} else {
logError(
"Describe Collection `{}` failed:\n{}",
collectionName, response.getStatus().toString());
return new DescribeCollectionResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
null);
}
} catch (StatusRuntimeException e) {
logError("describeCollection RPC failed:\n{}", e.getStatus().toString());
return new DescribeCollectionResponse(
new Response(Response.Status.RPC_ERROR, e.toString()), null);
}
}
@Override
public ShowCollectionsResponse showCollections() {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new ShowCollectionsResponse(
new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
}
Command request = Command.newBuilder().setCmd("").build();
CollectionNameList response;
try {
response = blockingStub.showCollections(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
List collectionNames = response.getCollectionNamesList();
logInfo("Current collections: {}", collectionNames.toString());
return new ShowCollectionsResponse(new Response(Response.Status.SUCCESS), collectionNames);
} else {
logError("Show collections failed:\n{}", response.getStatus().toString());
return new ShowCollectionsResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
new ArrayList<>());
}
} catch (StatusRuntimeException e) {
logError("showCollections RPC failed:\n{}", e.getStatus().toString());
return new ShowCollectionsResponse(
new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
}
}
@Override
public GetCollectionRowCountResponse getCollectionRowCount(@Nonnull String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new GetCollectionRowCountResponse(
new Response(Response.Status.CLIENT_NOT_CONNECTED), 0);
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
CollectionRowCount response;
try {
response = blockingStub.countCollection(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
long collectionRowCount = response.getCollectionRowCount();
logInfo("Collection `{}` has {} rows", collectionName, collectionRowCount);
return new GetCollectionRowCountResponse(
new Response(Response.Status.SUCCESS), collectionRowCount);
} else {
logError(
"Get collection `{}` row count failed:\n{}",
collectionName, response.getStatus().toString());
return new GetCollectionRowCountResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
0);
}
} catch (StatusRuntimeException e) {
logError("countCollection RPC failed:\n{}", e.getStatus().toString());
return new GetCollectionRowCountResponse(
new Response(Response.Status.RPC_ERROR, e.toString()), 0);
}
}
@Override
public Response getServerStatus() {
return command("status");
}
@Override
public Response getServerVersion() {
return command("version");
}
public Response command(@Nonnull String command) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
Command request = Command.newBuilder().setCmd(command).build();
StringReply response;
try {
response = blockingStub.cmd(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Command `{}`: {}", command, response.getStringReply());
return new Response(Response.Status.SUCCESS, response.getStringReply());
} else {
logError("Command `{}` failed:\n{}", command, response.toString());
return new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason());
}
} catch (StatusRuntimeException e) {
logError("Command RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public Response preloadCollection(@Nonnull String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
Status response;
try {
response = blockingStub.preloadCollection(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Preloaded collection `{}` successfully!", collectionName);
return new Response(Response.Status.SUCCESS);
} else {
logError("Preload collection `{}` failed:\n{}", collectionName, response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("preloadCollection RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public DescribeIndexResponse describeIndex(@Nonnull String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new DescribeIndexResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
IndexParam response;
try {
response = blockingStub.describeIndex(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
String extraParam = "";
for (KeyValuePair kv : response.getExtraParamsList()) {
if (kv.getKey().contentEquals(extraParamKey)) {
extraParam = kv.getValue();
}
}
Index index =
new Index.Builder(response.getCollectionName(), IndexType.valueOf(response.getIndexType()))
.withParamsInJson(extraParam)
.build();
logInfo(
"Describe index for collection `{}` returned:\n{}", collectionName, index.toString());
return new DescribeIndexResponse(new Response(Response.Status.SUCCESS), index);
} else {
logError(
"Describe index for collection `{}` failed:\n{}",
collectionName, response.getStatus().toString());
return new DescribeIndexResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
null);
}
} catch (StatusRuntimeException e) {
logError("describeIndex RPC failed:\n{}", e.getStatus().toString());
return new DescribeIndexResponse(new Response(Response.Status.RPC_ERROR, e.toString()), null);
}
}
@Override
public Response dropIndex(@Nonnull String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
Status response;
try {
response = blockingStub.dropIndex(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Dropped index for collection `{}` successfully!", collectionName);
return new Response(Response.Status.SUCCESS);
} else {
logError(
"Drop index for collection `{}` failed:\n{}", collectionName, response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("dropIndex RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public Response showCollectionInfo(String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
io.milvus.grpc.CollectionInfo response;
try {
response = blockingStub.showCollectionInfo(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo("ShowCollectionInfo for `{}` returned successfully!", collectionName);
return new Response(Response.Status.SUCCESS, response.getJsonInfo());
} else {
logError(
"ShowCollectionInfo for `{}` failed:\n{}",
collectionName, response.getStatus().toString());
return new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason());
}
} catch (StatusRuntimeException e) {
logError("showCollectionInfo RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public GetVectorsByIdsResponse getVectorsByIds(String collectionName, List ids) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new GetVectorsByIdsResponse(
new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>(), null);
}
VectorsIdentity request =
VectorsIdentity.newBuilder().setCollectionName(collectionName).addAllIdArray(ids).build();
VectorsData response;
try {
response = blockingStub.getVectorsByID(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo(
"getVectorsByIds in collection `{}` returned successfully!", collectionName);
List> floatVectors = new ArrayList<>();
List binaryVectors = new ArrayList<>();
for (int i = 0; i < ids.size(); i++) {
floatVectors.add(response.getVectorsData(i).getFloatDataList());
binaryVectors.add(response.getVectorsData(i).getBinaryData().asReadOnlyByteBuffer());
}
return new GetVectorsByIdsResponse(
new Response(Response.Status.SUCCESS), floatVectors, binaryVectors);
} else {
logError(
"getVectorsByIds in collection `{}` failed:\n{}",
collectionName, response.getStatus().toString());
return new GetVectorsByIdsResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
new ArrayList<>(),
null);
}
} catch (StatusRuntimeException e) {
logError("getVectorsByIds RPC failed:\n{}", e.getStatus().toString());
return new GetVectorsByIdsResponse(
new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>(), null);
}
}
@Override
public GetVectorIdsResponse getVectorIds(String collectionName, String segmentName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new GetVectorIdsResponse(
new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
}
GetVectorIDsParam request =
GetVectorIDsParam.newBuilder()
.setCollectionName(collectionName)
.setSegmentName(segmentName)
.build();
VectorIds response;
try {
response = blockingStub.getVectorIDs(request);
if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
logInfo(
"getVectorIds in collection `{}`, segment `{}` returned successfully!",
collectionName, segmentName);
return new GetVectorIdsResponse(
new Response(Response.Status.SUCCESS), response.getVectorIdArrayList());
} else {
logError(
"getVectorIds in collection `{}`, segment `{}` failed:\n{}",
collectionName, segmentName, response.getStatus().toString());
return new GetVectorIdsResponse(
new Response(
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
response.getStatus().getReason()),
new ArrayList<>());
}
} catch (StatusRuntimeException e) {
logError("getVectorIds RPC failed:\n{}", e.getStatus().toString());
return new GetVectorIdsResponse(
new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
}
}
@Override
public Response deleteByIds(String collectionName, List ids) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
DeleteByIDParam request =
DeleteByIDParam.newBuilder().setCollectionName(collectionName).addAllIdArray(ids).build();
Status response;
try {
response = blockingStub.deleteByID(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("deleteByIds in collection `{}` completed successfully!", collectionName);
return new Response(Response.Status.SUCCESS);
} else {
logError(
"deleteByIds in collection `{}` failed:\n{}", collectionName, response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("deleteByIds RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public Response deleteById(String collectionName, Long id) {
List list =
new ArrayList() {
{
add(id);
}
};
return deleteByIds(collectionName, list);
}
@Override
public Response flush(List collectionNames) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
FlushParam request = FlushParam.newBuilder().addAllCollectionNameArray(collectionNames).build();
Status response;
try {
response = blockingStub.flush(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Flushed collection {} successfully!", collectionNames);
return new Response(Response.Status.SUCCESS);
} else {
logError("Flush collection {} failed:\n{}", collectionNames, response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("flush RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public ListenableFuture flushAsync(@Nonnull List collectionNames) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
}
FlushParam request = FlushParam.newBuilder().addAllCollectionNameArray(collectionNames).build();
ListenableFuture response;
response = futureStub.flush(request);
Futures.addCallback(
response,
new FutureCallback() {
@Override
public void onSuccess(Status result) {
if (result.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Flushed collection {} successfully!", collectionNames);
} else {
logError("Flush collection {} failed:\n{}", collectionNames, result.toString());
}
}
@Override
public void onFailure(Throwable t) {
logError("FlushAsync failed:\n{}", t.getMessage());
}
},
MoreExecutors.directExecutor());
return Futures.transform(
response, transformStatusToResponseFunc::apply, MoreExecutors.directExecutor());
}
@Override
public Response flush(String collectionName) {
List list =
new ArrayList() {
{
add(collectionName);
}
};
return flush(list);
}
@Override
public ListenableFuture flushAsync(String collectionName) {
List list =
new ArrayList() {
{
add(collectionName);
}
};
return flushAsync(list);
}
@Override
public Response compact(String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return new Response(Response.Status.CLIENT_NOT_CONNECTED);
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
Status response;
try {
response = blockingStub.compact(request);
if (response.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Compacted collection `{}` successfully!", collectionName);
return new Response(Response.Status.SUCCESS);
} else {
logError("Compact collection `{}` failed:\n{}", collectionName, response.toString());
return new Response(
Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
}
} catch (StatusRuntimeException e) {
logError("compact RPC failed:\n{}", e.getStatus().toString());
return new Response(Response.Status.RPC_ERROR, e.toString());
}
}
@Override
public ListenableFuture compactAsync(@Nonnull String collectionName) {
if (!channelIsReadyOrIdle()) {
logWarning("You are not connected to Milvus server");
return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
}
CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
ListenableFuture response;
response = futureStub.compact(request);
Futures.addCallback(
response,
new FutureCallback() {
@Override
public void onSuccess(Status result) {
if (result.getErrorCode() == ErrorCode.SUCCESS) {
logInfo("Compacted collection `{}` successfully!", collectionName);
} else {
logError("Compact collection `{}` failed:\n{}", collectionName, result.toString());
}
}
@Override
public void onFailure(Throwable t) {
logError("CompactAsync failed:\n{}", t.getMessage());
}
},
MoreExecutors.directExecutor());
return Futures.transform(
response, transformStatusToResponseFunc::apply, MoreExecutors.directExecutor());
}
///////////////////// Util Functions/////////////////////
Function transformStatusToResponseFunc =
status -> {
if (status.getErrorCode() == ErrorCode.SUCCESS) {
return new Response(Response.Status.SUCCESS);
} else {
return new Response(
Response.Status.valueOf(status.getErrorCodeValue()), status.getReason());
}
};
private List buildRowRecordList(
@Nonnull List> floatVectors, @Nonnull List binaryVectors) {
List rowRecordList = new ArrayList<>();
int largerSize = Math.max(floatVectors.size(), binaryVectors.size());
for (int i = 0; i < largerSize; ++i) {
RowRecord.Builder rowRecordBuilder = RowRecord.newBuilder();
if (i < floatVectors.size()) {
rowRecordBuilder.addAllFloatData(floatVectors.get(i));
}
if (i < binaryVectors.size()) {
((Buffer) binaryVectors.get(i)).rewind();
rowRecordBuilder.setBinaryData(ByteString.copyFrom(binaryVectors.get(i)));
}
rowRecordList.add(rowRecordBuilder.build());
}
return rowRecordList;
}
private SearchResponse buildSearchResponse(TopKQueryResult topKQueryResult) {
final int numQueries = (int) topKQueryResult.getRowNum();
final int topK =
numQueries == 0
? 0
: topKQueryResult.getIdsCount()
/ numQueries; // Guaranteed to be divisible from server side
List> resultIdsList = new ArrayList<>();
List> resultDistancesList = new ArrayList<>();
if (topK > 0) {
for (int i = 0; i < numQueries; i++) {
// Process result of query i
int pos = i * topK;
while (pos < i * topK + topK && topKQueryResult.getIdsList().get(pos) != -1) {
pos++;
}
resultIdsList.add(topKQueryResult.getIdsList().subList(i * topK, pos));
resultDistancesList.add(topKQueryResult.getDistancesList().subList(i * topK, pos));
}
}
SearchResponse searchResponse = new SearchResponse();
searchResponse.setNumQueries(numQueries);
searchResponse.setTopK(topK);
searchResponse.setResultIdsList(resultIdsList);
searchResponse.setResultDistancesList(resultDistancesList);
return searchResponse;
}
private boolean channelIsReadyOrIdle() {
if (channel == null) {
return false;
}
ConnectivityState connectivityState = channel.getState(false);
return connectivityState == ConnectivityState.READY
|| connectivityState
== ConnectivityState.IDLE; // Since a new RPC would take the channel out of idle mode
}
///////////////////// Log Functions//////////////////////
private void logInfo(String msg, Object... params) {
logger.info(msg, params);
}
private void logWarning(String msg, Object... params) {
logger.warn(msg, params);
}
private void logError(String msg, Object... params) {
logger.error(msg, params);
}
}