12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172 |
- /*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
- package io.milvus.client;
- import com.google.common.util.concurrent.FutureCallback;
- import com.google.common.util.concurrent.Futures;
- import com.google.common.util.concurrent.ListenableFuture;
- import com.google.common.util.concurrent.MoreExecutors;
- import io.grpc.NameResolverProvider;
- import io.grpc.NameResolverRegistry;
- import io.milvus.client.InsertParam.Builder;
- import io.milvus.client.exception.ClientSideMilvusException;
- import io.milvus.client.exception.InitializationException;
- import io.milvus.client.exception.ServerSideMilvusException;
- import io.milvus.client.exception.UnsupportedServerVersion;
- import io.milvus.grpc.ErrorCode;
- import org.apache.commons.lang3.ArrayUtils;
- import org.apache.commons.text.RandomStringGenerator;
- import org.checkerframework.checker.nullness.compatqual.NullableDecl;
- import org.json.JSONArray;
- import org.json.JSONObject;
- import org.junit.jupiter.api.condition.DisabledIfSystemProperty;
- import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
- import org.testcontainers.containers.GenericContainer;
- import org.testcontainers.junit.jupiter.Container;
- import org.testcontainers.junit.jupiter.Testcontainers;
- import java.net.InetSocketAddress;
- import java.nio.ByteBuffer;
- import java.util.ArrayList;
- import java.util.Arrays;
- import java.util.Collections;
- import java.util.List;
- import java.util.Map;
- import java.util.Random;
- import java.util.SplittableRandom;
- import java.util.concurrent.ExecutionException;
- import java.util.concurrent.TimeUnit;
- import java.util.stream.Collectors;
- import java.util.stream.DoubleStream;
- import java.util.stream.IntStream;
- import java.util.stream.LongStream;
- import static org.junit.jupiter.api.Assertions.assertArrayEquals;
- import static org.junit.jupiter.api.Assertions.assertEquals;
- import static org.junit.jupiter.api.Assertions.assertFalse;
- import static org.junit.jupiter.api.Assertions.assertThrows;
- import static org.junit.jupiter.api.Assertions.assertTrue;
- @Testcontainers
- @EnabledIfSystemProperty(named = "with-containers", matches = "true")
- class ContainerMilvusClientTest extends MilvusClientTest {
- @Container
- private GenericContainer milvusContainer =
- new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu"))
- .withExposedPorts(19530);
- @Container
- private static GenericContainer milvusContainer2 =
- new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu"))
- .withExposedPorts(19530);
- @Override
- protected ConnectParam.Builder connectParamBuilder() {
- return connectParamBuilder(milvusContainer);
- }
- @org.junit.jupiter.api.Test
- void loadBalancing() {
- NameResolverProvider testNameResolverProvider = new StaticNameResolverProvider(
- new InetSocketAddress(milvusContainer.getHost(), milvusContainer.getFirstMappedPort()),
- new InetSocketAddress(milvusContainer2.getHost(), milvusContainer2.getFirstMappedPort()));
- NameResolverRegistry.getDefaultRegistry().register(testNameResolverProvider);
- ConnectParam connectParam = connectParamBuilder()
- .withTarget(testNameResolverProvider.getDefaultScheme() + ":///test")
- .build();
- MilvusClient loadBalancingClient = new MilvusGrpcClient(connectParam);
- assertEquals(50, IntStream.range(0, 100)
- .filter(i -> loadBalancingClient.hasCollection(randomCollectionName).hasCollection())
- .count());
- }
- }
- @Testcontainers
- @DisabledIfSystemProperty(named = "with-containers", matches = "true")
- class MilvusClientTest {
- private MilvusClient client;
- private RandomStringGenerator generator;
- protected String randomCollectionName;
- private int size;
- private int dimension;
- protected ConnectParam.Builder connectParamBuilder() {
- return connectParamBuilder("localhost", 19530);
- }
- protected ConnectParam.Builder connectParamBuilder(GenericContainer milvusContainer) {
- return connectParamBuilder(milvusContainer.getHost(), milvusContainer.getFirstMappedPort());
- }
- protected ConnectParam.Builder connectParamBuilder(String host, int port) {
- return new ConnectParam.Builder().withHost(host).withPort(port);
- }
- protected void assertErrorCode(ErrorCode errorCode, Runnable runnable) {
- assertEquals(errorCode, assertThrows(ServerSideMilvusException.class, runnable::run).getErrorCode());
- }
- // Helper function that generates random float vectors
- static List<List<Float>> generateFloatVectors(int vectorCount, int dimension) {
- SplittableRandom splittableRandom = new SplittableRandom();
- List<List<Float>> vectors = new ArrayList<>(vectorCount);
- for (int i = 0; i < vectorCount; ++i) {
- splittableRandom = splittableRandom.split();
- DoubleStream doubleStream = splittableRandom.doubles(dimension);
- List<Float> vector =
- doubleStream.boxed().map(Double::floatValue).collect(Collectors.toList());
- vectors.add(vector);
- }
- return vectors;
- }
- // Helper function that generates random binary vectors
- static List<List<Byte>> generateBinaryVectors(int vectorCount, int dimension) {
- Random random = new Random();
- List<List<Byte>> vectors = new ArrayList<>(vectorCount);
- final int dimensionInByte = dimension / 8;
- for (int i = 0; i < vectorCount; ++i) {
- ByteBuffer byteBuffer = ByteBuffer.allocate(dimensionInByte);
- random.nextBytes(byteBuffer.array());
- byte[] b = new byte[byteBuffer.remaining()];
- byteBuffer.get(b);
- vectors.add(Arrays.asList(ArrayUtils.toObject(b)));
- }
- return vectors;
- }
- // Helper function that normalizes a vector if you are using IP (Inner Product) as your metric
- // type
- static List<Float> normalizeVector(List<Float> vector) {
- float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
- final float norm = (float) Math.sqrt(squareSum);
- vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
- return vector;
- }
- // Helper function that generate a simple DSL statement with vector filtering only
- static String generateSimpleDSL(Long topK, String query) {
- return String.format(
- "{\"bool\": {"
- + "\"must\": [{"
- + " \"vector\": {"
- + " \"float_vec\": {"
- + " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
- + " }}}]}}", topK, query);
- }
- // Helper function that generate a complex DSL statement with scalar field filtering
- static String generateComplexDSL(Long topK, String query) {
- return String.format(
- "{\"bool\": {"
- + "\"must\": [{"
- + " \"must\": [{"
- + " \"vector\": {"
- + " \"float_vec\": {"
- + " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
- + "}}}]}]}}",
- topK, query);
- }
- // Helper function that generate a complex DSL statement with scalar field filtering
- static String generateComplexDSLBinary(Long topK, String query) {
- return String.format(
- "{\"bool\": {"
- + "\"must\": [{"
- + " \"vector\": {"
- + " \"binary_vec\": {"
- + " \"topk\": %d, \"metric_type\": \"JACCARD\", \"type\": \"binary\", \"query\": %s, \"params\": {\"nprobe\": 20}"
- + " }}}]}}",
- topK, query);
- }
- @org.junit.jupiter.api.BeforeEach
- void setUp() throws Exception {
- ConnectParam connectParam = connectParamBuilder().build();
- client = new MilvusGrpcClient(connectParam);
- generator = new RandomStringGenerator.Builder().withinRange('a', 'z').build();
- randomCollectionName = generator.generate(10);
- size = 100000;
- dimension = 128;
- CollectionMapping collectionMapping = CollectionMapping
- .create(randomCollectionName)
- .addField("int64", DataType.INT64)
- .addField("float", DataType.FLOAT)
- .addVectorField("float_vec", DataType.VECTOR_FLOAT, dimension)
- .setParamsInJson(new JsonBuilder()
- .param("segment_row_limit", 50000)
- .param("auto_id", false)
- .build());
- client.createCollection(collectionMapping);
- }
- @org.junit.jupiter.api.AfterEach
- void tearDown() {
- client.dropCollection(randomCollectionName);
- client.close();
- }
- @org.junit.jupiter.api.Test
- void idleTest() throws InterruptedException {
- ConnectParam connectParam = connectParamBuilder()
- .withIdleTimeout(1, TimeUnit.SECONDS)
- .build();
- MilvusClient client = new MilvusGrpcClient(connectParam);
- TimeUnit.SECONDS.sleep(2);
- // A new RPC would take the channel out of idle mode
- assertTrue(client.listCollections().ok());
- }
- @org.junit.jupiter.api.Test
- void setInvalidConnectParam() {
- assertThrows(
- IllegalArgumentException.class,
- () -> {
- ConnectParam connectParam = new ConnectParam.Builder().withPort(66666).build();
- });
- assertThrows(
- IllegalArgumentException.class,
- () -> {
- ConnectParam connectParam =
- new ConnectParam.Builder().withConnectTimeout(-1, TimeUnit.MILLISECONDS).build();
- });
- assertThrows(
- IllegalArgumentException.class,
- () -> {
- ConnectParam connectParam =
- new ConnectParam.Builder().withKeepAliveTime(-1, TimeUnit.MILLISECONDS).build();
- });
- assertThrows(
- IllegalArgumentException.class,
- () -> {
- ConnectParam connectParam =
- new ConnectParam.Builder().withKeepAliveTimeout(-1, TimeUnit.MILLISECONDS).build();
- });
- assertThrows(
- IllegalArgumentException.class,
- () -> {
- ConnectParam connectParam =
- new ConnectParam.Builder().withIdleTimeout(-1, TimeUnit.MILLISECONDS).build();
- });
- }
- @org.junit.jupiter.api.Test
- void connectUnreachableHost() {
- ConnectParam connectParam = connectParamBuilder("250.250.250.250", 19530).build();
- assertThrows(InitializationException.class, () -> new MilvusGrpcClient(connectParam));
- }
- @org.junit.jupiter.api.Test
- void unsupportedServerVersion() {
- GenericContainer unsupportedMilvusContainer =
- new GenericContainer("milvusdb/milvus:0.9.1-cpu-d052920-e04ed5")
- .withExposedPorts(19530);
- try {
- unsupportedMilvusContainer.start();
- ConnectParam connectParam = connectParamBuilder(unsupportedMilvusContainer).build();
- assertThrows(UnsupportedServerVersion.class, () -> new MilvusGrpcClient(connectParam));
- } finally {
- unsupportedMilvusContainer.stop();
- }
- }
- @org.junit.jupiter.api.Test
- void grpcTimeout() {
- insert();
- MilvusClient timeoutClient = client.withTimeout(1, TimeUnit.MILLISECONDS);
- Response response = timeoutClient.createIndex(
- new Index.Builder(randomCollectionName, "float_vec")
- .withParamsInJson(new JsonBuilder()
- .param("index_type", "IVF_FLAT")
- .param("metric_type", "L2")
- .indexParam("nlist", 2048)
- .build())
- .build());
- assertEquals(Response.Status.RPC_ERROR, response.getStatus());
- }
- @org.junit.jupiter.api.Test
- void createInvalidCollection() {
- // invalid collection name
- CollectionMapping invalidCollectionName = CollectionMapping
- .create("╯°□°)╯")
- .addVectorField("float_vec", DataType.VECTOR_FLOAT, dimension);
- assertErrorCode(ErrorCode.ILLEGAL_COLLECTION_NAME, () -> client.createCollection(invalidCollectionName));
- // invalid field
- CollectionMapping withoutField = CollectionMapping.create("validCollectionName");
- assertThrows(ClientSideMilvusException.class, () -> client.createCollection(withoutField));
- // invalid segment_row_count
- CollectionMapping invalidSegmentRowCount = CollectionMapping
- .create("validCollectionName")
- .addField("int64", DataType.INT64)
- .addField("float", DataType.FLOAT)
- .addVectorField("float_vec", DataType.VECTOR_FLOAT, dimension)
- .setParamsInJson(new JsonBuilder().param("segment_row_limit", -1000).build());
- assertErrorCode(ErrorCode.ILLEGAL_ARGUMENT, () -> client.createCollection(invalidSegmentRowCount));
- }
- @org.junit.jupiter.api.Test
- void hasCollection() {
- assertTrue(client.hasCollection(randomCollectionName));
- }
- @org.junit.jupiter.api.Test
- void dropCollection() {
- String nonExistingCollectionName = generator.generate(10);
- Response dropCollectionResponse = client.dropCollection(nonExistingCollectionName);
- assertFalse(dropCollectionResponse.ok());
- assertEquals(Response.Status.COLLECTION_NOT_EXISTS, dropCollectionResponse.getStatus());
- }
- @org.junit.jupiter.api.Test
- @SuppressWarnings("unchecked")
- void partitionTest() {
- final String tag1 = "tag1";
- Response createPartitionResponse = client.createPartition(randomCollectionName, tag1);
- assertTrue(createPartitionResponse.ok());
- final String tag2 = "tag2";
- createPartitionResponse = client.createPartition(randomCollectionName, tag2);
- assertTrue(createPartitionResponse.ok());
- ListPartitionsResponse listPartitionsResponse = client.listPartitions(randomCollectionName);
- assertTrue(listPartitionsResponse.ok());
- assertEquals(3, listPartitionsResponse.getPartitionList().size()); // two tags plus _default
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- List<List<Float>> vectors = generateFloatVectors(size, dimension);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- List<Long> entityIds1 = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(entityIds1)
- .withPartitionTag(tag1)
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- List<Long> entityIds2 = LongStream.range(size, size * 2).boxed().collect(Collectors.toList());
- insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(entityIds2)
- .withPartitionTag(tag2)
- .build();
- insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- assertTrue(client.flush(randomCollectionName).ok());
- assertEquals(size * 2,
- client.countEntities(randomCollectionName).getCollectionEntityCount());
- final int searchSize = 1;
- final long topK = 10;
- List<List<Float>> vectorsToSearch1 = vectors.subList(0, searchSize);
- List<String> partitionTags1 = new ArrayList<>();
- partitionTags1.add(tag1);
- SearchParam searchParam1 =
- new SearchParam.Builder(randomCollectionName)
- .withDSL(generateSimpleDSL(topK, vectorsToSearch1.toString()))
- .withPartitionTags(partitionTags1)
- .build();
- SearchResponse searchResponse1 = client.search(searchParam1);
- assertTrue(searchResponse1.ok());
- List<List<Long>> resultIdsList1 = searchResponse1.getResultIdsList();
- assertEquals(searchSize, resultIdsList1.size());
- assertTrue(entityIds1.containsAll(resultIdsList1.get(0)));
- List<List<Float>> vectorsToSearch2 = vectors.subList(0, searchSize);
- List<String> partitionTags2 = new ArrayList<>();
- partitionTags2.add(tag2);
- SearchParam searchParam2 =
- new SearchParam.Builder(randomCollectionName)
- .withDSL(generateSimpleDSL(topK, vectorsToSearch2.toString()))
- .withPartitionTags(partitionTags2)
- .build();
- SearchResponse searchResponse2 = client.search(searchParam2);
- assertTrue(searchResponse2.ok());
- List<List<Long>> resultIdsList2 = searchResponse2.getResultIdsList();
- assertEquals(searchSize, resultIdsList2.size());
- assertTrue(entityIds2.containsAll(resultIdsList2.get(0)));
- assertTrue(Collections.disjoint(resultIdsList1, resultIdsList2));
- HasPartitionResponse testHasPartition = client.hasPartition(randomCollectionName, tag1);
- assertTrue(testHasPartition.hasPartition());
- Response dropPartitionResponse = client.dropPartition(randomCollectionName, tag1);
- assertTrue(dropPartitionResponse.ok());
- testHasPartition = client.hasPartition(randomCollectionName, tag1);
- assertFalse(testHasPartition.hasPartition());
- dropPartitionResponse = client.dropPartition(randomCollectionName, tag2);
- assertTrue(dropPartitionResponse.ok());
- }
- @org.junit.jupiter.api.Test
- void createIndex() {
- insert();
- assertTrue(client.flush(randomCollectionName).ok());
- Index index =
- new Index.Builder(randomCollectionName, "float_vec")
- .withParamsInJson(new JsonBuilder().param("index_type", "IVF_SQ8")
- .param("metric_type", "L2")
- .indexParam("nlist", 2048)
- .build())
- .build();
- Response createIndexResponse = client.createIndex(index);
- assertTrue(createIndexResponse.ok());
- // also test drop index here
- Response dropIndexResponse = client.dropIndex(randomCollectionName, "float_vec");
- assertTrue(dropIndexResponse.ok());
- }
- @org.junit.jupiter.api.Test
- void createIndexAsync() throws ExecutionException, InterruptedException {
- insert();
- assertTrue(client.flush(randomCollectionName).ok());
- Index index =
- new Index.Builder(randomCollectionName, "float_vec")
- .withParamsInJson(new JsonBuilder().param("index_type", "IVF_SQ8")
- .param("metric_type", "L2")
- .indexParam("nlist", 2048)
- .build())
- .build();
- ListenableFuture<Response> createIndexResponseFuture = client.createIndexAsync(index);
- Futures.addCallback(
- createIndexResponseFuture,
- new FutureCallback<Response>() {
- @Override
- public void onSuccess(@NullableDecl Response createIndexResponse) {
- assert createIndexResponse != null;
- assertTrue(createIndexResponse.ok());
- }
- @Override
- public void onFailure(Throwable t) {
- System.out.println(t.getMessage());
- }
- }, MoreExecutors.directExecutor()
- );
- }
- @org.junit.jupiter.api.Test
- void insert() {
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- List<List<Float>> vectors = generateFloatVectors(size, dimension);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- List<Long> entityIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(entityIds)
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- assertEquals(size, insertResponse.getEntityIds().size());
- }
- @org.junit.jupiter.api.Test
- void insertAsync() throws ExecutionException, InterruptedException {
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- List<List<Float>> vectors = generateFloatVectors(size, dimension);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- List<Long> entityIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(entityIds)
- .build();
- ListenableFuture<InsertResponse> insertResponseFuture = client.insertAsync(insertParam);
- Futures.addCallback(
- insertResponseFuture,
- new FutureCallback<InsertResponse>() {
- @Override
- public void onSuccess(@NullableDecl InsertResponse insertResponse) {
- assert insertResponse != null;
- assertTrue(insertResponse.ok());
- assertEquals(size, insertResponse.getEntityIds().size());
- }
- @Override
- public void onFailure(Throwable t) {
- System.out.println(t.getMessage());
- }
- }, MoreExecutors.directExecutor()
- );
- }
- @org.junit.jupiter.api.Test
- void insertBinary() {
- final int binaryDimension = 10000;
- String binaryCollectionName = generator.generate(10);
- CollectionMapping collectionMapping = CollectionMapping
- .create(binaryCollectionName)
- .addVectorField("binary_vec", DataType.VECTOR_BINARY, binaryDimension);
- client.createCollection(collectionMapping);
- List<List<Byte>> vectors = generateBinaryVectors(size, binaryDimension);
- InsertParam insertParam =
- new Builder(binaryCollectionName)
- .field(new FieldBuilder("binary_vec", DataType.VECTOR_BINARY)
- .values(vectors)
- .build())
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- assertEquals(size, insertResponse.getEntityIds().size());
- Index index =
- new Index.Builder(binaryCollectionName, "binary_vec")
- .withParamsInJson(new JsonBuilder().param("index_type", "BIN_IVF_FLAT")
- .param("metric_type", "JACCARD")
- .indexParam("nlist", 100)
- .build())
- .build();
- Response createIndexResponse = client.createIndex(index);
- assertTrue(createIndexResponse.ok());
- // also test drop index here
- Response dropIndexResponse = client.dropIndex(binaryCollectionName, "binary_vec");
- assertTrue(dropIndexResponse.ok());
- assertTrue(client.dropCollection(binaryCollectionName).ok());
- }
- @org.junit.jupiter.api.Test
- void search() {
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- List<List<Float>> vectors = generateFloatVectors(size, dimension);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
- List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(insertIds)
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- List<Long> entityIds = insertResponse.getEntityIds();
- assertEquals(size, entityIds.size());
- assertTrue(client.flush(randomCollectionName).ok());
- final int searchSize = 5;
- List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
- final long topK = 10;
- SearchParam searchParam =
- new SearchParam.Builder(randomCollectionName)
- .withDSL(generateComplexDSL(topK, vectorsToSearch.toString()))
- .withParamsInJson(new JsonBuilder().param("fields",
- new ArrayList<>(Arrays.asList("int64", "float_vec"))).build())
- .build();
- SearchResponse searchResponse = client.search(searchParam);
- assertTrue(searchResponse.ok());
- List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
- assertEquals(searchSize, resultIdsList.size());
- List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
- assertEquals(searchSize, resultDistancesList.size());
- List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
- assertEquals(searchSize, queryResultsList.size());
- final double epsilon = 0.001;
- for (int i = 0; i < searchSize; i++) {
- SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
- assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
- assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
- assertTrue(Math.abs(firstQueryResult.getDistance()) < epsilon);
- assertTrue(Math.abs(resultDistancesList.get(i).get(0)) < epsilon);
- }
- }
- @org.junit.jupiter.api.Test
- void searchAsync() throws ExecutionException, InterruptedException {
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- List<List<Float>> vectors = generateFloatVectors(size, dimension);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
- List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(insertIds)
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- List<Long> entityIds = insertResponse.getEntityIds();
- assertEquals(size, entityIds.size());
- assertTrue(client.flush(randomCollectionName).ok());
- final int searchSize = 5;
- List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
- final long topK = 10;
- SearchParam searchParam =
- new SearchParam.Builder(randomCollectionName)
- .withDSL(generateComplexDSL(topK, vectorsToSearch.toString()))
- .withParamsInJson(new JsonBuilder().param("fields",
- new ArrayList<>(Arrays.asList("int64", "float"))).build())
- .build();
- ListenableFuture<SearchResponse> searchResponseFuture = client.searchAsync(searchParam);
- SearchResponse searchResponse = searchResponseFuture.get();
- assertTrue(searchResponse.ok());
- List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
- assertEquals(searchSize, resultIdsList.size());
- List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
- assertEquals(searchSize, resultDistancesList.size());
- List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
- assertEquals(searchSize, queryResultsList.size());
- final double epsilon = 0.001;
- for (int i = 0; i < searchSize; i++) {
- SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
- assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
- assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
- assertTrue(Math.abs(firstQueryResult.getDistance()) < epsilon);
- assertTrue(Math.abs(resultDistancesList.get(i).get(0)) < epsilon);
- }
- }
- @org.junit.jupiter.api.Test
- void searchBinary() {
- final int binaryDimension = 64;
- String binaryCollectionName = generator.generate(10);
- CollectionMapping collectionMapping = CollectionMapping
- .create(binaryCollectionName)
- .addField("int64", DataType.INT64)
- .addField("float", DataType.FLOAT)
- .addVectorField("binary_vec", DataType.VECTOR_BINARY, binaryDimension);
- client.createCollection(collectionMapping);
- // field list for insert
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- List<List<Byte>> vectors = generateBinaryVectors(size, binaryDimension);
- InsertParam insertParam =
- new Builder(binaryCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("binary_vec", DataType.VECTOR_BINARY)
- .values(vectors)
- .build())
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- List<Long> entityIds = insertResponse.getEntityIds();
- assertEquals(size, entityIds.size());
- assertTrue(client.flush(binaryCollectionName).ok());
- final int searchSize = 5;
- List<List<Byte>> vectorsToSearch = vectors.subList(0, searchSize);
- final long topK = 10;
- SearchParam searchParam =
- new SearchParam.Builder(binaryCollectionName)
- .withDSL(generateComplexDSLBinary(topK, vectorsToSearch.toString()))
- .build();
- SearchResponse searchResponse = client.search(searchParam);
- assertTrue(searchResponse.ok());
- List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
- assertEquals(searchSize, resultIdsList.size());
- List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
- assertEquals(searchSize, resultDistancesList.size());
- List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
- assertEquals(searchSize, queryResultsList.size());
- for (int i = 0; i < searchSize; i++) {
- SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
- assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
- assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
- }
- assertTrue(client.dropCollection(binaryCollectionName).ok());
- }
- @org.junit.jupiter.api.Test
- void getCollectionInfo() {
- GetCollectionInfoResponse getCollectionInfoResponse =
- client.getCollectionInfo(randomCollectionName);
- assertTrue(getCollectionInfoResponse.ok());
- assertTrue(getCollectionInfoResponse.getCollectionMapping().isPresent());
- assertEquals(
- getCollectionInfoResponse.getCollectionMapping().get().getCollectionName(),
- randomCollectionName);
- List<? extends Map<String, Object>> fields = getCollectionInfoResponse.getCollectionMapping()
- .get().getFields();
- for (Map<String, Object> field : fields) {
- if (field.get("field").equals("float_vec")) {
- JSONObject params = new JSONObject(field.get("params").toString());
- assertTrue(params.has("dim"));
- }
- }
- String nonExistingCollectionName = generator.generate(10);
- getCollectionInfoResponse = client.getCollectionInfo(nonExistingCollectionName);
- assertFalse(getCollectionInfoResponse.ok());
- assertFalse(getCollectionInfoResponse.getCollectionMapping().isPresent());
- }
- @org.junit.jupiter.api.Test
- void listCollections() {
- ListCollectionsResponse listCollectionsResponse = client.listCollections();
- assertTrue(listCollectionsResponse.ok());
- assertTrue(listCollectionsResponse.getCollectionNames().contains(randomCollectionName));
- }
- @org.junit.jupiter.api.Test
- void serverStatus() {
- Response serverStatusResponse = client.getServerStatus();
- assertTrue(serverStatusResponse.ok());
- }
- @org.junit.jupiter.api.Test
- void serverVersion() {
- Response serverVersionResponse = client.getServerVersion();
- assertTrue(serverVersionResponse.ok());
- }
- @org.junit.jupiter.api.Test
- void countEntities() {
- insert();
- assertTrue(client.flush(randomCollectionName).ok());
- CountEntitiesResponse countEntitiesResponse = client.countEntities(randomCollectionName);
- assertTrue(countEntitiesResponse.ok());
- assertEquals(size, countEntitiesResponse.getCollectionEntityCount());
- }
- @org.junit.jupiter.api.Test
- void loadCollection() {
- insert();
- assertTrue(client.flush(randomCollectionName).ok());
- Response loadCollectionResponse = client.loadCollection(randomCollectionName);
- assertTrue(loadCollectionResponse.ok());
- }
- @org.junit.jupiter.api.Test
- void getCollectionStats() {
- insert();
- assertTrue(client.flush(randomCollectionName).ok());
- Response getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
- assertTrue(getCollectionStatsResponse.ok());
- String jsonString = getCollectionStatsResponse.getMessage();
- JSONObject jsonInfo = new JSONObject(jsonString);
- assertEquals(jsonInfo.getInt("row_count"), size);
- JSONArray partitions = jsonInfo.getJSONArray("partitions");
- JSONObject partitionInfo = partitions.getJSONObject(0);
- assertEquals(partitionInfo.getString("tag"), "_default");
- assertEquals(partitionInfo.getInt("row_count"), size);
- }
- @org.junit.jupiter.api.Test
- @SuppressWarnings("unchecked")
- void getEntityByID() {
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- List<List<Float>> vectors = generateFloatVectors(size, dimension);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
- List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(insertIds)
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- List<Long> entityIds = insertResponse.getEntityIds();
- assertEquals(size, entityIds.size());
- assertTrue(client.flush(randomCollectionName).ok());
- GetEntityByIDResponse getEntityByIDResponse =
- client.getEntityByID(randomCollectionName, entityIds.subList(0, 100));
- assertTrue(getEntityByIDResponse.ok());
- int vecIndex = 0;
- List<Map<String, Object>> fieldsMap = getEntityByIDResponse.getFieldsMap();
- assertTrue(fieldsMap.get(vecIndex).get("float_vec") instanceof List);
- List<Float> first = (List<Float>) (fieldsMap.get(vecIndex).get("float_vec"));
- assertArrayEquals(first.toArray(), vectors.get(0).toArray());
- }
- @org.junit.jupiter.api.Test
- @SuppressWarnings("unchecked")
- void getEntityByIDBinary() {
- final int binaryDimension = 64;
- String binaryCollectionName = generator.generate(10);
- CollectionMapping collectionMapping = CollectionMapping
- .create(binaryCollectionName)
- .addVectorField("binary_vec", DataType.VECTOR_BINARY, binaryDimension)
- .setParamsInJson(new JsonBuilder().param("auto_id", false).build());
- client.createCollection(collectionMapping);
- List<List<Byte>> vectors = generateBinaryVectors(size, binaryDimension);
- List<Long> entityIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(binaryCollectionName)
- .field(new FieldBuilder("binary_vec", DataType.VECTOR_BINARY)
- .values(vectors)
- .build())
- .withEntityIds(entityIds)
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- assertEquals(size, insertResponse.getEntityIds().size());
- assertTrue(client.flush(binaryCollectionName).ok());
- GetEntityByIDResponse getEntityByIDResponse =
- client.getEntityByID(binaryCollectionName, entityIds.subList(0, 100));
- assertTrue(getEntityByIDResponse.ok());
- assertEquals(getEntityByIDResponse.getFieldsMap().size(), 100);
- List<Map<String, Object>> fieldsMap = getEntityByIDResponse.getFieldsMap();
- assertTrue(fieldsMap.get(0).get("binary_vec") instanceof List);
- List<Byte> first = (List<Byte>) (fieldsMap.get(0).get("binary_vec"));
- assertArrayEquals(first.toArray(), vectors.get(0).toArray());
- }
- @org.junit.jupiter.api.Test
- void getEntityIds() {
- insert();
- assertTrue(client.flush(randomCollectionName).ok());
- Response getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
- assertTrue(getCollectionStatsResponse.ok());
- JSONObject jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
- JSONObject segmentInfo =
- jsonInfo
- .getJSONArray("partitions")
- .getJSONObject(0)
- .getJSONArray("segments")
- .getJSONObject(0);
- ListIDInSegmentResponse listIDInSegmentResponse =
- client.listIDInSegment(randomCollectionName, segmentInfo.getLong("id"));
- assertTrue(listIDInSegmentResponse.ok());
- assertFalse(listIDInSegmentResponse.getIds().isEmpty());
- }
- @org.junit.jupiter.api.Test
- void deleteEntityByID() {
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- List<List<Float>> vectors = generateFloatVectors(size, dimension);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
- List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(insertIds)
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- assertEquals(size, insertResponse.getEntityIds().size());
- assertTrue(client.flush(randomCollectionName).ok());
- assertTrue(client.deleteEntityByID(randomCollectionName,
- insertResponse.getEntityIds().subList(0, 100)).ok());
- assertTrue(client.flush(randomCollectionName).ok());
- assertEquals(client.countEntities(randomCollectionName).getCollectionEntityCount(), size - 100);
- }
- @org.junit.jupiter.api.Test
- void flush() {
- assertTrue(client.flush(randomCollectionName).ok());
- }
- @org.junit.jupiter.api.Test
- void flushAsync() throws ExecutionException, InterruptedException {
- assertTrue(client.flushAsync(randomCollectionName).get().ok());
- }
- @org.junit.jupiter.api.Test
- void compact() {
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- List<List<Float>> vectors = generateFloatVectors(size, dimension);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
- List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(insertIds)
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- assertEquals(size, insertResponse.getEntityIds().size());
- assertTrue(client.flush(randomCollectionName).ok());
- Response getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
- assertTrue(getCollectionStatsResponse.ok());
- JSONObject jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
- long previousSegmentSize =
- jsonInfo
- .getJSONArray("partitions")
- .getJSONObject(0)
- .getLong("data_size");
- assertTrue(
- client.deleteEntityByID(randomCollectionName,
- insertResponse.getEntityIds().subList(0, size / 2)).ok());
- assertTrue(client.flush(randomCollectionName).ok());
- assertTrue(client.compact(
- new CompactParam.Builder(randomCollectionName).withThreshold(0.2).build()).ok());
- getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
- assertTrue(getCollectionStatsResponse.ok());
- jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
- long currentSegmentSize =
- jsonInfo
- .getJSONArray("partitions")
- .getJSONObject(0)
- .getLong("data_size");
- assertTrue(currentSegmentSize < previousSegmentSize);
- }
- @org.junit.jupiter.api.Test
- void compactAsync() throws ExecutionException, InterruptedException {
- List<Long> intValues = new ArrayList<>(size);
- List<Float> floatValues = new ArrayList<>(size);
- List<List<Float>> vectors = generateFloatVectors(size, dimension);
- for (int i = 0; i < size; i++) {
- intValues.add((long) i);
- floatValues.add((float) i);
- }
- vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
- List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
- InsertParam insertParam =
- new Builder(randomCollectionName)
- .field(new FieldBuilder("int64", DataType.INT64)
- .values(intValues)
- .build())
- .field(new FieldBuilder("float", DataType.FLOAT)
- .values(floatValues)
- .build())
- .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
- .values(vectors)
- .build())
- .withEntityIds(insertIds)
- .build();
- InsertResponse insertResponse = client.insert(insertParam);
- assertTrue(insertResponse.ok());
- assertEquals(size, insertResponse.getEntityIds().size());
- assertTrue(client.flush(randomCollectionName).ok());
- Response getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
- assertTrue(getCollectionStatsResponse.ok());
- JSONObject jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
- JSONObject segmentInfo =
- jsonInfo
- .getJSONArray("partitions")
- .getJSONObject(0)
- .getJSONArray("segments")
- .getJSONObject(0);
- long previousSegmentSize = segmentInfo.getLong("data_size");
- assertTrue(
- client.deleteEntityByID(randomCollectionName,
- insertResponse.getEntityIds().subList(0, size / 2)).ok());
- assertTrue(client.flush(randomCollectionName).ok());
- assertTrue(client.compactAsync(
- new CompactParam.Builder(randomCollectionName).withThreshold(0.8).build()).get().ok());
- getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
- assertTrue(getCollectionStatsResponse.ok());
- jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
- segmentInfo =
- jsonInfo
- .getJSONArray("partitions")
- .getJSONObject(0)
- .getJSONArray("segments")
- .getJSONObject(0);
- long currentSegmentSize = segmentInfo.getLong("data_size");
- assertFalse(currentSegmentSize < previousSegmentSize); // threshold 0.8 > 0.5, no compact
- }
- }
|