|
@@ -23,6 +23,7 @@ import io.grpc.ConnectivityState;
|
|
|
import io.grpc.ManagedChannel;
|
|
|
import io.grpc.ManagedChannelBuilder;
|
|
|
import io.grpc.StatusRuntimeException;
|
|
|
+import org.apache.commons.collections4.ListUtils;
|
|
|
|
|
|
import javax.annotation.Nonnull;
|
|
|
import java.text.SimpleDateFormat;
|
|
@@ -69,8 +70,7 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
.idleTimeout(connectParam.getIdleTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
|
|
|
.build();
|
|
|
|
|
|
- ConnectivityState connectivityState;
|
|
|
- connectivityState = channel.getState(true);
|
|
|
+ channel.getState(true);
|
|
|
|
|
|
long timeout = connectParam.getConnectTimeout(TimeUnit.MILLISECONDS);
|
|
|
logInfo("Trying to connect...Timeout in {0} ms", timeout);
|
|
@@ -94,7 +94,9 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
throw new ConnectFailedException("Exception occurred: " + e.toString());
|
|
|
}
|
|
|
|
|
|
- logInfo("Connected successfully!\n{0}", connectParam.toString());
|
|
|
+ logInfo(
|
|
|
+ "Connection established successfully to host={0}, port={1}",
|
|
|
+ connectParam.getHost(), connectParam.getPort());
|
|
|
return new Response(Response.Status.SUCCESS);
|
|
|
}
|
|
|
|
|
@@ -321,8 +323,9 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
|
|
|
if (!channelIsReadyOrIdle()) {
|
|
|
logWarning("You are not connected to Milvus server");
|
|
|
- return new SearchResponse(
|
|
|
- new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
|
|
|
List<io.milvus.grpc.RowRecord> queryRowRecordList = getQueryRowRecordList(searchParam);
|
|
@@ -338,29 +341,32 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
.setNprobe(searchParam.getNProbe())
|
|
|
.build();
|
|
|
|
|
|
- io.milvus.grpc.TopKQueryResultList response;
|
|
|
+ io.milvus.grpc.TopKQueryResult response;
|
|
|
|
|
|
try {
|
|
|
response = blockingStub.search(request);
|
|
|
|
|
|
if (response.getStatus().getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
|
|
|
- List<List<SearchResponse.QueryResult>> queryResultsList = getQueryResultsList(response);
|
|
|
+ SearchResponse searchResponse = buildSearchResponse(response);
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.SUCCESS));
|
|
|
logInfo(
|
|
|
"Search completed successfully! Returned results for {0} queries",
|
|
|
- queryResultsList.size());
|
|
|
- return new SearchResponse(new Response(Response.Status.SUCCESS), queryResultsList);
|
|
|
+ searchResponse.getNumQueries());
|
|
|
+ return searchResponse;
|
|
|
} else {
|
|
|
logSevere("Search failed:\n{0}", response.toString());
|
|
|
- return new SearchResponse(
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(
|
|
|
new Response(
|
|
|
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
|
|
|
- response.getStatus().getReason()),
|
|
|
- new ArrayList<>());
|
|
|
+ response.getStatus().getReason()));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
} catch (StatusRuntimeException e) {
|
|
|
logSevere("search RPC failed:\n{0}", e.getStatus().toString());
|
|
|
- return new SearchResponse(
|
|
|
- new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -369,8 +375,9 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
|
|
|
if (!channelIsReadyOrIdle()) {
|
|
|
logWarning("You are not connected to Milvus server");
|
|
|
- return new SearchResponse(
|
|
|
- new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
|
|
|
SearchParam searchParam = searchInFilesParam.getSearchParam();
|
|
@@ -394,30 +401,33 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
.setSearchParam(searchParamToSet)
|
|
|
.build();
|
|
|
|
|
|
- io.milvus.grpc.TopKQueryResultList response;
|
|
|
+ io.milvus.grpc.TopKQueryResult response;
|
|
|
|
|
|
try {
|
|
|
response = blockingStub.searchInFiles(request);
|
|
|
|
|
|
if (response.getStatus().getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
|
|
|
+ SearchResponse searchResponse = buildSearchResponse(response);
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.SUCCESS));
|
|
|
logInfo("Search in files {0} completed successfully!", searchInFilesParam.getFileIds());
|
|
|
-
|
|
|
- List<List<SearchResponse.QueryResult>> queryResultsList = getQueryResultsList(response);
|
|
|
- return new SearchResponse(new Response(Response.Status.SUCCESS), queryResultsList);
|
|
|
+ return searchResponse;
|
|
|
} else {
|
|
|
logSevere(
|
|
|
"Search in files {0} failed:\n{1}",
|
|
|
searchInFilesParam.getFileIds(), response.toString());
|
|
|
- return new SearchResponse(
|
|
|
+
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(
|
|
|
new Response(
|
|
|
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
|
|
|
- response.getStatus().getReason()),
|
|
|
- new ArrayList<>());
|
|
|
+ response.getStatus().getReason()));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
} catch (StatusRuntimeException e) {
|
|
|
logSevere("searchInFiles RPC failed:\n{0}", e.getStatus().toString());
|
|
|
- return new SearchResponse(
|
|
|
- new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -728,25 +738,29 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
.build();
|
|
|
}
|
|
|
|
|
|
- private List<List<SearchResponse.QueryResult>> getQueryResultsList(
|
|
|
- io.milvus.grpc.TopKQueryResultList searchResponse) {
|
|
|
- // TODO: refactor
|
|
|
- List<List<SearchResponse.QueryResult>> queryResultsList = new ArrayList<>();
|
|
|
- Optional<List<io.milvus.grpc.TopKQueryResult>> topKQueryResultList =
|
|
|
- Optional.ofNullable(searchResponse.getTopkQueryResultList());
|
|
|
- if (topKQueryResultList.isPresent()) {
|
|
|
- for (io.milvus.grpc.TopKQueryResult topKQueryResult : topKQueryResultList.get()) {
|
|
|
- List<SearchResponse.QueryResult> responseQueryResults = new ArrayList<>();
|
|
|
- List<io.milvus.grpc.QueryResult> queryResults = topKQueryResult.getQueryResultArraysList();
|
|
|
- for (io.milvus.grpc.QueryResult queryResult : queryResults) {
|
|
|
- SearchResponse.QueryResult responseQueryResult =
|
|
|
- new SearchResponse.QueryResult(queryResult.getId(), queryResult.getDistance());
|
|
|
- responseQueryResults.add(responseQueryResult);
|
|
|
- }
|
|
|
- queryResultsList.add(responseQueryResults);
|
|
|
- }
|
|
|
+ private SearchResponse buildSearchResponse(io.milvus.grpc.TopKQueryResult topKQueryResult) {
|
|
|
+
|
|
|
+ final int numQueries = (int) topKQueryResult.getRowNum();
|
|
|
+ final int topK =
|
|
|
+ numQueries == 0
|
|
|
+ ? 0
|
|
|
+ : topKQueryResult.getIdsCount()
|
|
|
+ / numQueries; // Guaranteed to be divisible from server side
|
|
|
+
|
|
|
+ List<List<Long>> resultIdsList = new ArrayList<>();
|
|
|
+ List<List<Float>> resultDistancesList = new ArrayList<>();
|
|
|
+ if (topK > 0) {
|
|
|
+ resultIdsList = ListUtils.partition(topKQueryResult.getIdsList(), topK);
|
|
|
+ resultDistancesList = ListUtils.partition(topKQueryResult.getDistancesList(), topK);
|
|
|
}
|
|
|
- return queryResultsList;
|
|
|
+
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setNumQueries(numQueries);
|
|
|
+ searchResponse.setTopK(topK);
|
|
|
+ searchResponse.setResultIdsList(resultIdsList);
|
|
|
+ searchResponse.setResultDistancesList(resultDistancesList);
|
|
|
+
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
|
|
|
private boolean channelIsReadyOrIdle() {
|