|
@@ -23,7 +23,6 @@ import com.google.common.util.concurrent.FutureCallback;
|
|
|
import com.google.common.util.concurrent.Futures;
|
|
|
import com.google.common.util.concurrent.ListenableFuture;
|
|
|
import com.google.common.util.concurrent.MoreExecutors;
|
|
|
-import com.google.protobuf.ByteString;
|
|
|
import io.grpc.CallOptions;
|
|
|
import io.grpc.Channel;
|
|
|
import io.grpc.ClientCall;
|
|
@@ -40,7 +39,6 @@ import io.milvus.client.exception.UnsupportedServerVersion;
|
|
|
import io.milvus.grpc.*;
|
|
|
import org.apache.commons.lang3.ArrayUtils;
|
|
|
import org.json.JSONArray;
|
|
|
-import org.json.JSONException;
|
|
|
import org.json.JSONObject;
|
|
|
import org.slf4j.Logger;
|
|
|
import org.slf4j.LoggerFactory;
|
|
@@ -344,288 +342,20 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public SearchResponse search(@Nonnull SearchParam searchParam) {
|
|
|
-
|
|
|
- if (!maybeAvailable()) {
|
|
|
- logWarning("You are not connected to Milvus server");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
|
|
|
- return searchResponse;
|
|
|
- }
|
|
|
-
|
|
|
- // convert DSL to json object and parse to extract vectors
|
|
|
- List<VectorParam> vectorParamList = new ArrayList<>();
|
|
|
- JSONObject jsonObject;
|
|
|
- List<Object> parsedDSL;
|
|
|
- try {
|
|
|
- jsonObject = new JSONObject(searchParam.getDSL());
|
|
|
- parsedDSL = parseDSL(jsonObject);
|
|
|
- } catch (JSONException err){
|
|
|
- logError("DSL must be in correct json format. Refer to examples for more information.");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT, err.toString()));
|
|
|
- return searchResponse;
|
|
|
- }
|
|
|
- if (parsedDSL.size() != 3) {
|
|
|
- logError("DSL must include vector query. Refer to examples for more information.");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
|
|
|
- return searchResponse;
|
|
|
- }
|
|
|
-
|
|
|
- // use placeholder and vectors to create VectorParam list
|
|
|
- String key = parsedDSL.get(2).toString();
|
|
|
- JSONObject outer = (JSONObject) parsedDSL.get(1);
|
|
|
- JSONObject value = (JSONObject) outer.get(key);
|
|
|
- if (!value.has("topk") || !value.has("query") || !value.has("type")) {
|
|
|
- logError("Invalid DSL vector field argument. Refer to examples for more information.");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
|
|
|
- return searchResponse;
|
|
|
- }
|
|
|
- List<VectorRowRecord> vectorRowRecordList = new ArrayList<>();
|
|
|
- if (value.get("type").toString().equals("float")) {
|
|
|
- JSONArray arr = (JSONArray) value.get("query");
|
|
|
- for (int i = 0; i < arr.length(); i++) {
|
|
|
- JSONArray innerArr = (JSONArray) (arr.get(i));
|
|
|
- List<Float> floatList = new ArrayList<>();
|
|
|
- for (int j = 0; j < innerArr.length(); j++) {
|
|
|
- Double num = (Double) innerArr.get(j);
|
|
|
- floatList.add(num.floatValue());
|
|
|
- }
|
|
|
- vectorRowRecordList.add(
|
|
|
- VectorRowRecord.newBuilder()
|
|
|
- .addAllFloatData(floatList)
|
|
|
- .build());
|
|
|
- }
|
|
|
- } else if (value.get("type").toString().equals("binary")) {
|
|
|
- JSONArray arr = (JSONArray) value.get("query");
|
|
|
- for (int i = 0; i < arr.length(); i++) {
|
|
|
- JSONArray innerArr = (JSONArray) (arr.get(i));
|
|
|
- ByteBuffer byteBuffer = ByteBuffer.allocate(innerArr.length());
|
|
|
- for (int j = 0; j < innerArr.length(); j++) {
|
|
|
- byteBuffer = byteBuffer.put(j, ((Integer) innerArr.get(j)).byteValue());
|
|
|
- }
|
|
|
- vectorRowRecordList.add(
|
|
|
- VectorRowRecord.newBuilder()
|
|
|
- .setBinaryData(ByteString.copyFrom(byteBuffer))
|
|
|
- .build());
|
|
|
- }
|
|
|
- } else {
|
|
|
- logError("DSL vector type must be float or binary. Refer to examples for more information.");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
|
|
|
- return searchResponse;
|
|
|
- }
|
|
|
-
|
|
|
- VectorRecord vectorRecord =
|
|
|
- VectorRecord.newBuilder()
|
|
|
- .addAllRecords(vectorRowRecordList)
|
|
|
- .build();
|
|
|
-
|
|
|
- JSONObject json = new JSONObject();
|
|
|
- value.remove("type");
|
|
|
- value.remove("query");
|
|
|
- json.put("placeholder", outer);
|
|
|
- VectorParam vectorParam =
|
|
|
- VectorParam.newBuilder()
|
|
|
- .setJson(json.toString())
|
|
|
- .setRowRecord(vectorRecord)
|
|
|
- .build();
|
|
|
- vectorParamList.add(vectorParam);
|
|
|
-
|
|
|
- KeyValuePair extraParam =
|
|
|
- KeyValuePair.newBuilder()
|
|
|
- .setKey(extraParamKey)
|
|
|
- .setValue(searchParam.getParamsInJson())
|
|
|
- .build();
|
|
|
-
|
|
|
- io.milvus.grpc.SearchParam request =
|
|
|
- io.milvus.grpc.SearchParam.newBuilder()
|
|
|
- .setCollectionName(searchParam.getCollectionName())
|
|
|
- .setDsl(jsonObject.toString())
|
|
|
- .addAllVectorParam(vectorParamList)
|
|
|
- .addAllPartitionTagArray(searchParam.getPartitionTags())
|
|
|
- .addExtraParams(extraParam)
|
|
|
- .build();
|
|
|
-
|
|
|
- QueryResult response;
|
|
|
-
|
|
|
- try {
|
|
|
- response = blockingStub().search(request);
|
|
|
-
|
|
|
- if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
|
|
|
- SearchResponse searchResponse = buildSearchResponse(response);
|
|
|
- searchResponse.setResponse(new Response(Response.Status.SUCCESS));
|
|
|
- logInfo(
|
|
|
- "Search completed successfully! Returned results for {} queries",
|
|
|
- searchResponse.getNumQueries());
|
|
|
- return searchResponse;
|
|
|
- } else {
|
|
|
- logError("Search failed:\n{}", response.getStatus().toString());
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(
|
|
|
- new Response(
|
|
|
- Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
|
|
|
- response.getStatus().getReason()));
|
|
|
- return searchResponse;
|
|
|
- }
|
|
|
- } catch (StatusRuntimeException e) {
|
|
|
- logError("search RPC failed:\n{}", e.getStatus().toString());
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
|
|
|
- return searchResponse;
|
|
|
- }
|
|
|
+ public SearchResult search(@Nonnull SearchParam searchParam) {
|
|
|
+ return translateExceptions(() -> Futures.getUnchecked(searchAsync(searchParam)));
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public ListenableFuture<SearchResponse> searchAsync(@Nonnull SearchParam searchParam) {
|
|
|
-
|
|
|
- if (!maybeAvailable()) {
|
|
|
- logWarning("You are not connected to Milvus server");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
|
|
|
- return Futures.immediateFuture(searchResponse);
|
|
|
- }
|
|
|
-
|
|
|
- // convert DSL to json object and parse to extract vectors
|
|
|
- List<VectorParam> vectorParamList = new ArrayList<>();
|
|
|
- JSONObject jsonObject;
|
|
|
- List<Object> parsedDSL;
|
|
|
- try {
|
|
|
- jsonObject = new JSONObject(searchParam.getDSL());
|
|
|
- parsedDSL = parseDSL(jsonObject);
|
|
|
- } catch (JSONException err){
|
|
|
- logError("DSL must be in correct json format. Refer to examples for more information.");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT, err.toString()));
|
|
|
- return Futures.immediateFuture(searchResponse);
|
|
|
- }
|
|
|
- if (parsedDSL.size() != 3) {
|
|
|
- logError("DSL must include vector query. Refer to examples for more information.");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
|
|
|
- return Futures.immediateFuture(searchResponse);
|
|
|
- }
|
|
|
-
|
|
|
- // use placeholder and vectors to create VectorParam list
|
|
|
- String key = parsedDSL.get(2).toString();
|
|
|
- JSONObject outer = (JSONObject) parsedDSL.get(1);
|
|
|
- JSONObject value = (JSONObject) outer.get(key);
|
|
|
- if (!value.has("topk") || !value.has("query") || !value.has("type")) {
|
|
|
- logError("Invalid DSL vector field argument. Refer to examples for more information.");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
|
|
|
- return Futures.immediateFuture(searchResponse);
|
|
|
- }
|
|
|
- List<VectorRowRecord> vectorRowRecordList = new ArrayList<>();
|
|
|
- if (value.get("type").toString().equals("float")) {
|
|
|
- JSONArray arr = (JSONArray) value.get("query");
|
|
|
- for (int i = 0; i < arr.length(); i++) {
|
|
|
- JSONArray innerArr = (JSONArray) (arr.get(i));
|
|
|
- List<Float> floatList = new ArrayList<>();
|
|
|
- for (int j = 0; j < innerArr.length(); j++) {
|
|
|
- Double num = (Double) innerArr.get(j);
|
|
|
- floatList.add(num.floatValue());
|
|
|
- }
|
|
|
- vectorRowRecordList.add(
|
|
|
- VectorRowRecord.newBuilder()
|
|
|
- .addAllFloatData(floatList)
|
|
|
- .build());
|
|
|
- }
|
|
|
- } else if (value.get("type").toString().equals("binary")) {
|
|
|
- JSONArray arr = (JSONArray) value.get("query");
|
|
|
- for (int i = 0; i < arr.length(); i++) {
|
|
|
- JSONArray innerArr = (JSONArray) (arr.get(i));
|
|
|
- ByteBuffer byteBuffer = ByteBuffer.allocate(innerArr.length());
|
|
|
- for (int j = 0; j < innerArr.length(); j++) {
|
|
|
- byteBuffer = byteBuffer.put(j, ((Integer) innerArr.get(j)).byteValue());
|
|
|
- }
|
|
|
- vectorRowRecordList.add(
|
|
|
- VectorRowRecord.newBuilder()
|
|
|
- .setBinaryData(ByteString.copyFrom(byteBuffer))
|
|
|
- .build());
|
|
|
- }
|
|
|
- } else {
|
|
|
- logError("DSL vector type must be float or binary. Refer to examples for more information.");
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
|
|
|
- return Futures.immediateFuture(searchResponse);
|
|
|
- }
|
|
|
-
|
|
|
- VectorRecord vectorRecord =
|
|
|
- VectorRecord.newBuilder()
|
|
|
- .addAllRecords(vectorRowRecordList)
|
|
|
- .build();
|
|
|
-
|
|
|
- JSONObject json = new JSONObject();
|
|
|
- value.remove("type");
|
|
|
- value.remove("query");
|
|
|
- json.put("placeholder", outer);
|
|
|
- VectorParam vectorParam =
|
|
|
- VectorParam.newBuilder()
|
|
|
- .setJson(json.toString())
|
|
|
- .setRowRecord(vectorRecord)
|
|
|
- .build();
|
|
|
- vectorParamList.add(vectorParam);
|
|
|
-
|
|
|
- KeyValuePair extraParam =
|
|
|
- KeyValuePair.newBuilder()
|
|
|
- .setKey(extraParamKey)
|
|
|
- .setValue(searchParam.getParamsInJson())
|
|
|
- .build();
|
|
|
-
|
|
|
- io.milvus.grpc.SearchParam request =
|
|
|
- io.milvus.grpc.SearchParam.newBuilder()
|
|
|
- .setCollectionName(searchParam.getCollectionName())
|
|
|
- .setDsl(jsonObject.toString())
|
|
|
- .addAllVectorParam(vectorParamList)
|
|
|
- .addAllPartitionTagArray(searchParam.getPartitionTags())
|
|
|
- .addExtraParams(extraParam)
|
|
|
- .build();
|
|
|
-
|
|
|
- ListenableFuture<QueryResult> response;
|
|
|
-
|
|
|
- response = futureStub().search(request);
|
|
|
-
|
|
|
- Futures.addCallback(
|
|
|
- response,
|
|
|
- new FutureCallback<QueryResult>() {
|
|
|
- @Override
|
|
|
- public void onSuccess(QueryResult result) {
|
|
|
- if (result.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
|
|
|
- logInfo(
|
|
|
- "SearchAsync completed successfully! Returned results for {} queries",
|
|
|
- result.getRowNum());
|
|
|
- } else {
|
|
|
- logError("SearchAsync failed:\n{}", result.getStatus().toString());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onFailure(Throwable t) {
|
|
|
- logError("SearchAsync failed:\n{}", t.getMessage());
|
|
|
- }
|
|
|
- },
|
|
|
- MoreExecutors.directExecutor());
|
|
|
-
|
|
|
- Function<QueryResult, SearchResponse> transformFunc =
|
|
|
- topKQueryResult -> {
|
|
|
- if (topKQueryResult.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
|
|
|
- SearchResponse searchResponse = buildSearchResponse(topKQueryResult);
|
|
|
- searchResponse.setResponse(new Response(Response.Status.SUCCESS));
|
|
|
- return searchResponse;
|
|
|
- } else {
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setResponse(
|
|
|
- new Response(
|
|
|
- Response.Status.valueOf(topKQueryResult.getStatus().getErrorCodeValue()),
|
|
|
- topKQueryResult.getStatus().getReason()));
|
|
|
- return searchResponse;
|
|
|
- }
|
|
|
- };
|
|
|
-
|
|
|
- return Futures.transform(response, transformFunc::apply, MoreExecutors.directExecutor());
|
|
|
+ public ListenableFuture<SearchResult> searchAsync(@Nonnull SearchParam searchParam) {
|
|
|
+ return translateExceptions(() -> {
|
|
|
+ io.milvus.grpc.SearchParam request = searchParam.grpc();
|
|
|
+ ListenableFuture<QueryResult> responseFuture = futureStub().search(request);
|
|
|
+ return Futures.transform(responseFuture, queryResult -> {
|
|
|
+ checkResponseStatus(queryResult.getStatus());
|
|
|
+ return buildSearchResponse(queryResult);
|
|
|
+ }, MoreExecutors.directExecutor());
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -1196,14 +926,9 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
- private SearchResponse buildSearchResponse(QueryResult topKQueryResult) {
|
|
|
-
|
|
|
+ private SearchResult buildSearchResponse(QueryResult topKQueryResult) {
|
|
|
final int numQueries = (int) topKQueryResult.getRowNum();
|
|
|
- final int topK =
|
|
|
- numQueries == 0
|
|
|
- ? 0
|
|
|
- : topKQueryResult.getDistancesCount()
|
|
|
- / numQueries; // Guaranteed to be divisible from server side
|
|
|
+ final int topK = numQueries == 0 ? 0 : topKQueryResult.getDistancesCount() / numQueries;
|
|
|
|
|
|
List<List<Long>> resultIdsList = new ArrayList<>(numQueries);
|
|
|
List<List<Float>> resultDistancesList = new ArrayList<>(numQueries);
|
|
@@ -1238,10 +963,7 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
|
if (vectorRowRecordList.get(j).getFloatDataCount() > 0) {
|
|
|
fieldsMap.get(j).put(fieldName, vectorRowRecordList.get(j).getFloatDataList());
|
|
|
} else {
|
|
|
- ByteBuffer bb = vectorRowRecordList.get(j).getBinaryData().asReadOnlyByteBuffer();
|
|
|
- byte[] b = new byte[bb.remaining()];
|
|
|
- bb.get(b);
|
|
|
- fieldsMap.get(j).put(fieldName, Arrays.asList(ArrayUtils.toObject(b)));
|
|
|
+ fieldsMap.get(j).put(fieldName, vectorRowRecordList.get(j).getBinaryData().asReadOnlyByteBuffer());
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1261,14 +983,7 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- SearchResponse searchResponse = new SearchResponse();
|
|
|
- searchResponse.setNumQueries(numQueries);
|
|
|
- searchResponse.setTopK(topK);
|
|
|
- searchResponse.setResultIdsList(resultIdsList);
|
|
|
- searchResponse.setResultDistancesList(resultDistancesList);
|
|
|
- searchResponse.setFieldsMap(resultFieldsMap);
|
|
|
-
|
|
|
- return searchResponse;
|
|
|
+ return new SearchResult(numQueries, topK, resultIdsList, resultDistancesList, resultFieldsMap);
|
|
|
}
|
|
|
|
|
|
private String kvListToString(List<KeyValuePair> kv) {
|
|
@@ -1280,36 +995,6 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
|
|
|
return jsonObject.toString();
|
|
|
}
|
|
|
|
|
|
- private List<Object> parseDSL(JSONObject dsl) {
|
|
|
- Iterator<String> keys = dsl.keys();
|
|
|
- while (keys.hasNext()) {
|
|
|
- String key = keys.next();
|
|
|
- if (key.equals("vector")) {
|
|
|
- // replace dsl vector data by a placeholder string
|
|
|
- List<Object> res = new ArrayList<>();
|
|
|
- JSONObject vecData = (JSONObject) dsl.get(key);
|
|
|
- String name = vecData.keys().next();
|
|
|
- dsl.put(key, "placeholder");
|
|
|
- res.add(dsl);
|
|
|
- res.add(vecData);
|
|
|
- res.add(name);
|
|
|
- return res;
|
|
|
- }
|
|
|
- if (dsl.get(key).getClass() == JSONObject.class) {
|
|
|
- List<Object> res = parseDSL((JSONObject) dsl.get(key));
|
|
|
- if (res.size() > 0) return res;
|
|
|
- } else if (dsl.get(key).getClass() == JSONArray.class) {
|
|
|
- JSONArray arr = (JSONArray) dsl.get(key);
|
|
|
- for (int i = 0; i < arr.length(); i++) {
|
|
|
- JSONObject jsonObject = arr.getJSONObject(i);
|
|
|
- List<Object> res = parseDSL(jsonObject);
|
|
|
- if (res.size() > 0) return res;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return new ArrayList<>();
|
|
|
- }
|
|
|
-
|
|
|
///////////////////// Log Functions//////////////////////
|
|
|
|
|
|
private void logInfo(String msg, Object... params) {
|