Browse Source

Add a dsl api for building search queries

jianghua 4 years ago
parent
commit
e871525645

+ 1 - 1
src/main/java/io/milvus/client/MilvusClient.java

@@ -31,7 +31,7 @@ import java.util.concurrent.TimeUnit;
 import java.util.function.Supplier;
 
 /** The Milvus Client Interface */
-public interface MilvusClient {
+public interface MilvusClient extends AutoCloseable {
 
   String extraParamKey = "params";
 

+ 11 - 1
src/main/java/io/milvus/client/SearchParam.java

@@ -43,7 +43,7 @@ public class SearchParam {
   private static final String VECTOR_QUERY_KEY = "vector";
   private static final String VECTOR_QUERY_PLACEHOLDER = "placeholder";
 
-  private io.milvus.grpc.SearchParam.Builder builder;
+  private final io.milvus.grpc.SearchParam.Builder builder;
 
   public static SearchParam create(String collectionName) {
     return new SearchParam(collectionName);
@@ -54,6 +54,11 @@ public class SearchParam {
     builder.setCollectionName(collectionName);
   }
 
+  public SearchParam setDsl(JSONObject json) {
+    builder.setDsl(json.toString());
+    return this;
+  }
+
   public SearchParam setDsl(String dsl) {
     try {
       JSONObject dslJson = new JSONObject(dsl);
@@ -95,6 +100,11 @@ public class SearchParam {
     }
   }
 
+  public SearchParam addQueries(VectorParam vectorParam) {
+    builder.addVectorParam(vectorParam);
+    return this;
+  }
+
   public SearchParam setPartitionTags(List<String> partitionTags) {
     builder.addAllPartitionTagArray(partitionTags);
     return this;

+ 42 - 0
src/main/java/io/milvus/client/dsl/BoolQuery.java

@@ -0,0 +1,42 @@
+package io.milvus.client.dsl;
+
+import io.milvus.client.SearchParam;
+import org.json.JSONArray;
+import org.json.JSONObject;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class BoolQuery extends Query {
+  private final Type type;
+  private final List<Query> subqueries;
+
+  BoolQuery(Type type, List<Query> subqueries) {
+    this.type = type;
+    this.subqueries = subqueries;
+  }
+
+  enum Type {
+    MUST, MUST_NOT, SHOULD,
+
+    BOOL {
+      @Override
+      public Object buildSearchParam(SearchParam searchParam, List<Query> subqueries) {
+        JSONObject outer = new JSONObject();
+        subqueries.forEach(query -> query.buildSearchParam(searchParam, outer));
+        return outer;
+      }
+    };
+
+    public Object buildSearchParam(SearchParam searchParam, List<Query> subqueries) {
+      return new JSONArray(subqueries.stream()
+          .map(query -> query.buildSearchParam(searchParam, new JSONObject()))
+          .collect(Collectors.toList()));
+    }
+  }
+
+  @Override
+  protected JSONObject buildSearchParam(SearchParam searchParam, JSONObject outer) {
+    return outer.put(type.name().toLowerCase(), type.buildSearchParam(searchParam, subqueries));
+  }
+}

+ 30 - 0
src/main/java/io/milvus/client/dsl/InsertParam.java

@@ -0,0 +1,30 @@
+package io.milvus.client.dsl;
+
+import java.util.List;
+
+public class InsertParam {
+  private final io.milvus.client.InsertParam insertParam;
+
+  InsertParam(String collectionName) {
+    this.insertParam = io.milvus.client.InsertParam.create(collectionName);
+  }
+
+  public InsertParam withIds(List<Long> ids) {
+    insertParam.setEntityIds(ids);
+    return this;
+  }
+
+  public <T> InsertParam with(Schema.Field<T> field, List<T> data) {
+    insertParam.addField(field.name, field.dataType, data);
+    return this;
+  }
+
+  public <T> InsertParam with(Schema.VectorField<T> vectorField, List<T> data) {
+    insertParam.addVectorField(vectorField.name, vectorField.dataType, data);
+    return this;
+  }
+
+  io.milvus.client.InsertParam getInsertParam() {
+    return insertParam;
+  }
+}

+ 117 - 0
src/main/java/io/milvus/client/dsl/MilvusService.java

@@ -0,0 +1,117 @@
+package io.milvus.client.dsl;
+
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import io.milvus.client.Index;
+import io.milvus.client.IndexType;
+import io.milvus.client.MetricType;
+import io.milvus.client.MilvusClient;
+import io.milvus.client.SearchParam;
+import io.milvus.client.SearchResult;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+
+public class MilvusService {
+  private final MilvusClient client;
+  private final String collectionName;
+  private final Schema schema;
+
+  public MilvusService(MilvusClient client, String  collectionName, Schema schema) {
+    this.client = client;
+    this.collectionName = collectionName;
+    this.schema = schema;
+  }
+
+  public MilvusService withTimeout(int timeout, TimeUnit unit) {
+    return new MilvusService(client.withTimeout(timeout, unit), collectionName, schema);
+  }
+
+  public void close() {
+    client.close();
+  }
+
+  public long countEntities() {
+    return client.countEntities(collectionName);
+  }
+
+  public void createCollection() {
+    createCollection("{}");
+  }
+
+  public void createCollection(String paramsInJson) {
+    client.createCollection(schema.mapToCollection(collectionName).setParamsInJson(paramsInJson));
+  }
+
+  public void createIndex(
+      Schema.VectorField vectorField, IndexType indexType, MetricType metricType, String paramsInJson) {
+    Futures.getUnchecked(createIndexAsync(vectorField, indexType, metricType, paramsInJson));
+  }
+
+  public ListenableFuture<Void> createIndexAsync(
+      Schema.VectorField vectorField, IndexType indexType, MetricType metricType, String paramsInJson) {
+    return client.createIndexAsync(
+        Index.create(collectionName, vectorField.name)
+            .setIndexType(indexType)
+            .setMetricType(metricType)
+            .setParamsInJson(paramsInJson));
+  }
+
+  public void deleteEntityByID(List<Long> ids) {
+    client.deleteEntityByID(collectionName, ids);
+  }
+
+  public void dropCollection() {
+    client.dropCollection(collectionName);
+  }
+
+  public void flush() {
+    client.flush(collectionName);
+  }
+
+  public ListenableFuture<Void> flushAsync() {
+    return client.flushAsync(collectionName);
+  }
+
+  public Map<Long, Schema.Entity> getEntityByID(List<Long> ids) {
+    return getEntityByID(ids, Collections.emptyList());
+  }
+
+  public Map<Long, Schema.Entity> getEntityByID(List<Long> ids, List<Schema.Field<?>> fields) {
+    List<String> fieldNames = fields.stream().map(f -> f.name).collect(Collectors.toList());
+    return client.getEntityByID(collectionName, ids, fieldNames)
+        .entrySet().stream().collect(Collectors.toMap(
+            e -> e.getKey(),
+            e -> schema.new Entity(e.getValue())));
+  }
+
+  public boolean hasCollection(String collectionName) {
+    return client.hasCollection(collectionName);
+  }
+
+  public List<Long> insert(Consumer<InsertParam> insertParamBuilder) {
+    return Futures.getUnchecked(insertAsync(insertParamBuilder));
+  }
+
+  public ListenableFuture<List<Long>> insertAsync(Consumer<InsertParam> insertParamBuilder) {
+    InsertParam insertParam = schema.insertInto(collectionName);
+    insertParamBuilder.accept(insertParam);
+    return client.insertAsync(insertParam.getInsertParam());
+  }
+
+  public SearchResult search(SearchParam searchParam) {
+    return client.search(searchParam);
+  }
+
+  public ListenableFuture<SearchResult> searchAsync(SearchParam searchParam) {
+    return client.searchAsync(searchParam);
+  }
+
+  public SearchParam buildSearchParam(Query query) {
+    return query.buildSearchParam(collectionName);
+  }
+}

+ 34 - 0
src/main/java/io/milvus/client/dsl/Query.java

@@ -0,0 +1,34 @@
+package io.milvus.client.dsl;
+
+import io.milvus.client.SearchParam;
+import org.json.JSONObject;
+
+import java.util.Arrays;
+
+public abstract class Query {
+
+  public static BoolQuery bool(Query... subqueries) {
+    return new BoolQuery(BoolQuery.Type.BOOL, Arrays.asList(subqueries));
+  }
+
+  public static BoolQuery must(Query... subqueries) {
+    return new BoolQuery(BoolQuery.Type.MUST, Arrays.asList(subqueries));
+  }
+
+  public static BoolQuery must_not(Query... subqueries) {
+    return new BoolQuery(BoolQuery.Type.MUST_NOT, Arrays.asList(subqueries));
+  }
+
+  public static BoolQuery should(Query... subqueries) {
+    return new BoolQuery(BoolQuery.Type.SHOULD, Arrays.asList(subqueries));
+  }
+
+  public SearchParam buildSearchParam(String collectionName) {
+    SearchParam searchParam = SearchParam.create(collectionName);
+    JSONObject json = buildSearchParam(searchParam, new JSONObject());
+    searchParam.setDsl(json);
+    return searchParam;
+  }
+
+  protected abstract JSONObject buildSearchParam(SearchParam searchParam, JSONObject json);
+}

+ 71 - 0
src/main/java/io/milvus/client/dsl/RangeQuery.java

@@ -0,0 +1,71 @@
+package io.milvus.client.dsl;
+
+import io.milvus.client.SearchParam;
+import org.json.JSONObject;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class RangeQuery<T> extends Query {
+  private Schema.Field<T> field;
+  private List<Expr> exprs = new ArrayList<>();
+
+  RangeQuery(Schema.Field field) {
+    this.field = field;
+  }
+
+  public RangeQuery<T> gt(T value) {
+    exprs.add(new Expr(Type.GT, value));
+    return this;
+  }
+
+  public RangeQuery<T> gte(T value) {
+    exprs.add(new Expr(Type.GTE, value));
+    return this;
+  }
+
+  public RangeQuery<T> lt(T value) {
+    exprs.add(new Expr(Type.LT, value));
+    return this;
+  }
+
+  public RangeQuery<T> lte(T value) {
+    exprs.add(new Expr(Type.LTE, value));
+    return this;
+  }
+
+  public RangeQuery<T> eq(T value) {
+    exprs.add(new Expr(Type.EQ, value));
+    return this;
+  }
+
+  public RangeQuery<T> ne(T value) {
+    exprs.add(new Expr(Type.NE, value));
+    return this;
+  }
+
+  @Override
+  protected JSONObject buildSearchParam(SearchParam searchParam, JSONObject outer) {
+    return outer.put("range", new JSONObject().put(field.name, buildSearchParam(exprs)));
+  }
+
+  private JSONObject buildSearchParam(List<Expr> exprs) {
+    JSONObject json = new JSONObject();
+    exprs.forEach(e -> json.put(e.type.name().toLowerCase(), e.value));
+    return json;
+  }
+
+  public enum Type {
+    GT, GTE, LT, LTE, EQ, NE;
+  }
+
+  private class Expr {
+    Type type;
+    T value;
+
+    Expr(Type type, T value) {
+      this.type = type;
+      this.value = value;
+    }
+  }
+}

+ 144 - 0
src/main/java/io/milvus/client/dsl/Schema.java

@@ -0,0 +1,144 @@
+package io.milvus.client.dsl;
+
+import io.milvus.client.CollectionMapping;
+import io.milvus.client.DataType;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+public abstract class Schema {
+  private final Map<String, Field> fields = new LinkedHashMap<>();
+
+  Field<?> getField(String name) {
+    return fields.get(name);
+  }
+
+  CollectionMapping mapToCollection(String collectionName) {
+    CollectionMapping mapping = CollectionMapping.create(collectionName);
+    fields.values().forEach(f -> {
+      if (f instanceof ScalarField) {
+        mapping.addField(f.name, f.dataType);
+      } else if (f instanceof VectorField) {
+        mapping.addVectorField(f.name, f.dataType, ((VectorField<?>) f).dimension);
+      }
+    });
+    return mapping;
+  }
+
+  InsertParam insertInto(String collectionName) {
+    return new InsertParam(collectionName);
+  }
+
+  public class Field<T> {
+    public final String name;
+    public final DataType dataType;
+
+    private Field(String name, DataType dataType) {
+      this.name = name;
+      this.dataType = dataType;
+      if (fields.putIfAbsent(name, this) != null) {
+        throw new IllegalArgumentException("Field name conflict: " + name);
+      }
+    }
+  }
+
+  public class ScalarField<T> extends Field<T> {
+    private ScalarField(String name, DataType dataType) {
+      super(name, dataType);
+    }
+
+    public RangeQuery<T> gt(T value) {
+      return new RangeQuery<T>(this).gt(value);
+    }
+
+    public RangeQuery<T> gte(T value) {
+      return new RangeQuery<T>(this).gte(value);
+    }
+
+    public RangeQuery<T> lt(T value) {
+      return new RangeQuery<T>(this).lt(value);
+    }
+
+    public RangeQuery<T> lte(T value) {
+      return new RangeQuery<T>(this).lte(value);
+    }
+
+    public RangeQuery<T> eq(T value) {
+      return new RangeQuery<T>(this).eq(value);
+    }
+
+    public RangeQuery<T> ne(T value) {
+      return new RangeQuery<T>(this).ne(value);
+    }
+
+    @SuppressWarnings("unchecked")
+    public TermQuery<T> in(T... values) {
+      return new TermQuery<>(this, TermQuery.Type.IN, Arrays.asList(values));
+    }
+  }
+
+  public class Int32Field extends ScalarField<Integer> {
+    public Int32Field(String name) {
+      super(name, DataType.INT32);
+    }
+  }
+
+  public class Int64Field extends ScalarField<Long> {
+    public Int64Field(String name) {
+      super(name, DataType.INT64);
+    }
+  }
+
+  public class FloatField extends ScalarField<Float> {
+    public FloatField(String name) {
+      super(name, DataType.FLOAT);
+    }
+  }
+
+  public class DoubleField extends ScalarField<Double> {
+    public DoubleField(String name) {
+      super(name, DataType.DOUBLE);
+    }
+  }
+
+  public class VectorField<T> extends Field<T> {
+    public final int dimension;
+
+    private VectorField(String name, DataType dataType, int dimension) {
+      super(name, dataType);
+      this.dimension = dimension;
+    }
+
+    public VectorQuery<T> query(List<T> queries) {
+      return new VectorQuery<>(this, queries);
+    }
+  }
+
+  public class FloatVectorField extends VectorField<List<Float>> {
+    public FloatVectorField(String name, int dimension) {
+      super(name, DataType.VECTOR_FLOAT, dimension);
+    }
+  }
+
+  public class BinaryVectorField extends VectorField<ByteBuffer> {
+    public BinaryVectorField(String name, int dimension) {
+      super(name, DataType.VECTOR_BINARY, dimension);
+    }
+  }
+
+  public class Entity {
+    private final Map<String, Object> properties;
+
+    Entity(Map<String, Object> properties) {
+      this.properties = properties;
+    }
+
+    @SuppressWarnings("unchecked")
+    public <T> T get(Field<T> field) {
+      return (T) properties.get(field.name);
+    }
+  }
+}

+ 35 - 0
src/main/java/io/milvus/client/dsl/TermQuery.java

@@ -0,0 +1,35 @@
+package io.milvus.client.dsl;
+
+import io.milvus.client.SearchParam;
+import org.json.JSONArray;
+import org.json.JSONObject;
+
+import java.util.Collection;
+
+public class TermQuery<T> extends Query {
+  private final Schema.Field<T> field;
+  private final Type type;
+  private final Object param;
+
+  public TermQuery(Schema.Field<T> field, Type type, Object param) {
+    this.field = field;
+    this.type = type;
+    this.param = param;
+  }
+
+  @Override
+  protected JSONObject buildSearchParam(SearchParam searchParam, JSONObject outer) {
+    return outer.put("term", new JSONObject().put(field.name, type.toJson(param)));
+  }
+
+  enum Type {
+    IN {
+      @Override
+      Object toJson(Object param) {
+        return new JSONArray((Collection<?>) param);
+      }
+    };
+
+    abstract Object toJson(Object param);
+  }
+}

+ 98 - 0
src/main/java/io/milvus/client/dsl/VectorQuery.java

@@ -0,0 +1,98 @@
+package io.milvus.client.dsl;
+
+import com.google.protobuf.UnsafeByteOperations;
+import io.milvus.client.MetricType;
+import io.milvus.client.SearchParam;
+import io.milvus.grpc.VectorParam;
+import io.milvus.grpc.VectorRecord;
+import io.milvus.grpc.VectorRowRecord;
+import org.json.JSONObject;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class VectorQuery<T> extends Query {
+  private final Schema.VectorField<T> field;
+  private final List<T> queries;
+  private String placeholder;
+  private int topK = 10;
+  private float boost = 1.0f;
+  private MetricType metricType;
+  private JSONObject params = new JSONObject();
+
+  VectorQuery(Schema.VectorField<T> field, List<T> queries) {
+    this.field = field;
+    this.queries = queries;
+    this.placeholder = field.name;
+    this.metricType = field instanceof Schema.FloatVectorField ? MetricType.L2 : MetricType.JACCARD;
+  }
+
+  public VectorQuery<T> placeholder(String placeholder) {
+    this.placeholder = placeholder;
+    return this;
+  }
+
+  public VectorQuery<T> top(int topK) {
+    this.topK = topK;
+    return this;
+  }
+
+  public VectorQuery<T> boost(float value) {
+    this.boost = value;
+    return this;
+  }
+
+  public VectorQuery<T> metricType(MetricType metricType) {
+    this.metricType = metricType;
+    return this;
+  }
+
+  public VectorQuery<T> param(String key, Object value) {
+    params.put(key, value);
+    return this;
+  }
+
+  public VectorQuery<T> paramsInJson(String paramsInJson) {
+    params = new JSONObject(paramsInJson);
+    return this;
+  }
+
+  @SuppressWarnings("unchecked")
+  void buildSearchParam(SearchParam searchParam) {
+    VectorRecord vectorRecord = null;
+    if (field instanceof Schema.FloatVectorField) {
+      vectorRecord = VectorRecord.newBuilder().addAllRecords(
+          ((List<List<Float>>) this.queries).stream().map(vector ->
+              VectorRowRecord.newBuilder().addAllFloatData(vector).build())
+              .collect(Collectors.toList()))
+          .build();
+    } else if (field instanceof Schema.BinaryVectorField) {
+      vectorRecord = VectorRecord.newBuilder().addAllRecords(
+          ((List<ByteBuffer>) this.queries).stream().map(vector ->
+            VectorRowRecord.newBuilder().setBinaryData(UnsafeByteOperations.unsafeWrap(vector)).build())
+              .collect(Collectors.toList()))
+          .build();
+    }
+
+    VectorParam vectorParam = VectorParam.newBuilder()
+        .setJson(new JSONObject()
+            .put(placeholder, new JSONObject()
+                .put(field.name, new JSONObject()
+                    .put("topk", topK)
+                    .put("metric_type", metricType.name())
+                    .put("boost", boost)
+                    .put("params", params)))
+                .toString())
+        .setRowRecord(vectorRecord)
+        .build();
+
+    searchParam.addQueries(vectorParam);
+  }
+
+  @Override
+  protected JSONObject buildSearchParam(SearchParam searchParam, JSONObject outer) {
+    buildSearchParam(searchParam);
+    return outer.put("vector", placeholder);
+  }
+}

+ 179 - 0
src/test/java/io/milvus/client/dsl/SearchDslTest.java

@@ -0,0 +1,179 @@
+package io.milvus.client.dsl;
+
+import io.milvus.client.ConnectParam;
+import io.milvus.client.IndexType;
+import io.milvus.client.JsonBuilder;
+import io.milvus.client.MetricType;
+import io.milvus.client.MilvusClient;
+import io.milvus.client.MilvusGrpcClient;
+import io.milvus.client.SearchParam;
+import io.milvus.client.SearchResult;
+import org.apache.commons.lang3.RandomUtils;
+import org.junit.jupiter.api.Test;
+import org.testcontainers.containers.GenericContainer;
+import org.testcontainers.junit.jupiter.Container;
+import org.testcontainers.junit.jupiter.Testcontainers;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.LongStream;
+import java.util.stream.Stream;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+@Testcontainers
+public class SearchDslTest {
+
+  @Container
+  private GenericContainer milvusContainer =
+      new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu"))
+          .withExposedPorts(19530);
+
+  private TestSchema schema = new TestSchema();
+  private String collectionName = "test_collection";
+  private int size = 1000;
+
+  private ConnectParam connectParam(GenericContainer milvusContainer) {
+    return new ConnectParam.Builder()
+        .withHost(milvusContainer.getHost())
+        .withPort(milvusContainer.getFirstMappedPort())
+        .build();
+  }
+
+  private void withMilvusService(Consumer<MilvusService> test) {
+    try (MilvusClient client = new MilvusGrpcClient(connectParam(milvusContainer))) {
+      test.accept(new MilvusService(client, collectionName, schema));
+    }
+  }
+
+  private List<Float> randomFloatVector(int dimension) {
+    return Stream.generate(RandomUtils::nextFloat).limit(dimension).collect(Collectors.toList());
+  }
+
+  private List<List<Float>> randomFloatVectors(int size, int dimension) {
+    return Stream.generate(() -> randomFloatVector(dimension)).limit(size).collect(Collectors.toList());
+  }
+
+  private ByteBuffer randomBinaryVector(int dimension) {
+    return ByteBuffer.wrap(RandomUtils.nextBytes(dimension / 8));
+  }
+
+  private List<ByteBuffer> randomBinaryVectors(int size, int dimension) {
+    return Stream.generate(() -> randomBinaryVector(dimension)).limit(size).collect(Collectors.toList());
+  }
+
+  @Test
+  public void testCreateCollection() {
+    withMilvusService(service -> {
+      service.createCollection(new JsonBuilder().param("auto_id", false).build());
+      assertTrue(service.hasCollection(collectionName));
+    });
+  }
+
+  @Test
+  public void testInsert() {
+    testCreateCollection();
+
+    withMilvusService(service -> {
+      service.insert(insertParam -> insertParam
+          .withIds(LongStream.range(0, size).boxed().collect(Collectors.toList()))
+          .with(schema.intField, IntStream.range(0, size).boxed().collect(Collectors.toList()))
+          .with(schema.longField, LongStream.range(0, size).boxed().collect(Collectors.toList()))
+          .with(schema.floatField, IntStream.range(0, size).boxed().map(Number::floatValue).collect(Collectors.toList()))
+          .with(schema.doubleField, IntStream.range(0, size).boxed().map(Number::doubleValue).collect(Collectors.toList()))
+          .with(schema.floatVectorField, randomFloatVectors(size, schema.floatVectorField.dimension))
+          .with(schema.binaryVectorField, randomBinaryVectors(size, schema.binaryVectorField.dimension)));
+
+      service.flush();
+
+      assertEquals(size, service.countEntities());
+    });
+  }
+
+  @Test
+  public void testCreateIndex() {
+    testInsert();
+
+    withMilvusService(service -> {
+      service.createIndex(schema.floatVectorField, IndexType.IVF_SQ8, MetricType.L2, "{\"nlist\": 256}");
+      service.createIndex(schema.binaryVectorField, IndexType.BIN_FLAT, MetricType.JACCARD, "{}");
+    });
+  }
+
+  @Test
+  public void testGetEntityById() {
+    withMilvusService(service -> {
+      testInsert();
+
+      Map<Long, Schema.Entity> entities = service.getEntityByID(
+          LongStream.range(0, 10).boxed().collect(Collectors.toList()),
+          Arrays.asList(schema.intField, schema.longField));
+
+      LongStream.range(0, 10).forEach(i -> {
+        assertEquals((int) i, entities.get(i).get(schema.intField));
+        assertEquals(i, entities.get(i).get(schema.longField));
+      });
+    });
+  }
+
+  @Test
+  public void testFloadVectorQuery() {
+    withMilvusService(service -> {
+      testCreateIndex();
+
+      List<Long> entityIds = LongStream.range(0, 10).boxed().collect(Collectors.toList());
+
+      Map<Long, Schema.Entity> entities = service.getEntityByID(entityIds);
+
+      List<List<Float>> vectors = entities.values().stream().map(e -> e.get(schema.floatVectorField)).collect(Collectors.toList());
+
+      Query query = Query.bool(
+          Query.must(
+              schema.floatVectorField.query(vectors).param("nprobe", 16).top(1)
+          )
+      );
+
+      SearchParam searchParam = service.buildSearchParam(query)
+          .setParamsInJson(new JsonBuilder().param("fields", Arrays.asList("int64", "float_vec")).build());
+
+      SearchResult searchResult = service.search(searchParam);
+      assertEquals(entityIds,
+          searchResult.getResultIdsList().stream()
+              .map(ids -> ids.get(0))
+              .collect(Collectors.toList()));
+    });
+  }
+
+  @Test
+  public void testBinaryVectorQuery() {
+    withMilvusService(service -> {
+      testCreateIndex();
+
+      List<Long> entityIds = LongStream.range(0, 10).boxed().collect(Collectors.toList());
+
+      Map<Long, Schema.Entity> entities = service.getEntityByID(entityIds);
+
+      List<ByteBuffer> vectors = entities.values().stream().map(e -> e.get(schema.binaryVectorField)).collect(Collectors.toList());
+
+      Query query = Query.bool(
+          Query.must(
+              schema.binaryVectorField.query(vectors).top(1)
+          )
+      );
+
+      SearchParam searchParam = service.buildSearchParam(query);
+
+      SearchResult searchResult = service.search(searchParam);
+      assertEquals(entityIds,
+          searchResult.getResultIdsList().stream()
+              .map(ids -> ids.get(0))
+              .collect(Collectors.toList()));
+    });
+  }
+}

+ 10 - 0
src/test/java/io/milvus/client/dsl/TestSchema.java

@@ -0,0 +1,10 @@
+package io.milvus.client.dsl;
+
+public class TestSchema extends Schema {
+  public final Int32Field intField = new Int32Field("int32");
+  public final Int64Field longField = new Int64Field("int64");
+  public final FloatField floatField = new FloatField("float");
+  public final DoubleField doubleField = new DoubleField("double");
+  public final FloatVectorField floatVectorField = new FloatVectorField("float_vec", 64);
+  public final BinaryVectorField binaryVectorField = new BinaryVectorField("binary_vec", 64);
+}