Browse Source

fix schema

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>
sahuang 4 years ago
parent
commit
4c8bd2afa8

+ 1 - 1
pom.xml

@@ -25,7 +25,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>0.9.1</version>
+    <version>0.9.2-SNAPSHOT</version>
     <packaging>jar</packaging>
 
     <name>io.milvus:milvus-sdk-java</name>

+ 100 - 44
src/test/java/io/milvus/client/dsl/SearchDslTest.java

@@ -1,5 +1,8 @@
 package io.milvus.client.dsl;
 
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
 import io.milvus.client.ConnectParam;
 import io.milvus.client.IndexType;
 import io.milvus.client.JsonBuilder;
@@ -8,12 +11,6 @@ import io.milvus.client.MilvusClient;
 import io.milvus.client.MilvusGrpcClient;
 import io.milvus.client.SearchParam;
 import io.milvus.client.SearchResult;
-import org.apache.commons.lang3.RandomUtils;
-import org.junit.jupiter.api.Test;
-import org.testcontainers.containers.GenericContainer;
-import org.testcontainers.junit.jupiter.Container;
-import org.testcontainers.junit.jupiter.Testcontainers;
-
 import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.List;
@@ -23,19 +20,22 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import java.util.stream.LongStream;
 import java.util.stream.Stream;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.apache.commons.lang3.RandomUtils;
+import org.junit.jupiter.api.Test;
+import org.testcontainers.containers.GenericContainer;
+import org.testcontainers.junit.jupiter.Container;
+import org.testcontainers.junit.jupiter.Testcontainers;
 
 @Testcontainers
 public class SearchDslTest {
 
   @Container
   private GenericContainer milvusContainer =
-      new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu"))
+      new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu-d101620-4c44c0"))
           .withExposedPorts(19530);
 
-  private TestSchema schema = new TestSchema();
+  private TestFloatSchema floatSchema = new TestFloatSchema();
+  private TestBinarySchema binarySchema = new TestBinarySchema();
   private String collectionName = "test_collection";
   private int size = 1000;
 
@@ -46,9 +46,15 @@ public class SearchDslTest {
         .build();
   }
 
-  private void withMilvusService(Consumer<MilvusService> test) {
+  private void withMilvusServiceFloat(Consumer<MilvusService> test) {
+    try (MilvusClient client = new MilvusGrpcClient(connectParam(milvusContainer))) {
+      test.accept(new MilvusService(client, collectionName, floatSchema));
+    }
+  }
+
+  private void withMilvusServiceBinary(Consumer<MilvusService> test) {
     try (MilvusClient client = new MilvusGrpcClient(connectParam(milvusContainer))) {
-      test.accept(new MilvusService(client, collectionName, schema));
+      test.accept(new MilvusService(client, collectionName, binarySchema));
     }
   }
 
@@ -69,26 +75,33 @@ public class SearchDslTest {
   }
 
   @Test
-  public void testCreateCollection() {
-    withMilvusService(service -> {
+  public void testCreateCollectionFloat() {
+    withMilvusServiceFloat(service -> {
+      service.createCollection(new JsonBuilder().param("auto_id", false).build());
+      assertTrue(service.hasCollection(collectionName));
+    });
+  }
+
+  @Test
+  public void testCreateCollectionBinary() {
+    withMilvusServiceBinary(service -> {
       service.createCollection(new JsonBuilder().param("auto_id", false).build());
       assertTrue(service.hasCollection(collectionName));
     });
   }
 
   @Test
-  public void testInsert() {
-    testCreateCollection();
+  public void testInsertFloat() {
+    testCreateCollectionFloat();
 
-    withMilvusService(service -> {
+    withMilvusServiceFloat(service -> {
       service.insert(insertParam -> insertParam
           .withIds(LongStream.range(0, size).boxed().collect(Collectors.toList()))
-          .with(schema.intField, IntStream.range(0, size).boxed().collect(Collectors.toList()))
-          .with(schema.longField, LongStream.range(0, size).boxed().collect(Collectors.toList()))
-          .with(schema.floatField, IntStream.range(0, size).boxed().map(Number::floatValue).collect(Collectors.toList()))
-          .with(schema.doubleField, IntStream.range(0, size).boxed().map(Number::doubleValue).collect(Collectors.toList()))
-          .with(schema.floatVectorField, randomFloatVectors(size, schema.floatVectorField.dimension))
-          .with(schema.binaryVectorField, randomBinaryVectors(size, schema.binaryVectorField.dimension)));
+          .with(floatSchema.intField, IntStream.range(0, size).boxed().collect(Collectors.toList()))
+          .with(floatSchema.longField, LongStream.range(0, size).boxed().collect(Collectors.toList()))
+          .with(floatSchema.floatField, IntStream.range(0, size).boxed().map(Number::floatValue).collect(Collectors.toList()))
+          .with(floatSchema.doubleField, IntStream.range(0, size).boxed().map(Number::doubleValue).collect(Collectors.toList()))
+          .with(floatSchema.floatVectorField, randomFloatVectors(size, floatSchema.floatVectorField.dimension)));
 
       service.flush();
 
@@ -97,45 +110,88 @@ public class SearchDslTest {
   }
 
   @Test
-  public void testCreateIndex() {
-    testInsert();
+  public void testInsertBinary() {
+    testCreateCollectionBinary();
 
-    withMilvusService(service -> {
-      service.createIndex(schema.floatVectorField, IndexType.IVF_SQ8, MetricType.L2, "{\"nlist\": 256}");
-      service.createIndex(schema.binaryVectorField, IndexType.BIN_FLAT, MetricType.JACCARD, "{}");
+    withMilvusServiceBinary(service -> {
+      service.insert(insertParam -> insertParam
+          .withIds(LongStream.range(0, size).boxed().collect(Collectors.toList()))
+          .with(binarySchema.intField, IntStream.range(0, size).boxed().collect(Collectors.toList()))
+          .with(binarySchema.longField, LongStream.range(0, size).boxed().collect(Collectors.toList()))
+          .with(binarySchema.floatField, IntStream.range(0, size).boxed().map(Number::floatValue).collect(Collectors.toList()))
+          .with(binarySchema.doubleField, IntStream.range(0, size).boxed().map(Number::doubleValue).collect(Collectors.toList()))
+          .with(binarySchema.binaryVectorField, randomBinaryVectors(size, binarySchema.binaryVectorField.dimension)));
+
+      service.flush();
+
+      assertEquals(size, service.countEntities());
+    });
+  }
+
+  @Test
+  public void testCreateIndexFloat() {
+    testInsertFloat();
+
+    withMilvusServiceFloat(service -> {
+      service.createIndex(floatSchema.floatVectorField, IndexType.IVF_SQ8, MetricType.L2, "{\"nlist\": 256}");
+    });
+  }
+
+  @Test
+  public void testCreateIndexBinary() {
+    testInsertBinary();
+
+    withMilvusServiceBinary(service -> {
+      service.createIndex(binarySchema.binaryVectorField, IndexType.BIN_FLAT, MetricType.JACCARD, "{}");
+    });
+  }
+
+  @Test
+  public void testGetEntityByIdFloat() {
+    withMilvusServiceFloat(service -> {
+      testInsertFloat();
+
+      Map<Long, Schema.Entity> entities = service.getEntityByID(
+          LongStream.range(0, 10).boxed().collect(Collectors.toList()),
+          Arrays.asList(floatSchema.intField, floatSchema.longField));
+
+      LongStream.range(0, 10).forEach(i -> {
+        assertEquals((int) i, entities.get(i).get(floatSchema.intField));
+        assertEquals(i, entities.get(i).get(floatSchema.longField));
+      });
     });
   }
 
   @Test
-  public void testGetEntityById() {
-    withMilvusService(service -> {
-      testInsert();
+  public void testGetEntityByIdBinary() {
+    withMilvusServiceBinary(service -> {
+      testInsertBinary();
 
       Map<Long, Schema.Entity> entities = service.getEntityByID(
           LongStream.range(0, 10).boxed().collect(Collectors.toList()),
-          Arrays.asList(schema.intField, schema.longField));
+          Arrays.asList(binarySchema.intField, binarySchema.longField));
 
       LongStream.range(0, 10).forEach(i -> {
-        assertEquals((int) i, entities.get(i).get(schema.intField));
-        assertEquals(i, entities.get(i).get(schema.longField));
+        assertEquals((int) i, entities.get(i).get(binarySchema.intField));
+        assertEquals(i, entities.get(i).get(binarySchema.longField));
       });
     });
   }
 
   @Test
-  public void testFloadVectorQuery() {
-    withMilvusService(service -> {
-      testCreateIndex();
+  public void testFloatVectorQuery() {
+    withMilvusServiceFloat(service -> {
+      testCreateIndexFloat();
 
       List<Long> entityIds = LongStream.range(0, 10).boxed().collect(Collectors.toList());
 
       Map<Long, Schema.Entity> entities = service.getEntityByID(entityIds);
 
-      List<List<Float>> vectors = entities.values().stream().map(e -> e.get(schema.floatVectorField)).collect(Collectors.toList());
+      List<List<Float>> vectors = entities.values().stream().map(e -> e.get(floatSchema.floatVectorField)).collect(Collectors.toList());
 
       Query query = Query.bool(
           Query.must(
-              schema.floatVectorField.query(vectors).param("nprobe", 16).top(1)
+              floatSchema.floatVectorField.query(vectors).param("nprobe", 16).top(1)
           )
       );
 
@@ -152,18 +208,18 @@ public class SearchDslTest {
 
   @Test
   public void testBinaryVectorQuery() {
-    withMilvusService(service -> {
-      testCreateIndex();
+    withMilvusServiceBinary(service -> {
+      testCreateIndexBinary();
 
       List<Long> entityIds = LongStream.range(0, 10).boxed().collect(Collectors.toList());
 
       Map<Long, Schema.Entity> entities = service.getEntityByID(entityIds);
 
-      List<ByteBuffer> vectors = entities.values().stream().map(e -> e.get(schema.binaryVectorField)).collect(Collectors.toList());
+      List<ByteBuffer> vectors = entities.values().stream().map(e -> e.get(binarySchema.binaryVectorField)).collect(Collectors.toList());
 
       Query query = Query.bool(
           Query.must(
-              schema.binaryVectorField.query(vectors).top(1)
+              binarySchema.binaryVectorField.query(vectors).top(1)
           )
       );
 

+ 1 - 2
src/test/java/io/milvus/client/dsl/TestSchema.java → src/test/java/io/milvus/client/dsl/TestBinarySchema.java

@@ -1,10 +1,9 @@
 package io.milvus.client.dsl;
 
-public class TestSchema extends Schema {
+public class TestBinarySchema extends Schema {
   public final Int32Field intField = new Int32Field("int32");
   public final Int64Field longField = new Int64Field("int64");
   public final FloatField floatField = new FloatField("float");
   public final DoubleField doubleField = new DoubleField("double");
-  public final FloatVectorField floatVectorField = new FloatVectorField("float_vec", 64);
   public final BinaryVectorField binaryVectorField = new BinaryVectorField("binary_vec", 64);
 }

+ 9 - 0
src/test/java/io/milvus/client/dsl/TestFloatSchema.java

@@ -0,0 +1,9 @@
+package io.milvus.client.dsl;
+
+public class TestFloatSchema extends Schema {
+  public final Int32Field intField = new Int32Field("int32");
+  public final Int64Field longField = new Int64Field("int64");
+  public final FloatField floatField = new FloatField("float");
+  public final DoubleField doubleField = new DoubleField("double");
+  public final FloatVectorField floatVectorField = new FloatVectorField("float_vec", 64);
+}