|
@@ -0,0 +1,762 @@
|
|
|
|
+/*
|
|
|
|
+ * Licensed to the Apache Software Foundation (ASF) under one
|
|
|
|
+ * or more contributor license agreements. See the NOTICE file
|
|
|
|
+ * distributed with this work for additional information
|
|
|
|
+ * regarding copyright ownership. The ASF licenses this file
|
|
|
|
+ * to you under the Apache License, Version 2.0 (the
|
|
|
|
+ * "License"); you may not use this file except in compliance
|
|
|
|
+ * with the License. You may obtain a copy of the License at
|
|
|
|
+ *
|
|
|
|
+ * http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
+ *
|
|
|
|
+ * Unless required by applicable law or agreed to in writing,
|
|
|
|
+ * software distributed under the License is distributed on an
|
|
|
|
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
|
|
+ * KIND, either express or implied. See the License for the
|
|
|
|
+ * specific language governing permissions and limitations
|
|
|
|
+ * under the License.
|
|
|
|
+ */
|
|
|
|
+
|
|
|
|
+package io.milvus.v2.client;
|
|
|
|
+
|
|
|
|
+import com.google.gson.*;
|
|
|
|
+
|
|
|
|
+import com.google.gson.reflect.TypeToken;
|
|
|
|
+import io.milvus.v2.common.ConsistencyLevel;
|
|
|
|
+import io.milvus.v2.common.DataType;
|
|
|
|
+import io.milvus.v2.common.IndexParam;
|
|
|
|
+import io.milvus.v2.service.collection.request.AddFieldReq;
|
|
|
|
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
|
|
|
|
+import io.milvus.v2.service.collection.request.DropCollectionReq;
|
|
|
|
+import io.milvus.v2.service.partition.request.CreatePartitionReq;
|
|
|
|
+import io.milvus.v2.service.vector.request.*;
|
|
|
|
+import io.milvus.v2.service.vector.request.data.*;
|
|
|
|
+import io.milvus.v2.service.vector.request.ranker.*;
|
|
|
|
+import io.milvus.v2.service.vector.response.*;
|
|
|
|
+import org.apache.commons.text.RandomStringGenerator;
|
|
|
|
+
|
|
|
|
+import org.junit.jupiter.api.Assertions;
|
|
|
|
+import org.junit.jupiter.api.BeforeAll;
|
|
|
|
+import org.junit.jupiter.api.AfterAll;
|
|
|
|
+import org.junit.jupiter.api.Test;
|
|
|
|
+import org.testcontainers.junit.jupiter.Container;
|
|
|
|
+import org.testcontainers.junit.jupiter.Testcontainers;
|
|
|
|
+import org.testcontainers.milvus.MilvusContainer;
|
|
|
|
+
|
|
|
|
+import java.nio.ByteBuffer;
|
|
|
|
+import java.util.*;
|
|
|
|
+
|
|
|
|
+@Testcontainers(disabledWithoutDocker = true)
|
|
|
|
+class MilvusClientV2DockerTest {
|
|
|
|
+ private static MilvusClientV2 client;
|
|
|
|
+ private static RandomStringGenerator generator;
|
|
|
|
+ private static final int dimension = 256;
|
|
|
|
+
|
|
|
|
+ private static final Gson GSON_INSTANCE = new Gson();
|
|
|
|
+
|
|
|
|
+ private static final Random RANDOM = new Random();
|
|
|
|
+
|
|
|
|
+ @Container
|
|
|
|
+ private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.4.1");
|
|
|
|
+
|
|
|
|
+ @BeforeAll
|
|
|
|
+ public static void setUp() {
|
|
|
|
+ ConnectConfig config = ConnectConfig.builder()
|
|
|
|
+ .uri(milvus.getEndpoint())
|
|
|
|
+ .build();
|
|
|
|
+ client = new MilvusClientV2(config);
|
|
|
|
+ generator = new RandomStringGenerator.Builder().withinRange('a', 'z').build();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @AfterAll
|
|
|
|
+ public static void tearDown() throws InterruptedException {
|
|
|
|
+ if (client != null) {
|
|
|
|
+ client.close(5L);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private List<Float> generateFolatVector() {
|
|
|
|
+ List<Float> vector = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < dimension; ++i) {
|
|
|
|
+ vector.add(RANDOM.nextFloat());
|
|
|
|
+ }
|
|
|
|
+ return vector;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private List<List<Float>> generateFloatVectors(int count) {
|
|
|
|
+ List<List<Float>> vectors = new ArrayList<>();
|
|
|
|
+ for (int n = 0; n < count; ++n) {
|
|
|
|
+ vectors.add(generateFolatVector());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return vectors;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private ByteBuffer generateBinaryVector() {
|
|
|
|
+ int byteCount = dimension / 8;
|
|
|
|
+ ByteBuffer vector = ByteBuffer.allocate(byteCount);
|
|
|
|
+ for (int i = 0; i < byteCount; ++i) {
|
|
|
|
+ vector.put((byte) RANDOM.nextInt(Byte.MAX_VALUE));
|
|
|
|
+ }
|
|
|
|
+ return vector;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private List<ByteBuffer> generateBinaryVectors(int count) {
|
|
|
|
+ List<ByteBuffer> vectors = new ArrayList<>();
|
|
|
|
+ for (int n = 0; n < count; ++n) {
|
|
|
|
+ vectors.add(generateBinaryVector());
|
|
|
|
+ }
|
|
|
|
+ return vectors;
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private SortedMap<Long, Float> generateSparseVector() {
|
|
|
|
+ SortedMap<Long, Float> sparse = new TreeMap<>();
|
|
|
|
+ int dim = RANDOM.nextInt(10) + 10;
|
|
|
|
+ for (int i = 0; i < dim; ++i) {
|
|
|
|
+ sparse.put((long) RANDOM.nextInt(1000000), RANDOM.nextFloat());
|
|
|
|
+ }
|
|
|
|
+ return sparse;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private List<SortedMap<Long, Float>> generateSparseVectors(int count) {
|
|
|
|
+ List<SortedMap<Long, Float>> vectors = new ArrayList<>();
|
|
|
|
+ for (int n = 0; n < count; ++n) {
|
|
|
|
+ vectors.add(generateSparseVector());
|
|
|
|
+ }
|
|
|
|
+ return vectors;
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private CreateCollectionReq.CollectionSchema baseSchema() {
|
|
|
|
+ CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
|
|
|
|
+ .build();
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("id")
|
|
|
|
+ .dataType(DataType.Int64)
|
|
|
|
+ .isPrimaryKey(Boolean.TRUE)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("bool_field")
|
|
|
|
+ .dataType(DataType.Bool)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("int8_field")
|
|
|
|
+ .dataType(DataType.Int8)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("int16_field")
|
|
|
|
+ .dataType(DataType.Int16)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("int32_field")
|
|
|
|
+ .dataType(DataType.Int32)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("int64_field")
|
|
|
|
+ .dataType(DataType.Int64)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("float_field")
|
|
|
|
+ .dataType(DataType.Float)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("double_field")
|
|
|
|
+ .dataType(DataType.Double)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("varchar_field")
|
|
|
|
+ .dataType(DataType.VarChar)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("json_field")
|
|
|
|
+ .dataType(DataType.JSON)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("arr_int_field")
|
|
|
|
+ .dataType(DataType.Array)
|
|
|
|
+ .maxCapacity(50)
|
|
|
|
+ .elementType(DataType.Int32)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("arr_float_field")
|
|
|
|
+ .dataType(DataType.Array)
|
|
|
|
+ .maxCapacity(20)
|
|
|
|
+ .elementType(DataType.Float)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("arr_varchar_field")
|
|
|
|
+ .dataType(DataType.Array)
|
|
|
|
+ .maxCapacity(10)
|
|
|
|
+ .elementType(DataType.VarChar)
|
|
|
|
+ .build());
|
|
|
|
+ return collectionSchema;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private JsonArray generateRandomArray(CreateCollectionReq.FieldSchema field) {
|
|
|
|
+ DataType dataType = field.getDataType();
|
|
|
|
+ if (dataType != DataType.Array) {
|
|
|
|
+ Assertions.fail();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ DataType eleType = field.getElementType();
|
|
|
|
+ int eleCnt = RANDOM.nextInt(field.getMaxCapacity());
|
|
|
|
+ switch (eleType) {
|
|
|
|
+ case Bool: {
|
|
|
|
+ List<Boolean> values = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < eleCnt; i++) {
|
|
|
|
+ values.add(i%10 == 0);
|
|
|
|
+ }
|
|
|
|
+ return GSON_INSTANCE.toJsonTree(values).getAsJsonArray();
|
|
|
|
+ }
|
|
|
|
+ case Int8: {
|
|
|
|
+ List<Integer> values = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < eleCnt; i++) {
|
|
|
|
+ values.add(RANDOM.nextInt(256));
|
|
|
|
+ }
|
|
|
|
+ return GSON_INSTANCE.toJsonTree(values).getAsJsonArray();
|
|
|
|
+ }
|
|
|
|
+ case Int16:
|
|
|
|
+ case Int32: {
|
|
|
|
+ List<Integer> values = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < eleCnt; i++) {
|
|
|
|
+ values.add(RANDOM.nextInt());
|
|
|
|
+ }
|
|
|
|
+ return GSON_INSTANCE.toJsonTree(values).getAsJsonArray();
|
|
|
|
+ }
|
|
|
|
+ case Int64: {
|
|
|
|
+ List<Long> values = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < eleCnt; i++) {
|
|
|
|
+ values.add(RANDOM.nextLong());
|
|
|
|
+ }
|
|
|
|
+ return GSON_INSTANCE.toJsonTree(values).getAsJsonArray();
|
|
|
|
+ }
|
|
|
|
+ case Float: {
|
|
|
|
+ List<Float> values = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < eleCnt; i++) {
|
|
|
|
+ values.add(RANDOM.nextFloat());
|
|
|
|
+ }
|
|
|
|
+ return GSON_INSTANCE.toJsonTree(values).getAsJsonArray();
|
|
|
|
+ }
|
|
|
|
+ case Double: {
|
|
|
|
+ List<Double> values = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < eleCnt; i++) {
|
|
|
|
+ values.add(RANDOM.nextDouble());
|
|
|
|
+ }
|
|
|
|
+ return GSON_INSTANCE.toJsonTree(values).getAsJsonArray();
|
|
|
|
+ }
|
|
|
|
+ case VarChar: {
|
|
|
|
+ List<String> values = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < eleCnt; i++) {
|
|
|
|
+ values.add(String.format("varchar_arr_%d", i));
|
|
|
|
+ }
|
|
|
|
+ return GSON_INSTANCE.toJsonTree(values).getAsJsonArray();
|
|
|
|
+ }
|
|
|
|
+ default:
|
|
|
|
+ Assertions.fail();
|
|
|
|
+ }
|
|
|
|
+ return null;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private List<JsonObject> generateRandomData(CreateCollectionReq.CollectionSchema schema, long count) {
|
|
|
|
+ List<CreateCollectionReq.FieldSchema> fields = schema.getFieldSchemaList();
|
|
|
|
+ List<JsonObject> rows = new ArrayList<>();
|
|
|
|
+ for (long i = 0; i < count; i++) {
|
|
|
|
+ JsonObject row = new JsonObject();
|
|
|
|
+ for (CreateCollectionReq.FieldSchema field : fields) {
|
|
|
|
+ DataType dataType = field.getDataType();
|
|
|
|
+ switch (dataType) {
|
|
|
|
+ case Bool:
|
|
|
|
+ row.addProperty(field.getName(), i%3==0);
|
|
|
|
+ break;
|
|
|
|
+ case Int8:
|
|
|
|
+ row.addProperty(field.getName(), i%128);
|
|
|
|
+ break;
|
|
|
|
+ case Int16:
|
|
|
|
+ row.addProperty(field.getName(), i%32768);
|
|
|
|
+ break;
|
|
|
|
+ case Int32:
|
|
|
|
+ row.addProperty(field.getName(), i%65536);
|
|
|
|
+ break;
|
|
|
|
+ case Int64:
|
|
|
|
+ row.addProperty(field.getName(), i);
|
|
|
|
+ break;
|
|
|
|
+ case Float:
|
|
|
|
+ row.addProperty(field.getName(), i/8);
|
|
|
|
+ break;
|
|
|
|
+ case Double:
|
|
|
|
+ row.addProperty(field.getName(), i/3);
|
|
|
|
+ break;
|
|
|
|
+ case VarChar:
|
|
|
|
+ row.addProperty(field.getName(), String.format("varchar_%d", i));
|
|
|
|
+ break;
|
|
|
|
+ case JSON: {
|
|
|
|
+ JsonObject jsonObj = new JsonObject();
|
|
|
|
+ jsonObj.addProperty(String.format("JSON_%d", i), i);
|
|
|
|
+ row.add(field.getName(), jsonObj);
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ case Array: {
|
|
|
|
+ JsonArray array = generateRandomArray(field);
|
|
|
|
+ row.add(field.getName(), array);
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ case FloatVector: {
|
|
|
|
+ List<Float> vector = generateFolatVector();
|
|
|
|
+ row.add(field.getName(), GSON_INSTANCE.toJsonTree(vector));
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ case BinaryVector: {
|
|
|
|
+ ByteBuffer vector = generateBinaryVector();
|
|
|
|
+ row.add(field.getName(), GSON_INSTANCE.toJsonTree(vector.array()));
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ case SparseFloatVector: {
|
|
|
|
+ SortedMap<Long, Float> vector = generateSparseVector();
|
|
|
|
+ row.add(field.getName(), GSON_INSTANCE.toJsonTree(vector));
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ default:
|
|
|
|
+ Assertions.fail();
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ rows.add(row);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return rows;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private void verifyOutput(JsonObject row, Map<String, Object> entity) {
|
|
|
|
+ Boolean b = (Boolean) entity.get("bool_field");
|
|
|
|
+ Assertions.assertEquals(row.get("bool_field").getAsBoolean(), b);
|
|
|
|
+ Integer i8 = (Integer) entity.get("int8_field");
|
|
|
|
+ Assertions.assertEquals(row.get("int8_field").getAsInt(), i8);
|
|
|
|
+ Integer i16 = (Integer) entity.get("int16_field");
|
|
|
|
+ Assertions.assertEquals(row.get("int16_field").getAsInt(), i16);
|
|
|
|
+ Integer i32 = (Integer) entity.get("int32_field");
|
|
|
|
+ Assertions.assertEquals(row.get("int32_field").getAsInt(), i32);
|
|
|
|
+ Long i64 = (Long) entity.get("int64_field");
|
|
|
|
+ Assertions.assertEquals(row.get("int64_field").getAsLong(), i64);
|
|
|
|
+ Float f32 = (Float) entity.get("float_field");
|
|
|
|
+ Assertions.assertEquals(row.get("float_field").getAsFloat(), f32);
|
|
|
|
+ Double f64 = (Double) entity.get("double_field");
|
|
|
|
+ Assertions.assertEquals(row.get("double_field").getAsDouble(), f64);
|
|
|
|
+ String str = (String) entity.get("varchar_field");
|
|
|
|
+ Assertions.assertEquals(row.get("varchar_field").getAsString(), str);
|
|
|
|
+ JsonObject jsn = (JsonObject) entity.get("json_field");
|
|
|
|
+ Assertions.assertEquals(row.get("json_field").toString(), jsn.toString());
|
|
|
|
+
|
|
|
|
+ List<Integer> arrInt = (List<Integer>) entity.get("arr_int_field");
|
|
|
|
+ List<Integer> arrIntOri = GSON_INSTANCE.fromJson(row.get("arr_int_field"), new TypeToken<List<Integer>>() {}.getType());
|
|
|
|
+ Assertions.assertEquals(arrIntOri, arrInt);
|
|
|
|
+ List<Float> arrFloat = (List<Float>) entity.get("arr_float_field");
|
|
|
|
+ List<Float> arrFloatOri = GSON_INSTANCE.fromJson(row.get("arr_float_field"), new TypeToken<List<Float>>() {}.getType());
|
|
|
|
+ Assertions.assertEquals(arrFloatOri, arrFloat);
|
|
|
|
+ List<String> arrStr = (List<String>) entity.get("arr_varchar_field");
|
|
|
|
+ List<String> arrStrOri = GSON_INSTANCE.fromJson(row.get("arr_varchar_field"), new TypeToken<List<String>>() {}.getType());
|
|
|
|
+ Assertions.assertEquals(arrStrOri, arrStr);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testFloatVectors() {
|
|
|
|
+ String randomCollectionName = generator.generate(10);
|
|
|
|
+
|
|
|
|
+ String vectorFieldName = "float_vector";
|
|
|
|
+ CreateCollectionReq.CollectionSchema collectionSchema = baseSchema();
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName(vectorFieldName)
|
|
|
|
+ .dataType(DataType.FloatVector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ Map<String,Object> extraParams = new HashMap<>();
|
|
|
|
+ extraParams.put("M",16);
|
|
|
|
+ extraParams.put("efConstruction",64);
|
|
|
|
+ IndexParam indexParam = IndexParam.builder()
|
|
|
|
+ .fieldName(vectorFieldName)
|
|
|
|
+ .indexType(IndexParam.IndexType.HNSW)
|
|
|
|
+ .metricType(IndexParam.MetricType.COSINE)
|
|
|
|
+ .extraParams(extraParams)
|
|
|
|
+ .build();
|
|
|
|
+
|
|
|
|
+ CreateCollectionReq requestCreate = CreateCollectionReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .collectionSchema(collectionSchema)
|
|
|
|
+ .indexParams(Collections.singletonList(indexParam))
|
|
|
|
+ .build();
|
|
|
|
+ client.createCollection(requestCreate);
|
|
|
|
+
|
|
|
|
+ // insert rows
|
|
|
|
+ long count = 10000;
|
|
|
|
+ List<JsonObject> data = generateRandomData(collectionSchema, count);
|
|
|
|
+ InsertResp insertResp = client.insert(InsertReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .data(data)
|
|
|
|
+ .build());
|
|
|
|
+ Assertions.assertEquals(count, insertResp.getInsertCnt());
|
|
|
|
+
|
|
|
|
+ // create partition, upsert one row to the partition
|
|
|
|
+ String partitionName = "PPP";
|
|
|
|
+ client.createPartition(CreatePartitionReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .partitionName(partitionName)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ List<JsonObject> upsertData = new ArrayList<>();
|
|
|
|
+ upsertData.add(data.get((int) (count - 1)));
|
|
|
|
+ UpsertResp upsertResp = client.upsert(UpsertReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .partitionName(partitionName)
|
|
|
|
+ .data(upsertData)
|
|
|
|
+ .build());
|
|
|
|
+ Assertions.assertEquals(1, upsertResp.getUpsertCnt());
|
|
|
|
+
|
|
|
|
+ // get row count
|
|
|
|
+ QueryResp queryResp = client.query(QueryReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .filter("")
|
|
|
|
+ .outputFields(Collections.singletonList("count(*)"))
|
|
|
|
+ .consistencyLevel(ConsistencyLevel.STRONG)
|
|
|
|
+ .build());
|
|
|
|
+ List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
|
|
|
|
+ Assertions.assertEquals(1, queryResults.size());
|
|
|
|
+ Assertions.assertEquals(count + 1, queryResults.get(0).getEntity().get("count(*)"));
|
|
|
|
+
|
|
|
|
+ // search in partition
|
|
|
|
+ SearchResp searchResp = client.search(SearchReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .partitionNames(Collections.singletonList(partitionName))
|
|
|
|
+ .annsField(vectorFieldName)
|
|
|
|
+ .data(Collections.singletonList(new FloatVec(generateFolatVector())))
|
|
|
|
+ .topK(10)
|
|
|
|
+ .build());
|
|
|
|
+ List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
|
|
|
|
+ Assertions.assertEquals(1, searchResults.size());
|
|
|
|
+ Assertions.assertEquals(1, searchResults.get(0).size());
|
|
|
|
+ Assertions.assertEquals(count - 1, searchResults.get(0).get(0).getId());
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ // query entities
|
|
|
|
+ int nq = 5;
|
|
|
|
+ List<Long> targetIDs = new ArrayList<>();
|
|
|
|
+ List<BaseVector> targetVectors = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < nq; i++) {
|
|
|
|
+ JsonObject row = data.get(RANDOM.nextInt((int)count));
|
|
|
|
+ targetIDs.add(row.get("id").getAsLong());
|
|
|
|
+ List<Float> vector = GSON_INSTANCE.fromJson(row.get(vectorFieldName), new TypeToken<List<Float>>() {}.getType());
|
|
|
|
+ targetVectors.add(new FloatVec(vector));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ GetResp getResp = client.get(GetReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .ids(new ArrayList<>(targetIDs))
|
|
|
|
+ .outputFields(Collections.singletonList("*"))
|
|
|
|
+ .build());
|
|
|
|
+ for (QueryResp.QueryResult result : getResp.getGetResults()) {
|
|
|
|
+ boolean found = false;
|
|
|
|
+ for (int i = 0; i < nq; i++) {
|
|
|
|
+ Map<String, Object> entity = result.getEntity();
|
|
|
|
+ if (Objects.equals(targetIDs.get(i), entity.get("id"))) {
|
|
|
|
+ JsonObject row = data.get(targetIDs.get(i).intValue());
|
|
|
|
+ verifyOutput(row, entity);
|
|
|
|
+ found = true;
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ if (!found) {
|
|
|
|
+ Assertions.fail();
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // search in collection
|
|
|
|
+ int topk = 10;
|
|
|
|
+ searchResp = client.search(SearchReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .annsField(vectorFieldName)
|
|
|
|
+ .data(targetVectors)
|
|
|
|
+ .topK(10)
|
|
|
|
+ .outputFields(Collections.singletonList("*"))
|
|
|
|
+ .build());
|
|
|
|
+ searchResults = searchResp.getSearchResults();
|
|
|
|
+ Assertions.assertEquals(nq, searchResults.size());
|
|
|
|
+ for (int i = 0; i < nq; i++) {
|
|
|
|
+ List<SearchResp.SearchResult> results = searchResults.get(i);
|
|
|
|
+ Assertions.assertEquals(topk, results.size());
|
|
|
|
+ SearchResp.SearchResult result = results.get(0);
|
|
|
|
+ Assertions.assertEquals(targetIDs.get(i), result.getId());
|
|
|
|
+
|
|
|
|
+ Map<String, Object> entity = result.getEntity();
|
|
|
|
+ JsonObject row = data.get(targetIDs.get(i).intValue());
|
|
|
|
+ verifyOutput(row, entity);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testBinaryVectors() {
|
|
|
|
+ String randomCollectionName = generator.generate(10);
|
|
|
|
+
|
|
|
|
+ String vectorFieldName = "binary_vector";
|
|
|
|
+ CreateCollectionReq.CollectionSchema collectionSchema = baseSchema();
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName(vectorFieldName)
|
|
|
|
+ .dataType(DataType.BinaryVector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ Map<String,Object> extraParams = new HashMap<>();
|
|
|
|
+ extraParams.put("nlist",64);
|
|
|
|
+ IndexParam indexParam = IndexParam.builder()
|
|
|
|
+ .fieldName(vectorFieldName)
|
|
|
|
+ .indexType(IndexParam.IndexType.BIN_IVF_FLAT)
|
|
|
|
+ .metricType(IndexParam.MetricType.JACCARD)
|
|
|
|
+ .extraParams(extraParams)
|
|
|
|
+ .build();
|
|
|
|
+
|
|
|
|
+ CreateCollectionReq requestCreate = CreateCollectionReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .collectionSchema(collectionSchema)
|
|
|
|
+ .indexParams(Collections.singletonList(indexParam))
|
|
|
|
+ .build();
|
|
|
|
+ client.createCollection(requestCreate);
|
|
|
|
+
|
|
|
|
+ // insert rows
|
|
|
|
+ long count = 10000;
|
|
|
|
+ List<JsonObject> data = generateRandomData(collectionSchema, count);
|
|
|
|
+ InsertResp insertResp = client.insert(InsertReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .data(data)
|
|
|
|
+ .build());
|
|
|
|
+ Assertions.assertEquals(count, insertResp.getInsertCnt());
|
|
|
|
+
|
|
|
|
+ // get row count
|
|
|
|
+ QueryResp queryResp = client.query(QueryReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .filter("")
|
|
|
|
+ .outputFields(Collections.singletonList("count(*)"))
|
|
|
|
+ .consistencyLevel(ConsistencyLevel.STRONG)
|
|
|
|
+ .build());
|
|
|
|
+ List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
|
|
|
|
+ Assertions.assertEquals(1, queryResults.size());
|
|
|
|
+ Assertions.assertEquals(count, queryResults.get(0).getEntity().get("count(*)"));
|
|
|
|
+
|
|
|
|
+ // search in collection
|
|
|
|
+ int nq = 5;
|
|
|
|
+ int topk = 10;
|
|
|
|
+ List<Long> targetIDs = new ArrayList<>();
|
|
|
|
+ List<BaseVector> targetVectors = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < nq; i++) {
|
|
|
|
+ JsonObject row = data.get(RANDOM.nextInt((int)count));
|
|
|
|
+ targetIDs.add(row.get("id").getAsLong());
|
|
|
|
+ byte[] vector = GSON_INSTANCE.fromJson(row.get(vectorFieldName), new TypeToken<byte[]>() {}.getType());
|
|
|
|
+ targetVectors.add(new BinaryVec(ByteBuffer.wrap(vector)));
|
|
|
|
+ }
|
|
|
|
+ SearchResp searchResp = client.search(SearchReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .annsField(vectorFieldName)
|
|
|
|
+ .data(targetVectors)
|
|
|
|
+ .topK(10)
|
|
|
|
+ .build());
|
|
|
|
+ List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
|
|
|
|
+ Assertions.assertEquals(nq, searchResults.size());
|
|
|
|
+ for (int i = 0; i < nq; i++) {
|
|
|
|
+ List<SearchResp.SearchResult> results = searchResults.get(i);
|
|
|
|
+ Assertions.assertEquals(topk, results.size());
|
|
|
|
+ Assertions.assertEquals(targetIDs.get(i), results.get(0).getId());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testSparseVectors() {
|
|
|
|
+ String randomCollectionName = generator.generate(10);
|
|
|
|
+
|
|
|
|
+ String vectorFieldName = "binary_vector";
|
|
|
|
+ CreateCollectionReq.CollectionSchema collectionSchema = baseSchema();
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName(vectorFieldName)
|
|
|
|
+ .dataType(DataType.SparseFloatVector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ Map<String,Object> extraParams = new HashMap<>();
|
|
|
|
+ extraParams.put("drop_ratio_build",0.2);
|
|
|
|
+ IndexParam indexParam = IndexParam.builder()
|
|
|
|
+ .fieldName(vectorFieldName)
|
|
|
|
+ .indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX)
|
|
|
|
+ .metricType(IndexParam.MetricType.IP)
|
|
|
|
+ .extraParams(extraParams)
|
|
|
|
+ .build();
|
|
|
|
+
|
|
|
|
+ CreateCollectionReq requestCreate = CreateCollectionReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .collectionSchema(collectionSchema)
|
|
|
|
+ .indexParams(Collections.singletonList(indexParam))
|
|
|
|
+ .build();
|
|
|
|
+ client.createCollection(requestCreate);
|
|
|
|
+
|
|
|
|
+ // insert rows
|
|
|
|
+ long count = 10000;
|
|
|
|
+ List<JsonObject> data = generateRandomData(collectionSchema, count);
|
|
|
|
+ InsertResp insertResp = client.insert(InsertReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .data(data)
|
|
|
|
+ .build());
|
|
|
|
+ Assertions.assertEquals(count, insertResp.getInsertCnt());
|
|
|
|
+
|
|
|
|
+ // get row count
|
|
|
|
+ QueryResp queryResp = client.query(QueryReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .filter("")
|
|
|
|
+ .outputFields(Collections.singletonList("count(*)"))
|
|
|
|
+ .consistencyLevel(ConsistencyLevel.STRONG)
|
|
|
|
+ .build());
|
|
|
|
+ List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
|
|
|
|
+ Assertions.assertEquals(1, queryResults.size());
|
|
|
|
+ Assertions.assertEquals(count, queryResults.get(0).getEntity().get("count(*)"));
|
|
|
|
+
|
|
|
|
+ // search in collection
|
|
|
|
+ int nq = 5;
|
|
|
|
+ int topk = 10;
|
|
|
|
+ List<Long> targetIDs = new ArrayList<>();
|
|
|
|
+ List<BaseVector> targetVectors = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < nq; i++) {
|
|
|
|
+ JsonObject row = data.get(RANDOM.nextInt((int)count));
|
|
|
|
+ targetIDs.add(row.get("id").getAsLong());
|
|
|
|
+ SortedMap<Long, Float> vector = GSON_INSTANCE.fromJson(row.get(vectorFieldName), new TypeToken<SortedMap<Long, Float>>() {}.getType());
|
|
|
|
+ targetVectors.add(new SparseFloatVec(vector));
|
|
|
|
+ }
|
|
|
|
+ SearchResp searchResp = client.search(SearchReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .annsField(vectorFieldName)
|
|
|
|
+ .data(targetVectors)
|
|
|
|
+ .topK(topk)
|
|
|
|
+ .build());
|
|
|
|
+ List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
|
|
|
|
+ Assertions.assertEquals(nq, searchResults.size());
|
|
|
|
+ for (int i = 0; i < nq; i++) {
|
|
|
|
+ List<SearchResp.SearchResult> results = searchResults.get(i);
|
|
|
|
+ Assertions.assertEquals(targetIDs.get(i), results.get(0).getId());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testHybridSearch() {
|
|
|
|
+ String randomCollectionName = generator.generate(10);
|
|
|
|
+
|
|
|
|
+ CreateCollectionReq.CollectionSchema collectionSchema = baseSchema();
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("float_vector")
|
|
|
|
+ .dataType(DataType.FloatVector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("binary_vector")
|
|
|
|
+ .dataType(DataType.BinaryVector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("sparse_vector")
|
|
|
|
+ .dataType(DataType.SparseFloatVector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ List<IndexParam> indexParams = new ArrayList<>();
|
|
|
|
+ indexParams.add(IndexParam.builder()
|
|
|
|
+ .fieldName("float_vector")
|
|
|
|
+ .indexType(IndexParam.IndexType.IVF_FLAT)
|
|
|
|
+ .metricType(IndexParam.MetricType.L2)
|
|
|
|
+ .extraParams(new HashMap<String,Object>(){{put("nlist", 64);}})
|
|
|
|
+ .build());
|
|
|
|
+ indexParams.add(IndexParam.builder()
|
|
|
|
+ .fieldName("binary_vector")
|
|
|
|
+ .indexType(IndexParam.IndexType.BIN_FLAT)
|
|
|
|
+ .metricType(IndexParam.MetricType.HAMMING)
|
|
|
|
+ .build());
|
|
|
|
+ indexParams.add(IndexParam.builder()
|
|
|
|
+ .fieldName("sparse_vector")
|
|
|
|
+ .indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX)
|
|
|
|
+ .metricType(IndexParam.MetricType.IP)
|
|
|
|
+ .extraParams(new HashMap<String,Object>(){{put("drop_ratio_build", 0.1);}})
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ CreateCollectionReq requestCreate = CreateCollectionReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .collectionSchema(collectionSchema)
|
|
|
|
+ .indexParams(indexParams)
|
|
|
|
+ .build();
|
|
|
|
+ client.createCollection(requestCreate);
|
|
|
|
+
|
|
|
|
+ // insert rows
|
|
|
|
+ long count = 10000;
|
|
|
|
+ List<JsonObject> data = generateRandomData(collectionSchema, count);
|
|
|
|
+ InsertResp insertResp = client.insert(InsertReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .data(data)
|
|
|
|
+ .build());
|
|
|
|
+ Assertions.assertEquals(count, insertResp.getInsertCnt());
|
|
|
|
+
|
|
|
|
+ // get row count
|
|
|
|
+ QueryResp queryResp = client.query(QueryReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .filter("")
|
|
|
|
+ .outputFields(Collections.singletonList("count(*)"))
|
|
|
|
+ .consistencyLevel(ConsistencyLevel.STRONG)
|
|
|
|
+ .build());
|
|
|
|
+ List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
|
|
|
|
+ Assertions.assertEquals(1, queryResults.size());
|
|
|
|
+ Assertions.assertEquals(count, queryResults.get(0).getEntity().get("count(*)"));
|
|
|
|
+
|
|
|
|
+ // hybrid search in collection
|
|
|
|
+ int nq = 5;
|
|
|
|
+ int topk = 10;
|
|
|
|
+ List<BaseVector> floatVectors = new ArrayList<>();
|
|
|
|
+ List<BaseVector> binaryVectors = new ArrayList<>();
|
|
|
|
+ List<BaseVector> sparseVectors = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < nq; i++) {
|
|
|
|
+ floatVectors.add(new FloatVec(generateFolatVector()));
|
|
|
|
+ binaryVectors.add(new BinaryVec(generateBinaryVector()));
|
|
|
|
+ sparseVectors.add(new SparseFloatVec(generateSparseVector()));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ List<AnnSearchReq> searchRequests = new ArrayList<>();
|
|
|
|
+ searchRequests.add(AnnSearchReq.builder()
|
|
|
|
+ .vectorFieldName("float_vector")
|
|
|
|
+ .vectors(floatVectors)
|
|
|
|
+ .params("{\"nprobe\": 10}")
|
|
|
|
+ .topK(10)
|
|
|
|
+ .build());
|
|
|
|
+ searchRequests.add(AnnSearchReq.builder()
|
|
|
|
+ .vectorFieldName("binary_vector")
|
|
|
|
+ .vectors(binaryVectors)
|
|
|
|
+ .topK(50)
|
|
|
|
+ .build());
|
|
|
|
+ searchRequests.add(AnnSearchReq.builder()
|
|
|
|
+ .vectorFieldName("sparse_vector")
|
|
|
|
+ .vectors(sparseVectors)
|
|
|
|
+ .topK(100)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .searchRequests(searchRequests)
|
|
|
|
+ .ranker(new RRFRanker(20))
|
|
|
|
+ .topK(topk)
|
|
|
|
+ .consistencyLevel(ConsistencyLevel.BOUNDED)
|
|
|
|
+ .build();
|
|
|
|
+ SearchResp searchResp = client.hybridSearch(hybridSearchReq);
|
|
|
|
+ List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
|
|
|
|
+ Assertions.assertEquals(nq, searchResults.size());
|
|
|
|
+ for (int i = 0; i < nq; i++) {
|
|
|
|
+ List<SearchResp.SearchResult> results = searchResults.get(i);
|
|
|
|
+ Assertions.assertEquals(topk, results.size());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());
|
|
|
|
+ }
|
|
|
|
+}
|