|
@@ -1,5 +1,8 @@
|
|
|
package io.milvus.client.dsl;
|
|
|
|
|
|
+import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
|
+import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
|
+
|
|
|
import io.milvus.client.ConnectParam;
|
|
|
import io.milvus.client.IndexType;
|
|
|
import io.milvus.client.JsonBuilder;
|
|
@@ -8,12 +11,6 @@ import io.milvus.client.MilvusClient;
|
|
|
import io.milvus.client.MilvusGrpcClient;
|
|
|
import io.milvus.client.SearchParam;
|
|
|
import io.milvus.client.SearchResult;
|
|
|
-import org.apache.commons.lang3.RandomUtils;
|
|
|
-import org.junit.jupiter.api.Test;
|
|
|
-import org.testcontainers.containers.GenericContainer;
|
|
|
-import org.testcontainers.junit.jupiter.Container;
|
|
|
-import org.testcontainers.junit.jupiter.Testcontainers;
|
|
|
-
|
|
|
import java.nio.ByteBuffer;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.List;
|
|
@@ -23,19 +20,22 @@ import java.util.stream.Collectors;
|
|
|
import java.util.stream.IntStream;
|
|
|
import java.util.stream.LongStream;
|
|
|
import java.util.stream.Stream;
|
|
|
-
|
|
|
-import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
|
-import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
|
+import org.apache.commons.lang3.RandomUtils;
|
|
|
+import org.junit.jupiter.api.Test;
|
|
|
+import org.testcontainers.containers.GenericContainer;
|
|
|
+import org.testcontainers.junit.jupiter.Container;
|
|
|
+import org.testcontainers.junit.jupiter.Testcontainers;
|
|
|
|
|
|
@Testcontainers
|
|
|
public class SearchDslTest {
|
|
|
|
|
|
@Container
|
|
|
private GenericContainer milvusContainer =
|
|
|
- new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu"))
|
|
|
+ new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu-d101620-4c44c0"))
|
|
|
.withExposedPorts(19530);
|
|
|
|
|
|
- private TestSchema schema = new TestSchema();
|
|
|
+ private TestFloatSchema floatSchema = new TestFloatSchema();
|
|
|
+ private TestBinarySchema binarySchema = new TestBinarySchema();
|
|
|
private String collectionName = "test_collection";
|
|
|
private int size = 1000;
|
|
|
|
|
@@ -46,9 +46,15 @@ public class SearchDslTest {
|
|
|
.build();
|
|
|
}
|
|
|
|
|
|
- private void withMilvusService(Consumer<MilvusService> test) {
|
|
|
+ private void withMilvusServiceFloat(Consumer<MilvusService> test) {
|
|
|
+ try (MilvusClient client = new MilvusGrpcClient(connectParam(milvusContainer))) {
|
|
|
+ test.accept(new MilvusService(client, collectionName, floatSchema));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void withMilvusServiceBinary(Consumer<MilvusService> test) {
|
|
|
try (MilvusClient client = new MilvusGrpcClient(connectParam(milvusContainer))) {
|
|
|
- test.accept(new MilvusService(client, collectionName, schema));
|
|
|
+ test.accept(new MilvusService(client, collectionName, binarySchema));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -69,26 +75,33 @@ public class SearchDslTest {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void testCreateCollection() {
|
|
|
- withMilvusService(service -> {
|
|
|
+ public void testCreateCollectionFloat() {
|
|
|
+ withMilvusServiceFloat(service -> {
|
|
|
+ service.createCollection(new JsonBuilder().param("auto_id", false).build());
|
|
|
+ assertTrue(service.hasCollection(collectionName));
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void testCreateCollectionBinary() {
|
|
|
+ withMilvusServiceBinary(service -> {
|
|
|
service.createCollection(new JsonBuilder().param("auto_id", false).build());
|
|
|
assertTrue(service.hasCollection(collectionName));
|
|
|
});
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void testInsert() {
|
|
|
- testCreateCollection();
|
|
|
+ public void testInsertFloat() {
|
|
|
+ testCreateCollectionFloat();
|
|
|
|
|
|
- withMilvusService(service -> {
|
|
|
+ withMilvusServiceFloat(service -> {
|
|
|
service.insert(insertParam -> insertParam
|
|
|
.withIds(LongStream.range(0, size).boxed().collect(Collectors.toList()))
|
|
|
- .with(schema.intField, IntStream.range(0, size).boxed().collect(Collectors.toList()))
|
|
|
- .with(schema.longField, LongStream.range(0, size).boxed().collect(Collectors.toList()))
|
|
|
- .with(schema.floatField, IntStream.range(0, size).boxed().map(Number::floatValue).collect(Collectors.toList()))
|
|
|
- .with(schema.doubleField, IntStream.range(0, size).boxed().map(Number::doubleValue).collect(Collectors.toList()))
|
|
|
- .with(schema.floatVectorField, randomFloatVectors(size, schema.floatVectorField.dimension))
|
|
|
- .with(schema.binaryVectorField, randomBinaryVectors(size, schema.binaryVectorField.dimension)));
|
|
|
+ .with(floatSchema.intField, IntStream.range(0, size).boxed().collect(Collectors.toList()))
|
|
|
+ .with(floatSchema.longField, LongStream.range(0, size).boxed().collect(Collectors.toList()))
|
|
|
+ .with(floatSchema.floatField, IntStream.range(0, size).boxed().map(Number::floatValue).collect(Collectors.toList()))
|
|
|
+ .with(floatSchema.doubleField, IntStream.range(0, size).boxed().map(Number::doubleValue).collect(Collectors.toList()))
|
|
|
+ .with(floatSchema.floatVectorField, randomFloatVectors(size, floatSchema.floatVectorField.dimension)));
|
|
|
|
|
|
service.flush();
|
|
|
|
|
@@ -97,45 +110,88 @@ public class SearchDslTest {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void testCreateIndex() {
|
|
|
- testInsert();
|
|
|
+ public void testInsertBinary() {
|
|
|
+ testCreateCollectionBinary();
|
|
|
|
|
|
- withMilvusService(service -> {
|
|
|
- service.createIndex(schema.floatVectorField, IndexType.IVF_SQ8, MetricType.L2, "{\"nlist\": 256}");
|
|
|
- service.createIndex(schema.binaryVectorField, IndexType.BIN_FLAT, MetricType.JACCARD, "{}");
|
|
|
+ withMilvusServiceBinary(service -> {
|
|
|
+ service.insert(insertParam -> insertParam
|
|
|
+ .withIds(LongStream.range(0, size).boxed().collect(Collectors.toList()))
|
|
|
+ .with(binarySchema.intField, IntStream.range(0, size).boxed().collect(Collectors.toList()))
|
|
|
+ .with(binarySchema.longField, LongStream.range(0, size).boxed().collect(Collectors.toList()))
|
|
|
+ .with(binarySchema.floatField, IntStream.range(0, size).boxed().map(Number::floatValue).collect(Collectors.toList()))
|
|
|
+ .with(binarySchema.doubleField, IntStream.range(0, size).boxed().map(Number::doubleValue).collect(Collectors.toList()))
|
|
|
+ .with(binarySchema.binaryVectorField, randomBinaryVectors(size, binarySchema.binaryVectorField.dimension)));
|
|
|
+
|
|
|
+ service.flush();
|
|
|
+
|
|
|
+ assertEquals(size, service.countEntities());
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void testCreateIndexFloat() {
|
|
|
+ testInsertFloat();
|
|
|
+
|
|
|
+ withMilvusServiceFloat(service -> {
|
|
|
+ service.createIndex(floatSchema.floatVectorField, IndexType.IVF_SQ8, MetricType.L2, "{\"nlist\": 256}");
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void testCreateIndexBinary() {
|
|
|
+ testInsertBinary();
|
|
|
+
|
|
|
+ withMilvusServiceBinary(service -> {
|
|
|
+ service.createIndex(binarySchema.binaryVectorField, IndexType.BIN_FLAT, MetricType.JACCARD, "{}");
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void testGetEntityByIdFloat() {
|
|
|
+ withMilvusServiceFloat(service -> {
|
|
|
+ testInsertFloat();
|
|
|
+
|
|
|
+ Map<Long, Schema.Entity> entities = service.getEntityByID(
|
|
|
+ LongStream.range(0, 10).boxed().collect(Collectors.toList()),
|
|
|
+ Arrays.asList(floatSchema.intField, floatSchema.longField));
|
|
|
+
|
|
|
+ LongStream.range(0, 10).forEach(i -> {
|
|
|
+ assertEquals((int) i, entities.get(i).get(floatSchema.intField));
|
|
|
+ assertEquals(i, entities.get(i).get(floatSchema.longField));
|
|
|
+ });
|
|
|
});
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void testGetEntityById() {
|
|
|
- withMilvusService(service -> {
|
|
|
- testInsert();
|
|
|
+ public void testGetEntityByIdBinary() {
|
|
|
+ withMilvusServiceBinary(service -> {
|
|
|
+ testInsertBinary();
|
|
|
|
|
|
Map<Long, Schema.Entity> entities = service.getEntityByID(
|
|
|
LongStream.range(0, 10).boxed().collect(Collectors.toList()),
|
|
|
- Arrays.asList(schema.intField, schema.longField));
|
|
|
+ Arrays.asList(binarySchema.intField, binarySchema.longField));
|
|
|
|
|
|
LongStream.range(0, 10).forEach(i -> {
|
|
|
- assertEquals((int) i, entities.get(i).get(schema.intField));
|
|
|
- assertEquals(i, entities.get(i).get(schema.longField));
|
|
|
+ assertEquals((int) i, entities.get(i).get(binarySchema.intField));
|
|
|
+ assertEquals(i, entities.get(i).get(binarySchema.longField));
|
|
|
});
|
|
|
});
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void testFloadVectorQuery() {
|
|
|
- withMilvusService(service -> {
|
|
|
- testCreateIndex();
|
|
|
+ public void testFloatVectorQuery() {
|
|
|
+ withMilvusServiceFloat(service -> {
|
|
|
+ testCreateIndexFloat();
|
|
|
|
|
|
List<Long> entityIds = LongStream.range(0, 10).boxed().collect(Collectors.toList());
|
|
|
|
|
|
Map<Long, Schema.Entity> entities = service.getEntityByID(entityIds);
|
|
|
|
|
|
- List<List<Float>> vectors = entities.values().stream().map(e -> e.get(schema.floatVectorField)).collect(Collectors.toList());
|
|
|
+ List<List<Float>> vectors = entities.values().stream().map(e -> e.get(floatSchema.floatVectorField)).collect(Collectors.toList());
|
|
|
|
|
|
Query query = Query.bool(
|
|
|
Query.must(
|
|
|
- schema.floatVectorField.query(vectors).param("nprobe", 16).top(1)
|
|
|
+ floatSchema.floatVectorField.query(vectors).param("nprobe", 16).top(1)
|
|
|
)
|
|
|
);
|
|
|
|
|
@@ -152,18 +208,18 @@ public class SearchDslTest {
|
|
|
|
|
|
@Test
|
|
|
public void testBinaryVectorQuery() {
|
|
|
- withMilvusService(service -> {
|
|
|
- testCreateIndex();
|
|
|
+ withMilvusServiceBinary(service -> {
|
|
|
+ testCreateIndexBinary();
|
|
|
|
|
|
List<Long> entityIds = LongStream.range(0, 10).boxed().collect(Collectors.toList());
|
|
|
|
|
|
Map<Long, Schema.Entity> entities = service.getEntityByID(entityIds);
|
|
|
|
|
|
- List<ByteBuffer> vectors = entities.values().stream().map(e -> e.get(schema.binaryVectorField)).collect(Collectors.toList());
|
|
|
+ List<ByteBuffer> vectors = entities.values().stream().map(e -> e.get(binarySchema.binaryVectorField)).collect(Collectors.toList());
|
|
|
|
|
|
Query query = Query.bool(
|
|
|
Query.must(
|
|
|
- schema.binaryVectorField.query(vectors).top(1)
|
|
|
+ binarySchema.binaryVectorField.query(vectors).top(1)
|
|
|
)
|
|
|
);
|
|
|
|