|
@@ -23,6 +23,7 @@ import io.grpc.ConnectivityState;
|
|
|
import io.grpc.ManagedChannel;
|
|
|
import io.grpc.ManagedChannelBuilder;
|
|
|
import io.grpc.StatusRuntimeException;
|
|
|
+import org.apache.commons.collections4.ListUtils;
|
|
|
|
|
|
import javax.annotation.Nonnull;
|
|
|
import java.text.SimpleDateFormat;
|
|
@@ -321,8 +322,9 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
|
|
|
if (!channelIsReadyOrIdle()) {
|
|
|
logWarning("You are not connected to Milvus server");
|
|
|
- return new SearchResponse(
|
|
|
- new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
|
|
|
List<io.milvus.grpc.RowRecord> queryRowRecordList = getQueryRowRecordList(searchParam);
|
|
@@ -338,29 +340,32 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
.setNprobe(searchParam.getNProbe())
|
|
|
.build();
|
|
|
|
|
|
- io.milvus.grpc.TopKQueryResultList response;
|
|
|
+ io.milvus.grpc.TopKQueryResult response;
|
|
|
|
|
|
try {
|
|
|
response = blockingStub.search(request);
|
|
|
|
|
|
if (response.getStatus().getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
|
|
|
- List<List<SearchResponse.QueryResult>> queryResultsList = getQueryResultsList(response);
|
|
|
+ SearchResponse searchResponse = buildSearchResponse(response);
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.SUCCESS));
|
|
|
logInfo(
|
|
|
"Search completed successfully! Returned results for {0} queries",
|
|
|
- queryResultsList.size());
|
|
|
- return new SearchResponse(new Response(Response.Status.SUCCESS), queryResultsList);
|
|
|
+ searchResponse.getNumQueries());
|
|
|
+ return searchResponse;
|
|
|
} else {
|
|
|
logSevere("Search failed:\n{0}", response.toString());
|
|
|
- return new SearchResponse(
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(
|
|
|
new Response(
|
|
|
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
|
|
|
- response.getStatus().getReason()),
|
|
|
- new ArrayList<>());
|
|
|
+ response.getStatus().getReason()));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
} catch (StatusRuntimeException e) {
|
|
|
logSevere("search RPC failed:\n{0}", e.getStatus().toString());
|
|
|
- return new SearchResponse(
|
|
|
- new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -369,8 +374,9 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
|
|
|
if (!channelIsReadyOrIdle()) {
|
|
|
logWarning("You are not connected to Milvus server");
|
|
|
- return new SearchResponse(
|
|
|
- new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
|
|
|
SearchParam searchParam = searchInFilesParam.getSearchParam();
|
|
@@ -394,30 +400,33 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
.setSearchParam(searchParamToSet)
|
|
|
.build();
|
|
|
|
|
|
- io.milvus.grpc.TopKQueryResultList response;
|
|
|
+ io.milvus.grpc.TopKQueryResult response;
|
|
|
|
|
|
try {
|
|
|
response = blockingStub.searchInFiles(request);
|
|
|
|
|
|
if (response.getStatus().getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
|
|
|
+ SearchResponse searchResponse = buildSearchResponse(response);
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.SUCCESS));
|
|
|
logInfo("Search in files {0} completed successfully!", searchInFilesParam.getFileIds());
|
|
|
-
|
|
|
- List<List<SearchResponse.QueryResult>> queryResultsList = getQueryResultsList(response);
|
|
|
- return new SearchResponse(new Response(Response.Status.SUCCESS), queryResultsList);
|
|
|
+ return searchResponse;
|
|
|
} else {
|
|
|
logSevere(
|
|
|
"Search in files {0} failed:\n{1}",
|
|
|
searchInFilesParam.getFileIds(), response.toString());
|
|
|
- return new SearchResponse(
|
|
|
+
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(
|
|
|
new Response(
|
|
|
Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
|
|
|
- response.getStatus().getReason()),
|
|
|
- new ArrayList<>());
|
|
|
+ response.getStatus().getReason()));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
} catch (StatusRuntimeException e) {
|
|
|
logSevere("searchInFiles RPC failed:\n{0}", e.getStatus().toString());
|
|
|
- return new SearchResponse(
|
|
|
- new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -728,25 +737,23 @@ public class MilvusGrpcClient implements MilvusClient {
|
|
|
.build();
|
|
|
}
|
|
|
|
|
|
- private List<List<SearchResponse.QueryResult>> getQueryResultsList(
|
|
|
- io.milvus.grpc.TopKQueryResultList searchResponse) {
|
|
|
- // TODO: refactor
|
|
|
- List<List<SearchResponse.QueryResult>> queryResultsList = new ArrayList<>();
|
|
|
- Optional<List<io.milvus.grpc.TopKQueryResult>> topKQueryResultList =
|
|
|
- Optional.ofNullable(searchResponse.getTopkQueryResultList());
|
|
|
- if (topKQueryResultList.isPresent()) {
|
|
|
- for (io.milvus.grpc.TopKQueryResult topKQueryResult : topKQueryResultList.get()) {
|
|
|
- List<SearchResponse.QueryResult> responseQueryResults = new ArrayList<>();
|
|
|
- List<io.milvus.grpc.QueryResult> queryResults = topKQueryResult.getQueryResultArraysList();
|
|
|
- for (io.milvus.grpc.QueryResult queryResult : queryResults) {
|
|
|
- SearchResponse.QueryResult responseQueryResult =
|
|
|
- new SearchResponse.QueryResult(queryResult.getId(), queryResult.getDistance());
|
|
|
- responseQueryResults.add(responseQueryResult);
|
|
|
- }
|
|
|
- queryResultsList.add(responseQueryResults);
|
|
|
- }
|
|
|
- }
|
|
|
- return queryResultsList;
|
|
|
+ private SearchResponse buildSearchResponse(io.milvus.grpc.TopKQueryResult topKQueryResult) {
|
|
|
+
|
|
|
+ final int numQueries = (int) topKQueryResult.getRowNum();
|
|
|
+ final int topK =
|
|
|
+ topKQueryResult.getIdsCount() / numQueries; // Guaranteed to be disable from server side
|
|
|
+
|
|
|
+ List<List<Long>> resultIdsList = ListUtils.partition(topKQueryResult.getIdsList(), topK);
|
|
|
+ List<List<Float>> resultDistancesList =
|
|
|
+ ListUtils.partition(topKQueryResult.getDistancesList(), topK);
|
|
|
+
|
|
|
+ SearchResponse searchResponse = new SearchResponse();
|
|
|
+ searchResponse.setNumQueries(numQueries);
|
|
|
+ searchResponse.setTopK(topK);
|
|
|
+ searchResponse.setResultIdsList(resultIdsList);
|
|
|
+ searchResponse.setResultDistancesList(resultDistancesList);
|
|
|
+
|
|
|
+ return searchResponse;
|
|
|
}
|
|
|
|
|
|
private boolean channelIsReadyOrIdle() {
|