Browse Source

Simplify `search`

jianghua 4 years ago
parent
commit
e81fdbe456

+ 6 - 10
src/main/java/io/milvus/client/MilvusClient.java

@@ -235,13 +235,11 @@ public interface MilvusClient {
    * </code>
    * </code>
    * </pre>
    * </pre>
    *
    *
-   * @return <code>SearchResponse</code>
+   * @return <code>SearchResult</code>
    * @see SearchParam
    * @see SearchParam
-   * @see SearchResponse
-   * @see SearchResponse.QueryResult
-   * @see Response
+   * @see SearchResult
    */
    */
-  SearchResponse search(SearchParam searchParam);
+  SearchResult search(SearchParam searchParam);
 
 
   /**
   /**
    * Searches entities specified by <code>searchParam</code> asynchronously
    * Searches entities specified by <code>searchParam</code> asynchronously
@@ -258,14 +256,12 @@ public interface MilvusClient {
    * </code>
    * </code>
    * </pre>
    * </pre>
    *
    *
-   * @return a <code>ListenableFuture</code> object which holds the <code>SearchResponse</code>
+   * @return a <code>ListenableFuture</code> object which holds the <code>SearchResult</code>
    * @see SearchParam
    * @see SearchParam
-   * @see SearchResponse
-   * @see SearchResponse.QueryResult
-   * @see Response
+   * @see SearchResult
    * @see ListenableFuture
    * @see ListenableFuture
    */
    */
-  ListenableFuture<SearchResponse> searchAsync(SearchParam searchParam);
+  ListenableFuture<SearchResult> searchAsync(SearchParam searchParam);
 
 
   /**
   /**
    * Gets collection info
    * Gets collection info

+ 15 - 330
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -23,7 +23,6 @@ import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.common.util.concurrent.MoreExecutors;
-import com.google.protobuf.ByteString;
 import io.grpc.CallOptions;
 import io.grpc.CallOptions;
 import io.grpc.Channel;
 import io.grpc.Channel;
 import io.grpc.ClientCall;
 import io.grpc.ClientCall;
@@ -40,7 +39,6 @@ import io.milvus.client.exception.UnsupportedServerVersion;
 import io.milvus.grpc.*;
 import io.milvus.grpc.*;
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.lang3.ArrayUtils;
 import org.json.JSONArray;
 import org.json.JSONArray;
-import org.json.JSONException;
 import org.json.JSONObject;
 import org.json.JSONObject;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
@@ -344,288 +342,20 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
   }
   }
 
 
   @Override
   @Override
-  public SearchResponse search(@Nonnull SearchParam searchParam) {
-
-    if (!maybeAvailable()) {
-      logWarning("You are not connected to Milvus server");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
-      return searchResponse;
-    }
-
-    // convert DSL to json object and parse to extract vectors
-    List<VectorParam> vectorParamList = new ArrayList<>();
-    JSONObject jsonObject;
-    List<Object> parsedDSL;
-    try {
-      jsonObject = new JSONObject(searchParam.getDSL());
-      parsedDSL = parseDSL(jsonObject);
-    } catch (JSONException err){
-      logError("DSL must be in correct json format. Refer to examples for more information.");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT, err.toString()));
-      return searchResponse;
-    }
-    if (parsedDSL.size() != 3) {
-      logError("DSL must include vector query. Refer to examples for more information.");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
-      return searchResponse;
-    }
-
-    // use placeholder and vectors to create VectorParam list
-    String key = parsedDSL.get(2).toString();
-    JSONObject outer = (JSONObject) parsedDSL.get(1);
-    JSONObject value = (JSONObject) outer.get(key);
-    if (!value.has("topk") || !value.has("query") || !value.has("type")) {
-      logError("Invalid DSL vector field argument. Refer to examples for more information.");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
-      return searchResponse;
-    }
-    List<VectorRowRecord> vectorRowRecordList = new ArrayList<>();
-    if (value.get("type").toString().equals("float")) {
-      JSONArray arr = (JSONArray) value.get("query");
-      for (int i = 0; i < arr.length(); i++) {
-        JSONArray innerArr = (JSONArray) (arr.get(i));
-        List<Float> floatList = new ArrayList<>();
-        for (int j = 0; j < innerArr.length(); j++) {
-          Double num = (Double) innerArr.get(j);
-          floatList.add(num.floatValue());
-        }
-        vectorRowRecordList.add(
-            VectorRowRecord.newBuilder()
-                .addAllFloatData(floatList)
-                .build());
-      }
-    } else if (value.get("type").toString().equals("binary")) {
-      JSONArray arr = (JSONArray) value.get("query");
-      for (int i = 0; i < arr.length(); i++) {
-        JSONArray innerArr = (JSONArray) (arr.get(i));
-        ByteBuffer byteBuffer = ByteBuffer.allocate(innerArr.length());
-        for (int j = 0; j < innerArr.length(); j++) {
-          byteBuffer = byteBuffer.put(j, ((Integer) innerArr.get(j)).byteValue());
-        }
-        vectorRowRecordList.add(
-            VectorRowRecord.newBuilder()
-                .setBinaryData(ByteString.copyFrom(byteBuffer))
-                .build());
-      }
-    } else {
-      logError("DSL vector type must be float or binary. Refer to examples for more information.");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
-      return searchResponse;
-    }
-
-    VectorRecord vectorRecord =
-        VectorRecord.newBuilder()
-            .addAllRecords(vectorRowRecordList)
-            .build();
-
-    JSONObject json = new JSONObject();
-    value.remove("type");
-    value.remove("query");
-    json.put("placeholder", outer);
-    VectorParam vectorParam =
-        VectorParam.newBuilder()
-            .setJson(json.toString())
-            .setRowRecord(vectorRecord)
-            .build();
-    vectorParamList.add(vectorParam);
-
-    KeyValuePair extraParam =
-        KeyValuePair.newBuilder()
-            .setKey(extraParamKey)
-            .setValue(searchParam.getParamsInJson())
-            .build();
-
-    io.milvus.grpc.SearchParam request =
-        io.milvus.grpc.SearchParam.newBuilder()
-            .setCollectionName(searchParam.getCollectionName())
-            .setDsl(jsonObject.toString())
-            .addAllVectorParam(vectorParamList)
-            .addAllPartitionTagArray(searchParam.getPartitionTags())
-            .addExtraParams(extraParam)
-            .build();
-
-    QueryResult response;
-
-    try {
-      response = blockingStub().search(request);
-
-      if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
-        SearchResponse searchResponse = buildSearchResponse(response);
-        searchResponse.setResponse(new Response(Response.Status.SUCCESS));
-        logInfo(
-            "Search completed successfully! Returned results for {} queries",
-            searchResponse.getNumQueries());
-        return searchResponse;
-      } else {
-        logError("Search failed:\n{}", response.getStatus().toString());
-        SearchResponse searchResponse = new SearchResponse();
-        searchResponse.setResponse(
-            new Response(
-                Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
-                response.getStatus().getReason()));
-        return searchResponse;
-      }
-    } catch (StatusRuntimeException e) {
-      logError("search RPC failed:\n{}", e.getStatus().toString());
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
-      return searchResponse;
-    }
+  public SearchResult search(@Nonnull SearchParam searchParam) {
+    return translateExceptions(() -> Futures.getUnchecked(searchAsync(searchParam)));
   }
   }
 
 
   @Override
   @Override
-  public ListenableFuture<SearchResponse> searchAsync(@Nonnull SearchParam searchParam) {
-
-    if (!maybeAvailable()) {
-      logWarning("You are not connected to Milvus server");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
-      return Futures.immediateFuture(searchResponse);
-    }
-
-    // convert DSL to json object and parse to extract vectors
-    List<VectorParam> vectorParamList = new ArrayList<>();
-    JSONObject jsonObject;
-    List<Object> parsedDSL;
-    try {
-      jsonObject = new JSONObject(searchParam.getDSL());
-      parsedDSL = parseDSL(jsonObject);
-    } catch (JSONException err){
-      logError("DSL must be in correct json format. Refer to examples for more information.");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT, err.toString()));
-      return Futures.immediateFuture(searchResponse);
-    }
-    if (parsedDSL.size() != 3) {
-      logError("DSL must include vector query. Refer to examples for more information.");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
-      return Futures.immediateFuture(searchResponse);
-    }
-
-    // use placeholder and vectors to create VectorParam list
-    String key = parsedDSL.get(2).toString();
-    JSONObject outer = (JSONObject) parsedDSL.get(1);
-    JSONObject value = (JSONObject) outer.get(key);
-    if (!value.has("topk") || !value.has("query") || !value.has("type")) {
-      logError("Invalid DSL vector field argument. Refer to examples for more information.");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
-      return Futures.immediateFuture(searchResponse);
-    }
-    List<VectorRowRecord> vectorRowRecordList = new ArrayList<>();
-    if (value.get("type").toString().equals("float")) {
-      JSONArray arr = (JSONArray) value.get("query");
-      for (int i = 0; i < arr.length(); i++) {
-        JSONArray innerArr = (JSONArray) (arr.get(i));
-        List<Float> floatList = new ArrayList<>();
-        for (int j = 0; j < innerArr.length(); j++) {
-          Double num = (Double) innerArr.get(j);
-          floatList.add(num.floatValue());
-        }
-        vectorRowRecordList.add(
-            VectorRowRecord.newBuilder()
-                .addAllFloatData(floatList)
-                .build());
-      }
-    } else if (value.get("type").toString().equals("binary")) {
-      JSONArray arr = (JSONArray) value.get("query");
-      for (int i = 0; i < arr.length(); i++) {
-        JSONArray innerArr = (JSONArray) (arr.get(i));
-        ByteBuffer byteBuffer = ByteBuffer.allocate(innerArr.length());
-        for (int j = 0; j < innerArr.length(); j++) {
-          byteBuffer = byteBuffer.put(j, ((Integer) innerArr.get(j)).byteValue());
-        }
-        vectorRowRecordList.add(
-            VectorRowRecord.newBuilder()
-                .setBinaryData(ByteString.copyFrom(byteBuffer))
-                .build());
-      }
-    } else {
-      logError("DSL vector type must be float or binary. Refer to examples for more information.");
-      SearchResponse searchResponse = new SearchResponse();
-      searchResponse.setResponse(new Response(Response.Status.ILLEGAL_ARGUMENT));
-      return Futures.immediateFuture(searchResponse);
-    }
-
-    VectorRecord vectorRecord =
-        VectorRecord.newBuilder()
-            .addAllRecords(vectorRowRecordList)
-            .build();
-
-    JSONObject json = new JSONObject();
-    value.remove("type");
-    value.remove("query");
-    json.put("placeholder", outer);
-    VectorParam vectorParam =
-        VectorParam.newBuilder()
-            .setJson(json.toString())
-            .setRowRecord(vectorRecord)
-            .build();
-    vectorParamList.add(vectorParam);
-
-    KeyValuePair extraParam =
-        KeyValuePair.newBuilder()
-            .setKey(extraParamKey)
-            .setValue(searchParam.getParamsInJson())
-            .build();
-
-    io.milvus.grpc.SearchParam request =
-        io.milvus.grpc.SearchParam.newBuilder()
-            .setCollectionName(searchParam.getCollectionName())
-            .setDsl(jsonObject.toString())
-            .addAllVectorParam(vectorParamList)
-            .addAllPartitionTagArray(searchParam.getPartitionTags())
-            .addExtraParams(extraParam)
-            .build();
-
-    ListenableFuture<QueryResult> response;
-
-    response = futureStub().search(request);
-
-    Futures.addCallback(
-        response,
-        new FutureCallback<QueryResult>() {
-          @Override
-          public void onSuccess(QueryResult result) {
-            if (result.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
-              logInfo(
-                  "SearchAsync completed successfully! Returned results for {} queries",
-                  result.getRowNum());
-            } else {
-              logError("SearchAsync failed:\n{}", result.getStatus().toString());
-            }
-          }
-
-          @Override
-          public void onFailure(Throwable t) {
-            logError("SearchAsync failed:\n{}", t.getMessage());
-          }
-        },
-        MoreExecutors.directExecutor());
-
-    Function<QueryResult, SearchResponse> transformFunc =
-        topKQueryResult -> {
-          if (topKQueryResult.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
-            SearchResponse searchResponse = buildSearchResponse(topKQueryResult);
-            searchResponse.setResponse(new Response(Response.Status.SUCCESS));
-            return searchResponse;
-          } else {
-            SearchResponse searchResponse = new SearchResponse();
-            searchResponse.setResponse(
-                new Response(
-                    Response.Status.valueOf(topKQueryResult.getStatus().getErrorCodeValue()),
-                    topKQueryResult.getStatus().getReason()));
-            return searchResponse;
-          }
-        };
-
-    return Futures.transform(response, transformFunc::apply, MoreExecutors.directExecutor());
+  public ListenableFuture<SearchResult> searchAsync(@Nonnull SearchParam searchParam) {
+    return translateExceptions(() -> {
+      io.milvus.grpc.SearchParam request = searchParam.grpc();
+      ListenableFuture<QueryResult> responseFuture = futureStub().search(request);
+      return Futures.transform(responseFuture, queryResult -> {
+        checkResponseStatus(queryResult.getStatus());
+        return buildSearchResponse(queryResult);
+      }, MoreExecutors.directExecutor());
+    });
   }
   }
 
 
   @Override
   @Override
@@ -1196,14 +926,9 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
         }
         }
       };
       };
 
 
-  private SearchResponse buildSearchResponse(QueryResult topKQueryResult) {
-
+  private SearchResult buildSearchResponse(QueryResult topKQueryResult) {
     final int numQueries = (int) topKQueryResult.getRowNum();
     final int numQueries = (int) topKQueryResult.getRowNum();
-    final int topK =
-        numQueries == 0
-            ? 0
-            : topKQueryResult.getDistancesCount()
-                / numQueries; // Guaranteed to be divisible from server side
+    final int topK = numQueries == 0 ? 0 : topKQueryResult.getDistancesCount() / numQueries;
 
 
     List<List<Long>> resultIdsList = new ArrayList<>(numQueries);
     List<List<Long>> resultIdsList = new ArrayList<>(numQueries);
     List<List<Float>> resultDistancesList = new ArrayList<>(numQueries);
     List<List<Float>> resultDistancesList = new ArrayList<>(numQueries);
@@ -1238,10 +963,7 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
             if (vectorRowRecordList.get(j).getFloatDataCount() > 0) {
             if (vectorRowRecordList.get(j).getFloatDataCount() > 0) {
               fieldsMap.get(j).put(fieldName, vectorRowRecordList.get(j).getFloatDataList());
               fieldsMap.get(j).put(fieldName, vectorRowRecordList.get(j).getFloatDataList());
             } else {
             } else {
-              ByteBuffer bb = vectorRowRecordList.get(j).getBinaryData().asReadOnlyByteBuffer();
-              byte[] b = new byte[bb.remaining()];
-              bb.get(b);
-              fieldsMap.get(j).put(fieldName, Arrays.asList(ArrayUtils.toObject(b)));
+              fieldsMap.get(j).put(fieldName, vectorRowRecordList.get(j).getBinaryData().asReadOnlyByteBuffer());
             }
             }
           }
           }
         }
         }
@@ -1261,14 +983,7 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
       }
       }
     }
     }
 
 
-    SearchResponse searchResponse = new SearchResponse();
-    searchResponse.setNumQueries(numQueries);
-    searchResponse.setTopK(topK);
-    searchResponse.setResultIdsList(resultIdsList);
-    searchResponse.setResultDistancesList(resultDistancesList);
-    searchResponse.setFieldsMap(resultFieldsMap);
-
-    return searchResponse;
+    return new SearchResult(numQueries, topK, resultIdsList, resultDistancesList, resultFieldsMap);
   }
   }
 
 
   private String kvListToString(List<KeyValuePair> kv) {
   private String kvListToString(List<KeyValuePair> kv) {
@@ -1280,36 +995,6 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
     return jsonObject.toString();
     return jsonObject.toString();
   }
   }
 
 
-  private List<Object> parseDSL(JSONObject dsl) {
-    Iterator<String> keys = dsl.keys();
-    while (keys.hasNext()) {
-      String key = keys.next();
-      if (key.equals("vector")) {
-        // replace dsl vector data by a placeholder string
-        List<Object> res = new ArrayList<>();
-        JSONObject vecData = (JSONObject) dsl.get(key);
-        String name = vecData.keys().next();
-        dsl.put(key, "placeholder");
-        res.add(dsl);
-        res.add(vecData);
-        res.add(name);
-        return res;
-      }
-      if (dsl.get(key).getClass() == JSONObject.class) {
-        List<Object> res = parseDSL((JSONObject) dsl.get(key));
-        if (res.size() > 0) return res;
-      } else if (dsl.get(key).getClass() == JSONArray.class) {
-        JSONArray arr = (JSONArray) dsl.get(key);
-        for (int i = 0; i < arr.length(); i++) {
-          JSONObject jsonObject = arr.getJSONObject(i);
-          List<Object> res = parseDSL(jsonObject);
-          if (res.size() > 0) return res;
-        }
-      }
-    }
-    return new ArrayList<>();
-  }
-
   ///////////////////// Log Functions//////////////////////
   ///////////////////// Log Functions//////////////////////
 
 
   private void logInfo(String msg, Object... params) {
   private void logInfo(String msg, Object... params) {

+ 128 - 106
src/main/java/io/milvus/client/SearchParam.java

@@ -19,133 +19,155 @@
 
 
 package io.milvus.client;
 package io.milvus.client;
 
 
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.UnsafeByteOperations;
+import io.milvus.client.exception.InvalidDsl;
+import io.milvus.grpc.KeyValuePair;
+import io.milvus.grpc.VectorParam;
+import io.milvus.grpc.VectorRecord;
+import io.milvus.grpc.VectorRowRecord;
+import org.json.JSONArray;
+import org.json.JSONException;
+import org.json.JSONObject;
+
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
 import java.util.List;
-import java.util.Map;
-import javax.annotation.Nonnull;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
 
 
 /** Contains parameters for <code>search</code> */
 /** Contains parameters for <code>search</code> */
 public class SearchParam {
 public class SearchParam {
+  private static final String VECTOR_QUERY_KEY = "vector";
+  private static final String VECTOR_QUERY_PLACEHOLDER = "placeholder";
 
 
-  private final String collectionName;
-  private final String dsl;
-  private final List<String> partitionTags;
-  private final String paramsInJson;
+  private io.milvus.grpc.SearchParam.Builder builder;
 
 
-  private SearchParam(@Nonnull Builder builder) {
-    this.collectionName = builder.collectionName;
-    this.dsl = builder.dsl;
-    this.partitionTags = builder.partitionTags;
-    this.paramsInJson = builder.paramsInJson;
+  public static SearchParam create(String collectionName) {
+    return new SearchParam(collectionName);
   }
   }
 
 
-  public String getCollectionName() {
-    return collectionName;
+  private SearchParam(String collectionName) {
+    builder = io.milvus.grpc.SearchParam.newBuilder();
+    builder.setCollectionName(collectionName);
   }
   }
 
 
-  public String getDSL() { return dsl; }
-
-  public List<String> getPartitionTags() {
-    return partitionTags;
+  public SearchParam setDsl(String dsl) {
+    try {
+      JSONObject dslJson = new JSONObject(dsl);
+      JSONObject vectorQueryParent = locateVectorQuery(dslJson)
+          .orElseThrow(() -> new InvalidDsl("A vector query must be specified", dsl));
+      JSONObject vectorQueries = vectorQueryParent.getJSONObject(VECTOR_QUERY_KEY);
+      vectorQueryParent.put(VECTOR_QUERY_KEY, VECTOR_QUERY_PLACEHOLDER);
+      String vectorQueryField = vectorQueries.keys().next();
+      JSONObject vectorQuery = vectorQueries.getJSONObject(vectorQueryField);
+      String vectorQueryType = vectorQuery.getString("type");
+      JSONArray vectorQueryData = vectorQuery.getJSONArray("query");
+
+      VectorRecord vectorRecord;
+      switch (vectorQueryType) {
+        case "float":
+          vectorRecord = toFloatVectorRecord(vectorQueryData);
+          break;
+        case "binary":
+          vectorRecord = toBinaryVectorRecord(vectorQueryData);
+          break;
+        default:
+          throw new InvalidDsl("Unsupported vector type: " + vectorQueryType, dsl);
+      }
+
+      JSONObject json = new JSONObject();
+      vectorQuery.remove("type");
+      vectorQuery.remove("query");
+      json.put("placeholder", vectorQueries);
+      VectorParam vectorParam = VectorParam.newBuilder()
+          .setJson(json.toString())
+          .setRowRecord(vectorRecord)
+          .build();
+
+      builder.setDsl(dslJson.toString())
+          .addAllVectorParam(ImmutableList.of(vectorParam));
+      return this;
+    } catch (JSONException e) {
+      throw new InvalidDsl(e.getMessage(), dsl);
+    }
   }
   }
 
 
-  public String getParamsInJson() {
-    return paramsInJson;
+  public SearchParam setPartitionTags(List<String> partitionTags) {
+    builder.addAllPartitionTagArray(partitionTags);
+    return this;
   }
   }
 
 
-  /** Builder for <code>SearchParam</code> */
-  public static class Builder {
-    // Required parameter
-    private final String collectionName;
+  public SearchParam setParamsInJson(String paramsInJson) {
+    builder.addExtraParams(KeyValuePair.newBuilder()
+        .setKey(MilvusClient.extraParamKey)
+        .setValue(paramsInJson)
+        .build());
+    return this;
+  }
 
 
-    // Optional parameters - initialized to default values
-    private List<String> partitionTags = new ArrayList<>();
-    private String dsl = "{}";
-    private String paramsInJson = "{}";
+  io.milvus.grpc.SearchParam grpc() {
+    return builder.build();
+  }
 
 
-    /** @param collectionName collection to search from */
-    public Builder(@Nonnull String collectionName) {
-      this.collectionName = collectionName;
-    }
+  private Optional<JSONObject> locateVectorQuery(Object obj) {
+    return obj instanceof JSONObject ? locateVectorQuery((JSONObject) obj)
+        : obj instanceof JSONArray ? locateVectorQuery((JSONArray) obj)
+        : Optional.empty();
+  }
 
 
-    /**
-     * The DSL statement for search. DSL provides a more convenient and idiomatic way to write and
-     * manipulate queries. It is in JSON format (passed into builder as String), and an example of
-     * DSL statement is as follows.
-     *
-     * <pre>
-     *   <code>
-     * {
-     *     "bool": {
-     *         "must": [
-     *             {
-     *                 "term": {
-     *                     "A": [1, 2, 5]
-     *                 }
-     *             },
-     *             {
-     *                 "range": {
-     *                     "B": {"GT": 1, "LT": 100}
-     *                 }
-     *             },
-     *             {
-     *                 "vector": {
-     *                     "Vec": {
-     *                         "topk": 10, "type": "float", "query": list_of_vecs, "params": {"nprobe": 10}
-     *                     }
-     *                 }
-     *             }
-     *         ],
-     *     },
-     * }
-     *   </code>
-     * </pre>
-     *
-     * <p>Note that "vector" must be included in DSL. The "params" in "Vec" is different for different
-     * index types. Refer to Milvus documentation for more information about DSL.</p>
-     *
-     * <p>A "type" key must be present in "Vec" field to indicate whether your query vectors are
-     * "float" or "binary".</p>
-     *
-     * @param dsl The DSL String in JSON format
-     * @return <code>Builder</code>
-     */
-    public SearchParam.Builder withDSL(@Nonnull String dsl) {
-      this.dsl = dsl;
-      return this;
-    }
+  private Optional<JSONObject> locateVectorQuery(JSONArray array) {
+    return StreamSupport.stream(array.spliterator(), false)
+        .map(this::locateVectorQuery)
+        .filter(Optional::isPresent)
+        .map(Optional::get)
+        .findFirst();
+  }
 
 
-    /**
-     * Optional. Search vectors with corresponding <code>partitionTags</code>. Default to an empty
-     * <code>List</code>
-     *
-     * @param partitionTags a <code>List</code> of partition tags
-     * @return <code>Builder</code>
-     */
-    public SearchParam.Builder withPartitionTags(@Nonnull List<String> partitionTags) {
-      this.partitionTags = partitionTags;
-      return this;
+  private Optional<JSONObject> locateVectorQuery(JSONObject obj) {
+    if (obj.opt(VECTOR_QUERY_KEY) instanceof JSONObject) {
+      return Optional.of(obj);
     }
     }
+    return obj.keySet().stream()
+        .map(key -> locateVectorQuery(obj.get(key)))
+        .filter(Optional::isPresent)
+        .map(Optional::get)
+        .findFirst();
+  }
 
 
-    /**
-     * Optional. Default to empty <code>String</code>. This is to specify the fields you would like
-     * Milvus server to return from query results. No field information will be returned if this
-     * is not specified. Example:
-     * <pre>
-     *   {"fields": ["B", "D"]}
-     * </pre>
-     *
-     * @param paramsInJson extra parameters in JSON format
-     * @return <code>Builder</code>
-     */
-    public SearchParam.Builder withParamsInJson(@Nonnull String paramsInJson) {
-      this.paramsInJson = paramsInJson;
-      return this;
-    }
+  private VectorRecord toFloatVectorRecord(JSONArray data) {
+    return VectorRecord.newBuilder().addAllRecords(
+        StreamSupport.stream(data.spliterator(), false)
+            .map(element -> (JSONArray) element)
+            .map(array -> {
+              int dimension = array.length();
+              List<Float> vector = new ArrayList<>(dimension);
+              for (int i = 0; i < dimension; i++) {
+                vector.add(array.getFloat(i));
+              }
+              return VectorRowRecord.newBuilder().addAllFloatData(vector).build();
+            })
+            .collect(Collectors.toList()))
+        .build();
+  }
 
 
-    public SearchParam build() {
-      return new SearchParam(this);
-    }
+  private VectorRecord toBinaryVectorRecord(JSONArray data) {
+    return VectorRecord.newBuilder().addAllRecords(
+        StreamSupport.stream(data.spliterator(), false)
+            .map(element -> (JSONArray) element)
+            .map(array -> {
+              int dimension = array.length();
+              ByteBuffer bytes = ByteBuffer.allocate(dimension);
+              for (int i = 0; i < dimension; i++) {
+                bytes.put(array.getNumber(i).byteValue());
+              }
+              bytes.flip();
+              ByteString vector = UnsafeByteOperations.unsafeWrap(bytes);
+              return VectorRowRecord.newBuilder().setBinaryData(vector).build();
+            })
+            .collect(Collectors.toList()))
+        .build();
   }
   }
 }
 }

+ 0 - 150
src/main/java/io/milvus/client/SearchResponse.java

@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package io.milvus.client;
-
-import java.util.List;
-import java.util.Map;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import java.util.stream.LongStream;
-
-/**
- * Contains the returned <code>response</code> and query results for <code>search</code>
- */
-public class SearchResponse {
-
-  private Response response;
-  private int numQueries;
-  private long topK;
-  private List<List<Long>> resultIdsList;
-  private List<List<Float>> resultDistancesList;
-  private List<List<Map<String, Object>>> fieldsMap;
-
-  public int getNumQueries() {
-    return numQueries;
-  }
-
-  void setNumQueries(int numQueries) {
-    this.numQueries = numQueries;
-  }
-
-  public long getTopK() {
-    return topK;
-  }
-
-  void setTopK(long topK) {
-    this.topK = topK;
-  }
-
-  /**
-   * @return a <code>List</code> of <code>QueryResult</code>s. Each inner <code>List</code> contains
-   *     the query result of an entity.
-   */
-  public List<List<QueryResult>> getQueryResultsList() {
-    return IntStream.range(0, numQueries)
-        .mapToObj(
-            i ->
-                IntStream.range(0, resultIdsList.get(i).size())
-                    .mapToObj(
-                        j ->
-                            new QueryResult(
-                                resultIdsList.get(i).get(j),
-                                resultDistancesList.get(i).get(j)))
-                    .collect(Collectors.toList()))
-        .collect(Collectors.toList());
-  }
-
-  /**
-   * @return a <code>List</code> of result ids. Each inner <code>List</code> contains the result ids
-   *     of an entity.
-   */
-  public List<List<Long>> getResultIdsList() {
-    return resultIdsList;
-  }
-
-  void setResultIdsList(List<List<Long>> resultIdsList) {
-    this.resultIdsList = resultIdsList;
-  }
-
-  /**
-   * @return a <code>List</code> of result distances. Each inner <code>List</code> contains
-   *     the result distances of an entity.
-   */
-  public List<List<Float>> getResultDistancesList() {
-    return resultDistancesList;
-  }
-
-  void setResultDistancesList(List<List<Float>> resultDistancesList) {
-    this.resultDistancesList = resultDistancesList;
-  }
-
-  public Response getResponse() {
-    return response;
-  }
-
-  void setResponse(Response response) {
-    this.response = response;
-  }
-
-  /**
-   * @return A <code>List</code> of map with fields information. Each inner <code>List</code> contains
-   * a <code>Map</code> of field names to records in a row.
-   * The record object can be one of int, long, float, double, List<Float> or List<Byte>
-   * depending on the field's <code>DataType</code> you specified.
-   */
-  public List<List<Map<String, Object>>> getFieldsMap() { return fieldsMap; }
-
-  void setFieldsMap(List<List<Map<String, Object>>> fieldsMap) {
-    this.fieldsMap = fieldsMap;
-  }
-
-  /** @return <code>true</code> if the response status equals SUCCESS */
-  public boolean ok() {
-    return response.ok();
-  }
-
-  @Override
-  public String toString() {
-    return String.format(
-        "SearchResponse {%s, returned results for %d queries}", response.toString(), numQueries);
-  }
-
-  /**
-   * Represents a single result of an entity query. Contains the result <code>entityId</code> and its
-   * <code>distance</code> to the entity being queried
-   */
-  public static class QueryResult {
-    private final long entityId;
-    private final float distance;
-
-    QueryResult(long entityId, float distance) {
-      this.entityId = entityId;
-      this.distance = distance;
-    }
-
-    public long getEntityId() {
-      return entityId;
-    }
-
-    public float getDistance() {
-      return distance;
-    }
-  }
-}

+ 72 - 0
src/main/java/io/milvus/client/SearchResult.java

@@ -0,0 +1,72 @@
+package io.milvus.client;
+
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+public class SearchResult {
+  private int numQueries;
+  private long topK;
+  private List<List<Long>> resultIdsList;
+  private List<List<Float>> resultDistancesList;
+  private List<List<Map<String, Object>>> fieldsMap;
+
+  public SearchResult(int numQueries,
+                      long topK,
+                      List<List<Long>> resultIdsList,
+                      List<List<Float>> resultDistancesList,
+                      List<List<Map<String, Object>>> fieldsMap) {
+    this.numQueries = numQueries;
+    this.topK = topK;
+    this.resultIdsList = resultIdsList;
+    this.resultDistancesList = resultDistancesList;
+    this.fieldsMap = fieldsMap;
+  }
+
+  public int getNumQueries() {
+    return numQueries;
+  }
+
+  public long getTopK() {
+    return topK;
+  }
+
+  public List<List<Long>> getResultIdsList() {
+    return resultIdsList;
+  }
+
+  public List<List<Float>> getResultDistancesList() {
+    return resultDistancesList;
+  }
+
+  public List<List<Map<String, Object>>> getFieldsMap() {
+    return fieldsMap;
+  }
+
+  public List<List<QueryResult>> getQueryResultsList() {
+    return IntStream.range(0, numQueries)
+        .mapToObj(i -> IntStream.range(0, resultIdsList.get(i).size())
+            .mapToObj(j -> new QueryResult(resultIdsList.get(i).get(j), resultDistancesList.get(i).get(j)))
+            .collect(Collectors.toList()))
+        .collect(Collectors.toList());
+  }
+
+  public static class QueryResult {
+    private final long entityId;
+    private final float distance;
+
+    QueryResult(long entityId, float distance) {
+      this.entityId = entityId;
+      this.distance = distance;
+    }
+
+    public long getEntityId() {
+      return entityId;
+    }
+
+    public float getDistance() {
+      return distance;
+    }
+  }
+}

+ 15 - 0
src/main/java/io/milvus/client/exception/InvalidDsl.java

@@ -0,0 +1,15 @@
+package io.milvus.client.exception;
+
+public class InvalidDsl extends ClientSideMilvusException {
+  private String dsl;
+
+  public InvalidDsl(String dsl, String message) {
+    super(null, message);
+    this.dsl = dsl;
+  }
+
+  @Override
+  protected String getErrorMessage() {
+    return super.getErrorMessage() + ": " + dsl;
+  }
+}

+ 28 - 90
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -396,28 +396,23 @@ class MilvusClientTest {
     List<List<Float>> vectorsToSearch1 = vectors.subList(0, searchSize);
     List<List<Float>> vectorsToSearch1 = vectors.subList(0, searchSize);
     List<String> partitionTags1 = new ArrayList<>();
     List<String> partitionTags1 = new ArrayList<>();
     partitionTags1.add(tag1);
     partitionTags1.add(tag1);
-    SearchParam searchParam1 =
-        new SearchParam.Builder(randomCollectionName)
-            .withDSL(generateSimpleDSL(topK, vectorsToSearch1.toString()))
-            .withPartitionTags(partitionTags1)
-            .build();
-    SearchResponse searchResponse1 = client.search(searchParam1);
-    assertTrue(searchResponse1.ok());
-    List<List<Long>> resultIdsList1 = searchResponse1.getResultIdsList();
+    SearchParam searchParam1 = SearchParam
+        .create(randomCollectionName)
+        .setDsl(generateSimpleDSL(topK, vectorsToSearch1.toString()))
+        .setPartitionTags(partitionTags1);
+    SearchResult searchResult1 = client.search(searchParam1);
+    List<List<Long>> resultIdsList1 = searchResult1.getResultIdsList();
     assertEquals(searchSize, resultIdsList1.size());
     assertEquals(searchSize, resultIdsList1.size());
     assertTrue(entityIds1.containsAll(resultIdsList1.get(0)));
     assertTrue(entityIds1.containsAll(resultIdsList1.get(0)));
 
 
     List<List<Float>> vectorsToSearch2 = vectors.subList(0, searchSize);
     List<List<Float>> vectorsToSearch2 = vectors.subList(0, searchSize);
     List<String> partitionTags2 = new ArrayList<>();
     List<String> partitionTags2 = new ArrayList<>();
     partitionTags2.add(tag2);
     partitionTags2.add(tag2);
-    SearchParam searchParam2 =
-        new SearchParam.Builder(randomCollectionName)
-            .withDSL(generateSimpleDSL(topK, vectorsToSearch2.toString()))
-            .withPartitionTags(partitionTags2)
-            .build();
-    SearchResponse searchResponse2 = client.search(searchParam2);
-    assertTrue(searchResponse2.ok());
-    List<List<Long>> resultIdsList2 = searchResponse2.getResultIdsList();
+    SearchParam searchParam2 = SearchParam.create(randomCollectionName)
+        .setDsl(generateSimpleDSL(topK, vectorsToSearch2.toString()))
+        .setPartitionTags(partitionTags2);
+    SearchResult searchResult2 = client.search(searchParam2);
+    List<List<Long>> resultIdsList2 = searchResult2.getResultIdsList();
     assertEquals(searchSize, resultIdsList2.size());
     assertEquals(searchSize, resultIdsList2.size());
     assertTrue(entityIds2.containsAll(resultIdsList2.get(0)));
     assertTrue(entityIds2.containsAll(resultIdsList2.get(0)));
 
 
@@ -529,77 +524,22 @@ class MilvusClientTest {
     List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
     List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
 
 
     final long topK = 10;
     final long topK = 10;
-    SearchParam searchParam =
-        new SearchParam.Builder(randomCollectionName)
-            .withDSL(generateComplexDSL(topK, vectorsToSearch.toString()))
-            .withParamsInJson(new JsonBuilder().param("fields",
-                new ArrayList<>(Arrays.asList("int64", "float_vec"))).build())
-            .build();
-    SearchResponse searchResponse = client.search(searchParam);
-    assertTrue(searchResponse.ok());
-    List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
-    assertEquals(searchSize, resultIdsList.size());
-    List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
-    assertEquals(searchSize, resultDistancesList.size());
-    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
-    assertEquals(searchSize, queryResultsList.size());
-
-    final double epsilon = 0.001;
-    for (int i = 0; i < searchSize; i++) {
-      SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
-      assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
-      assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
-      assertTrue(Math.abs(firstQueryResult.getDistance()) < epsilon);
-      assertTrue(Math.abs(resultDistancesList.get(i).get(0)) < epsilon);
-    }
-  }
-
-  @org.junit.jupiter.api.Test
-  void searchAsync() throws ExecutionException, InterruptedException {
-    List<Long> intValues = new ArrayList<>(size);
-    List<Float> floatValues = new ArrayList<>(size);
-    List<List<Float>> vectors = generateFloatVectors(size, dimension);
-    for (int i = 0; i < size; i++) {
-      intValues.add((long) i);
-      floatValues.add((float) i);
-    }
-    vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
-
-    List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
-    InsertParam insertParam = InsertParam
+    SearchParam searchParam = SearchParam
         .create(randomCollectionName)
         .create(randomCollectionName)
-        .addField("int64", DataType.INT64, intValues)
-        .addField("float", DataType.FLOAT, floatValues)
-        .addVectorField("float_vec", DataType.VECTOR_FLOAT, vectors)
-        .setEntityIds(insertIds);
-    List<Long> entityIds = client.insert(insertParam);
-    assertEquals(size, entityIds.size());
-
-    assertTrue(client.flush(randomCollectionName).ok());
-
-    final int searchSize = 5;
-    List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
-
-    final long topK = 10;
-    SearchParam searchParam =
-        new SearchParam.Builder(randomCollectionName)
-            .withDSL(generateComplexDSL(topK, vectorsToSearch.toString()))
-            .withParamsInJson(new JsonBuilder().param("fields",
-                new ArrayList<>(Arrays.asList("int64", "float"))).build())
-            .build();
-    ListenableFuture<SearchResponse> searchResponseFuture = client.searchAsync(searchParam);
-    SearchResponse searchResponse = searchResponseFuture.get();
-    assertTrue(searchResponse.ok());
-    List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
+        .setDsl(generateComplexDSL(topK, vectorsToSearch.toString()))
+        .setParamsInJson(new JsonBuilder().param("fields",
+            new ArrayList<>(Arrays.asList("int64", "float_vec"))).build());
+    SearchResult searchResult = client.search(searchParam);
+    List<List<Long>> resultIdsList = searchResult.getResultIdsList();
     assertEquals(searchSize, resultIdsList.size());
     assertEquals(searchSize, resultIdsList.size());
-    List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
+    List<List<Float>> resultDistancesList = searchResult.getResultDistancesList();
     assertEquals(searchSize, resultDistancesList.size());
     assertEquals(searchSize, resultDistancesList.size());
-    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
+    List<List<SearchResult.QueryResult>> queryResultsList = searchResult.getQueryResultsList();
     assertEquals(searchSize, queryResultsList.size());
     assertEquals(searchSize, queryResultsList.size());
 
 
     final double epsilon = 0.001;
     final double epsilon = 0.001;
     for (int i = 0; i < searchSize; i++) {
     for (int i = 0; i < searchSize; i++) {
-      SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+      SearchResult.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
       assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
       assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
       assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
       assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
       assertTrue(Math.abs(firstQueryResult.getDistance()) < epsilon);
       assertTrue(Math.abs(firstQueryResult.getDistance()) < epsilon);
@@ -645,21 +585,19 @@ class MilvusClientTest {
         .collect(Collectors.toList());
         .collect(Collectors.toList());
 
 
     final long topK = 10;
     final long topK = 10;
-    SearchParam searchParam =
-        new SearchParam.Builder(binaryCollectionName)
-            .withDSL(generateComplexDSLBinary(topK, vectorsToSearch.toString()))
-            .build();
-    SearchResponse searchResponse = client.search(searchParam);
-    assertTrue(searchResponse.ok());
-    List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
+    SearchParam searchParam = SearchParam
+        .create(binaryCollectionName)
+        .setDsl(generateComplexDSLBinary(topK, vectorsToSearch.toString()));
+    SearchResult searchResult = client.search(searchParam);
+    List<List<Long>> resultIdsList = searchResult.getResultIdsList();
     assertEquals(searchSize, resultIdsList.size());
     assertEquals(searchSize, resultIdsList.size());
-    List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
+    List<List<Float>> resultDistancesList = searchResult.getResultDistancesList();
     assertEquals(searchSize, resultDistancesList.size());
     assertEquals(searchSize, resultDistancesList.size());
-    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
+    List<List<SearchResult.QueryResult>> queryResultsList = searchResult.getQueryResultsList();
     assertEquals(searchSize, queryResultsList.size());
     assertEquals(searchSize, queryResultsList.size());
 
 
     for (int i = 0; i < searchSize; i++) {
     for (int i = 0; i < searchSize; i++) {
-      SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+      SearchResult.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
       assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
       assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
       assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
       assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
     }
     }