MilvusGrpcClientTest.java 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.client;
  20. import com.google.common.util.concurrent.FutureCallback;
  21. import com.google.common.util.concurrent.Futures;
  22. import com.google.common.util.concurrent.ListenableFuture;
  23. import com.google.common.util.concurrent.MoreExecutors;
  24. import io.grpc.NameResolverProvider;
  25. import io.grpc.NameResolverRegistry;
  26. import io.milvus.client.InsertParam.Builder;
  27. import io.milvus.client.exception.ClientSideMilvusException;
  28. import io.milvus.client.exception.InitializationException;
  29. import io.milvus.client.exception.ServerSideMilvusException;
  30. import io.milvus.client.exception.UnsupportedServerVersion;
  31. import io.milvus.grpc.ErrorCode;
  32. import org.apache.commons.lang3.ArrayUtils;
  33. import org.apache.commons.text.RandomStringGenerator;
  34. import org.checkerframework.checker.nullness.compatqual.NullableDecl;
  35. import org.json.JSONArray;
  36. import org.json.JSONObject;
  37. import org.junit.jupiter.api.condition.DisabledIfSystemProperty;
  38. import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
  39. import org.testcontainers.containers.GenericContainer;
  40. import org.testcontainers.junit.jupiter.Container;
  41. import org.testcontainers.junit.jupiter.Testcontainers;
  42. import java.net.InetSocketAddress;
  43. import java.nio.ByteBuffer;
  44. import java.util.ArrayList;
  45. import java.util.Arrays;
  46. import java.util.Collections;
  47. import java.util.List;
  48. import java.util.Map;
  49. import java.util.Random;
  50. import java.util.SplittableRandom;
  51. import java.util.concurrent.ExecutionException;
  52. import java.util.concurrent.TimeUnit;
  53. import java.util.stream.Collectors;
  54. import java.util.stream.DoubleStream;
  55. import java.util.stream.IntStream;
  56. import java.util.stream.LongStream;
  57. import static org.junit.jupiter.api.Assertions.assertArrayEquals;
  58. import static org.junit.jupiter.api.Assertions.assertEquals;
  59. import static org.junit.jupiter.api.Assertions.assertFalse;
  60. import static org.junit.jupiter.api.Assertions.assertThrows;
  61. import static org.junit.jupiter.api.Assertions.assertTrue;
  62. @Testcontainers
  63. @EnabledIfSystemProperty(named = "with-containers", matches = "true")
  64. class ContainerMilvusClientTest extends MilvusClientTest {
  65. @Container
  66. private GenericContainer milvusContainer =
  67. new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu"))
  68. .withExposedPorts(19530);
  69. @Container
  70. private static GenericContainer milvusContainer2 =
  71. new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu"))
  72. .withExposedPorts(19530);
  73. @Override
  74. protected ConnectParam.Builder connectParamBuilder() {
  75. return connectParamBuilder(milvusContainer);
  76. }
  77. @org.junit.jupiter.api.Test
  78. void loadBalancing() {
  79. NameResolverProvider testNameResolverProvider = new StaticNameResolverProvider(
  80. new InetSocketAddress(milvusContainer.getHost(), milvusContainer.getFirstMappedPort()),
  81. new InetSocketAddress(milvusContainer2.getHost(), milvusContainer2.getFirstMappedPort()));
  82. NameResolverRegistry.getDefaultRegistry().register(testNameResolverProvider);
  83. ConnectParam connectParam = connectParamBuilder()
  84. .withTarget(testNameResolverProvider.getDefaultScheme() + ":///test")
  85. .build();
  86. MilvusClient loadBalancingClient = new MilvusGrpcClient(connectParam);
  87. assertEquals(50, IntStream.range(0, 100)
  88. .filter(i -> loadBalancingClient.hasCollection(randomCollectionName).hasCollection())
  89. .count());
  90. }
  91. }
  92. @Testcontainers
  93. @DisabledIfSystemProperty(named = "with-containers", matches = "true")
  94. class MilvusClientTest {
  95. private MilvusClient client;
  96. private RandomStringGenerator generator;
  97. protected String randomCollectionName;
  98. private int size;
  99. private int dimension;
  100. protected ConnectParam.Builder connectParamBuilder() {
  101. return connectParamBuilder("localhost", 19530);
  102. }
  103. protected ConnectParam.Builder connectParamBuilder(GenericContainer milvusContainer) {
  104. return connectParamBuilder(milvusContainer.getHost(), milvusContainer.getFirstMappedPort());
  105. }
  106. protected ConnectParam.Builder connectParamBuilder(String host, int port) {
  107. return new ConnectParam.Builder().withHost(host).withPort(port);
  108. }
  109. protected void assertErrorCode(ErrorCode errorCode, Runnable runnable) {
  110. assertEquals(errorCode, assertThrows(ServerSideMilvusException.class, runnable::run).getErrorCode());
  111. }
  112. // Helper function that generates random float vectors
  113. static List<List<Float>> generateFloatVectors(int vectorCount, int dimension) {
  114. SplittableRandom splittableRandom = new SplittableRandom();
  115. List<List<Float>> vectors = new ArrayList<>(vectorCount);
  116. for (int i = 0; i < vectorCount; ++i) {
  117. splittableRandom = splittableRandom.split();
  118. DoubleStream doubleStream = splittableRandom.doubles(dimension);
  119. List<Float> vector =
  120. doubleStream.boxed().map(Double::floatValue).collect(Collectors.toList());
  121. vectors.add(vector);
  122. }
  123. return vectors;
  124. }
  125. // Helper function that generates random binary vectors
  126. static List<List<Byte>> generateBinaryVectors(int vectorCount, int dimension) {
  127. Random random = new Random();
  128. List<List<Byte>> vectors = new ArrayList<>(vectorCount);
  129. final int dimensionInByte = dimension / 8;
  130. for (int i = 0; i < vectorCount; ++i) {
  131. ByteBuffer byteBuffer = ByteBuffer.allocate(dimensionInByte);
  132. random.nextBytes(byteBuffer.array());
  133. byte[] b = new byte[byteBuffer.remaining()];
  134. byteBuffer.get(b);
  135. vectors.add(Arrays.asList(ArrayUtils.toObject(b)));
  136. }
  137. return vectors;
  138. }
  139. // Helper function that normalizes a vector if you are using IP (Inner Product) as your metric
  140. // type
  141. static List<Float> normalizeVector(List<Float> vector) {
  142. float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
  143. final float norm = (float) Math.sqrt(squareSum);
  144. vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
  145. return vector;
  146. }
  147. // Helper function that generate a simple DSL statement with vector filtering only
  148. static String generateSimpleDSL(Long topK, String query) {
  149. return String.format(
  150. "{\"bool\": {"
  151. + "\"must\": [{"
  152. + " \"vector\": {"
  153. + " \"float_vec\": {"
  154. + " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
  155. + " }}}]}}", topK, query);
  156. }
  157. // Helper function that generate a complex DSL statement with scalar field filtering
  158. static String generateComplexDSL(Long topK, String query) {
  159. return String.format(
  160. "{\"bool\": {"
  161. + "\"must\": [{"
  162. + " \"must\": [{"
  163. + " \"vector\": {"
  164. + " \"float_vec\": {"
  165. + " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
  166. + "}}}]}]}}",
  167. topK, query);
  168. }
  169. // Helper function that generate a complex DSL statement with scalar field filtering
  170. static String generateComplexDSLBinary(Long topK, String query) {
  171. return String.format(
  172. "{\"bool\": {"
  173. + "\"must\": [{"
  174. + " \"vector\": {"
  175. + " \"binary_vec\": {"
  176. + " \"topk\": %d, \"metric_type\": \"JACCARD\", \"type\": \"binary\", \"query\": %s, \"params\": {\"nprobe\": 20}"
  177. + " }}}]}}",
  178. topK, query);
  179. }
  180. @org.junit.jupiter.api.BeforeEach
  181. void setUp() throws Exception {
  182. ConnectParam connectParam = connectParamBuilder().build();
  183. client = new MilvusGrpcClient(connectParam);
  184. generator = new RandomStringGenerator.Builder().withinRange('a', 'z').build();
  185. randomCollectionName = generator.generate(10);
  186. size = 100000;
  187. dimension = 128;
  188. CollectionMapping collectionMapping = CollectionMapping
  189. .create(randomCollectionName)
  190. .addField("int64", DataType.INT64)
  191. .addField("float", DataType.FLOAT)
  192. .addVectorField("float_vec", DataType.VECTOR_FLOAT, dimension)
  193. .setParamsInJson(new JsonBuilder()
  194. .param("segment_row_limit", 50000)
  195. .param("auto_id", false)
  196. .build());
  197. client.createCollection(collectionMapping);
  198. }
  199. @org.junit.jupiter.api.AfterEach
  200. void tearDown() {
  201. client.dropCollection(randomCollectionName);
  202. client.close();
  203. }
  204. @org.junit.jupiter.api.Test
  205. void idleTest() throws InterruptedException {
  206. ConnectParam connectParam = connectParamBuilder()
  207. .withIdleTimeout(1, TimeUnit.SECONDS)
  208. .build();
  209. MilvusClient client = new MilvusGrpcClient(connectParam);
  210. TimeUnit.SECONDS.sleep(2);
  211. // A new RPC would take the channel out of idle mode
  212. assertTrue(client.listCollections().ok());
  213. }
  214. @org.junit.jupiter.api.Test
  215. void setInvalidConnectParam() {
  216. assertThrows(
  217. IllegalArgumentException.class,
  218. () -> {
  219. ConnectParam connectParam = new ConnectParam.Builder().withPort(66666).build();
  220. });
  221. assertThrows(
  222. IllegalArgumentException.class,
  223. () -> {
  224. ConnectParam connectParam =
  225. new ConnectParam.Builder().withConnectTimeout(-1, TimeUnit.MILLISECONDS).build();
  226. });
  227. assertThrows(
  228. IllegalArgumentException.class,
  229. () -> {
  230. ConnectParam connectParam =
  231. new ConnectParam.Builder().withKeepAliveTime(-1, TimeUnit.MILLISECONDS).build();
  232. });
  233. assertThrows(
  234. IllegalArgumentException.class,
  235. () -> {
  236. ConnectParam connectParam =
  237. new ConnectParam.Builder().withKeepAliveTimeout(-1, TimeUnit.MILLISECONDS).build();
  238. });
  239. assertThrows(
  240. IllegalArgumentException.class,
  241. () -> {
  242. ConnectParam connectParam =
  243. new ConnectParam.Builder().withIdleTimeout(-1, TimeUnit.MILLISECONDS).build();
  244. });
  245. }
  246. @org.junit.jupiter.api.Test
  247. void connectUnreachableHost() {
  248. ConnectParam connectParam = connectParamBuilder("250.250.250.250", 19530).build();
  249. assertThrows(InitializationException.class, () -> new MilvusGrpcClient(connectParam));
  250. }
  251. @org.junit.jupiter.api.Test
  252. void unsupportedServerVersion() {
  253. GenericContainer unsupportedMilvusContainer =
  254. new GenericContainer("milvusdb/milvus:0.9.1-cpu-d052920-e04ed5")
  255. .withExposedPorts(19530);
  256. try {
  257. unsupportedMilvusContainer.start();
  258. ConnectParam connectParam = connectParamBuilder(unsupportedMilvusContainer).build();
  259. assertThrows(UnsupportedServerVersion.class, () -> new MilvusGrpcClient(connectParam));
  260. } finally {
  261. unsupportedMilvusContainer.stop();
  262. }
  263. }
  264. @org.junit.jupiter.api.Test
  265. void grpcTimeout() {
  266. insert();
  267. MilvusClient timeoutClient = client.withTimeout(1, TimeUnit.MILLISECONDS);
  268. Response response = timeoutClient.createIndex(
  269. new Index.Builder(randomCollectionName, "float_vec")
  270. .withParamsInJson(new JsonBuilder()
  271. .param("index_type", "IVF_FLAT")
  272. .param("metric_type", "L2")
  273. .indexParam("nlist", 2048)
  274. .build())
  275. .build());
  276. assertEquals(Response.Status.RPC_ERROR, response.getStatus());
  277. }
  278. @org.junit.jupiter.api.Test
  279. void createInvalidCollection() {
  280. // invalid collection name
  281. CollectionMapping invalidCollectionName = CollectionMapping
  282. .create("╯°□°)╯")
  283. .addVectorField("float_vec", DataType.VECTOR_FLOAT, dimension);
  284. assertErrorCode(ErrorCode.ILLEGAL_COLLECTION_NAME, () -> client.createCollection(invalidCollectionName));
  285. // invalid field
  286. CollectionMapping withoutField = CollectionMapping.create("validCollectionName");
  287. assertThrows(ClientSideMilvusException.class, () -> client.createCollection(withoutField));
  288. // invalid segment_row_count
  289. CollectionMapping invalidSegmentRowCount = CollectionMapping
  290. .create("validCollectionName")
  291. .addField("int64", DataType.INT64)
  292. .addField("float", DataType.FLOAT)
  293. .addVectorField("float_vec", DataType.VECTOR_FLOAT, dimension)
  294. .setParamsInJson(new JsonBuilder().param("segment_row_limit", -1000).build());
  295. assertErrorCode(ErrorCode.ILLEGAL_ARGUMENT, () -> client.createCollection(invalidSegmentRowCount));
  296. }
  297. @org.junit.jupiter.api.Test
  298. void hasCollection() {
  299. assertTrue(client.hasCollection(randomCollectionName));
  300. }
  301. @org.junit.jupiter.api.Test
  302. void dropCollection() {
  303. String nonExistingCollectionName = generator.generate(10);
  304. Response dropCollectionResponse = client.dropCollection(nonExistingCollectionName);
  305. assertFalse(dropCollectionResponse.ok());
  306. assertEquals(Response.Status.COLLECTION_NOT_EXISTS, dropCollectionResponse.getStatus());
  307. }
  308. @org.junit.jupiter.api.Test
  309. @SuppressWarnings("unchecked")
  310. void partitionTest() {
  311. final String tag1 = "tag1";
  312. Response createPartitionResponse = client.createPartition(randomCollectionName, tag1);
  313. assertTrue(createPartitionResponse.ok());
  314. final String tag2 = "tag2";
  315. createPartitionResponse = client.createPartition(randomCollectionName, tag2);
  316. assertTrue(createPartitionResponse.ok());
  317. ListPartitionsResponse listPartitionsResponse = client.listPartitions(randomCollectionName);
  318. assertTrue(listPartitionsResponse.ok());
  319. assertEquals(3, listPartitionsResponse.getPartitionList().size()); // two tags plus _default
  320. List<Long> intValues = new ArrayList<>(size);
  321. List<Float> floatValues = new ArrayList<>(size);
  322. List<List<Float>> vectors = generateFloatVectors(size, dimension);
  323. for (int i = 0; i < size; i++) {
  324. intValues.add((long) i);
  325. floatValues.add((float) i);
  326. }
  327. List<Long> entityIds1 = LongStream.range(0, size).boxed().collect(Collectors.toList());
  328. InsertParam insertParam =
  329. new Builder(randomCollectionName)
  330. .field(new FieldBuilder("int64", DataType.INT64)
  331. .values(intValues)
  332. .build())
  333. .field(new FieldBuilder("float", DataType.FLOAT)
  334. .values(floatValues)
  335. .build())
  336. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  337. .values(vectors)
  338. .build())
  339. .withEntityIds(entityIds1)
  340. .withPartitionTag(tag1)
  341. .build();
  342. InsertResponse insertResponse = client.insert(insertParam);
  343. assertTrue(insertResponse.ok());
  344. List<Long> entityIds2 = LongStream.range(size, size * 2).boxed().collect(Collectors.toList());
  345. insertParam =
  346. new Builder(randomCollectionName)
  347. .field(new FieldBuilder("int64", DataType.INT64)
  348. .values(intValues)
  349. .build())
  350. .field(new FieldBuilder("float", DataType.FLOAT)
  351. .values(floatValues)
  352. .build())
  353. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  354. .values(vectors)
  355. .build())
  356. .withEntityIds(entityIds2)
  357. .withPartitionTag(tag2)
  358. .build();
  359. insertResponse = client.insert(insertParam);
  360. assertTrue(insertResponse.ok());
  361. assertTrue(client.flush(randomCollectionName).ok());
  362. assertEquals(size * 2,
  363. client.countEntities(randomCollectionName).getCollectionEntityCount());
  364. final int searchSize = 1;
  365. final long topK = 10;
  366. List<List<Float>> vectorsToSearch1 = vectors.subList(0, searchSize);
  367. List<String> partitionTags1 = new ArrayList<>();
  368. partitionTags1.add(tag1);
  369. SearchParam searchParam1 =
  370. new SearchParam.Builder(randomCollectionName)
  371. .withDSL(generateSimpleDSL(topK, vectorsToSearch1.toString()))
  372. .withPartitionTags(partitionTags1)
  373. .build();
  374. SearchResponse searchResponse1 = client.search(searchParam1);
  375. assertTrue(searchResponse1.ok());
  376. List<List<Long>> resultIdsList1 = searchResponse1.getResultIdsList();
  377. assertEquals(searchSize, resultIdsList1.size());
  378. assertTrue(entityIds1.containsAll(resultIdsList1.get(0)));
  379. List<List<Float>> vectorsToSearch2 = vectors.subList(0, searchSize);
  380. List<String> partitionTags2 = new ArrayList<>();
  381. partitionTags2.add(tag2);
  382. SearchParam searchParam2 =
  383. new SearchParam.Builder(randomCollectionName)
  384. .withDSL(generateSimpleDSL(topK, vectorsToSearch2.toString()))
  385. .withPartitionTags(partitionTags2)
  386. .build();
  387. SearchResponse searchResponse2 = client.search(searchParam2);
  388. assertTrue(searchResponse2.ok());
  389. List<List<Long>> resultIdsList2 = searchResponse2.getResultIdsList();
  390. assertEquals(searchSize, resultIdsList2.size());
  391. assertTrue(entityIds2.containsAll(resultIdsList2.get(0)));
  392. assertTrue(Collections.disjoint(resultIdsList1, resultIdsList2));
  393. HasPartitionResponse testHasPartition = client.hasPartition(randomCollectionName, tag1);
  394. assertTrue(testHasPartition.hasPartition());
  395. Response dropPartitionResponse = client.dropPartition(randomCollectionName, tag1);
  396. assertTrue(dropPartitionResponse.ok());
  397. testHasPartition = client.hasPartition(randomCollectionName, tag1);
  398. assertFalse(testHasPartition.hasPartition());
  399. dropPartitionResponse = client.dropPartition(randomCollectionName, tag2);
  400. assertTrue(dropPartitionResponse.ok());
  401. }
  402. @org.junit.jupiter.api.Test
  403. void createIndex() {
  404. insert();
  405. assertTrue(client.flush(randomCollectionName).ok());
  406. Index index =
  407. new Index.Builder(randomCollectionName, "float_vec")
  408. .withParamsInJson(new JsonBuilder().param("index_type", "IVF_SQ8")
  409. .param("metric_type", "L2")
  410. .indexParam("nlist", 2048)
  411. .build())
  412. .build();
  413. Response createIndexResponse = client.createIndex(index);
  414. assertTrue(createIndexResponse.ok());
  415. // also test drop index here
  416. Response dropIndexResponse = client.dropIndex(randomCollectionName, "float_vec");
  417. assertTrue(dropIndexResponse.ok());
  418. }
  419. @org.junit.jupiter.api.Test
  420. void createIndexAsync() throws ExecutionException, InterruptedException {
  421. insert();
  422. assertTrue(client.flush(randomCollectionName).ok());
  423. Index index =
  424. new Index.Builder(randomCollectionName, "float_vec")
  425. .withParamsInJson(new JsonBuilder().param("index_type", "IVF_SQ8")
  426. .param("metric_type", "L2")
  427. .indexParam("nlist", 2048)
  428. .build())
  429. .build();
  430. ListenableFuture<Response> createIndexResponseFuture = client.createIndexAsync(index);
  431. Futures.addCallback(
  432. createIndexResponseFuture,
  433. new FutureCallback<Response>() {
  434. @Override
  435. public void onSuccess(@NullableDecl Response createIndexResponse) {
  436. assert createIndexResponse != null;
  437. assertTrue(createIndexResponse.ok());
  438. }
  439. @Override
  440. public void onFailure(Throwable t) {
  441. System.out.println(t.getMessage());
  442. }
  443. }, MoreExecutors.directExecutor()
  444. );
  445. }
  446. @org.junit.jupiter.api.Test
  447. void insert() {
  448. List<Long> intValues = new ArrayList<>(size);
  449. List<Float> floatValues = new ArrayList<>(size);
  450. List<List<Float>> vectors = generateFloatVectors(size, dimension);
  451. for (int i = 0; i < size; i++) {
  452. intValues.add((long) i);
  453. floatValues.add((float) i);
  454. }
  455. List<Long> entityIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
  456. InsertParam insertParam =
  457. new Builder(randomCollectionName)
  458. .field(new FieldBuilder("int64", DataType.INT64)
  459. .values(intValues)
  460. .build())
  461. .field(new FieldBuilder("float", DataType.FLOAT)
  462. .values(floatValues)
  463. .build())
  464. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  465. .values(vectors)
  466. .build())
  467. .withEntityIds(entityIds)
  468. .build();
  469. InsertResponse insertResponse = client.insert(insertParam);
  470. assertTrue(insertResponse.ok());
  471. assertEquals(size, insertResponse.getEntityIds().size());
  472. }
  473. @org.junit.jupiter.api.Test
  474. void insertAsync() throws ExecutionException, InterruptedException {
  475. List<Long> intValues = new ArrayList<>(size);
  476. List<Float> floatValues = new ArrayList<>(size);
  477. List<List<Float>> vectors = generateFloatVectors(size, dimension);
  478. for (int i = 0; i < size; i++) {
  479. intValues.add((long) i);
  480. floatValues.add((float) i);
  481. }
  482. List<Long> entityIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
  483. InsertParam insertParam =
  484. new Builder(randomCollectionName)
  485. .field(new FieldBuilder("int64", DataType.INT64)
  486. .values(intValues)
  487. .build())
  488. .field(new FieldBuilder("float", DataType.FLOAT)
  489. .values(floatValues)
  490. .build())
  491. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  492. .values(vectors)
  493. .build())
  494. .withEntityIds(entityIds)
  495. .build();
  496. ListenableFuture<InsertResponse> insertResponseFuture = client.insertAsync(insertParam);
  497. Futures.addCallback(
  498. insertResponseFuture,
  499. new FutureCallback<InsertResponse>() {
  500. @Override
  501. public void onSuccess(@NullableDecl InsertResponse insertResponse) {
  502. assert insertResponse != null;
  503. assertTrue(insertResponse.ok());
  504. assertEquals(size, insertResponse.getEntityIds().size());
  505. }
  506. @Override
  507. public void onFailure(Throwable t) {
  508. System.out.println(t.getMessage());
  509. }
  510. }, MoreExecutors.directExecutor()
  511. );
  512. }
  513. @org.junit.jupiter.api.Test
  514. void insertBinary() {
  515. final int binaryDimension = 10000;
  516. String binaryCollectionName = generator.generate(10);
  517. CollectionMapping collectionMapping = CollectionMapping
  518. .create(binaryCollectionName)
  519. .addVectorField("binary_vec", DataType.VECTOR_BINARY, binaryDimension);
  520. client.createCollection(collectionMapping);
  521. List<List<Byte>> vectors = generateBinaryVectors(size, binaryDimension);
  522. InsertParam insertParam =
  523. new Builder(binaryCollectionName)
  524. .field(new FieldBuilder("binary_vec", DataType.VECTOR_BINARY)
  525. .values(vectors)
  526. .build())
  527. .build();
  528. InsertResponse insertResponse = client.insert(insertParam);
  529. assertTrue(insertResponse.ok());
  530. assertEquals(size, insertResponse.getEntityIds().size());
  531. Index index =
  532. new Index.Builder(binaryCollectionName, "binary_vec")
  533. .withParamsInJson(new JsonBuilder().param("index_type", "BIN_IVF_FLAT")
  534. .param("metric_type", "JACCARD")
  535. .indexParam("nlist", 100)
  536. .build())
  537. .build();
  538. Response createIndexResponse = client.createIndex(index);
  539. assertTrue(createIndexResponse.ok());
  540. // also test drop index here
  541. Response dropIndexResponse = client.dropIndex(binaryCollectionName, "binary_vec");
  542. assertTrue(dropIndexResponse.ok());
  543. assertTrue(client.dropCollection(binaryCollectionName).ok());
  544. }
  545. @org.junit.jupiter.api.Test
  546. void search() {
  547. List<Long> intValues = new ArrayList<>(size);
  548. List<Float> floatValues = new ArrayList<>(size);
  549. List<List<Float>> vectors = generateFloatVectors(size, dimension);
  550. for (int i = 0; i < size; i++) {
  551. intValues.add((long) i);
  552. floatValues.add((float) i);
  553. }
  554. vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
  555. List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
  556. InsertParam insertParam =
  557. new Builder(randomCollectionName)
  558. .field(new FieldBuilder("int64", DataType.INT64)
  559. .values(intValues)
  560. .build())
  561. .field(new FieldBuilder("float", DataType.FLOAT)
  562. .values(floatValues)
  563. .build())
  564. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  565. .values(vectors)
  566. .build())
  567. .withEntityIds(insertIds)
  568. .build();
  569. InsertResponse insertResponse = client.insert(insertParam);
  570. assertTrue(insertResponse.ok());
  571. List<Long> entityIds = insertResponse.getEntityIds();
  572. assertEquals(size, entityIds.size());
  573. assertTrue(client.flush(randomCollectionName).ok());
  574. final int searchSize = 5;
  575. List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
  576. final long topK = 10;
  577. SearchParam searchParam =
  578. new SearchParam.Builder(randomCollectionName)
  579. .withDSL(generateComplexDSL(topK, vectorsToSearch.toString()))
  580. .withParamsInJson(new JsonBuilder().param("fields",
  581. new ArrayList<>(Arrays.asList("int64", "float_vec"))).build())
  582. .build();
  583. SearchResponse searchResponse = client.search(searchParam);
  584. assertTrue(searchResponse.ok());
  585. List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
  586. assertEquals(searchSize, resultIdsList.size());
  587. List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
  588. assertEquals(searchSize, resultDistancesList.size());
  589. List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
  590. assertEquals(searchSize, queryResultsList.size());
  591. final double epsilon = 0.001;
  592. for (int i = 0; i < searchSize; i++) {
  593. SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
  594. assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
  595. assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
  596. assertTrue(Math.abs(firstQueryResult.getDistance()) < epsilon);
  597. assertTrue(Math.abs(resultDistancesList.get(i).get(0)) < epsilon);
  598. }
  599. }
  600. @org.junit.jupiter.api.Test
  601. void searchAsync() throws ExecutionException, InterruptedException {
  602. List<Long> intValues = new ArrayList<>(size);
  603. List<Float> floatValues = new ArrayList<>(size);
  604. List<List<Float>> vectors = generateFloatVectors(size, dimension);
  605. for (int i = 0; i < size; i++) {
  606. intValues.add((long) i);
  607. floatValues.add((float) i);
  608. }
  609. vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
  610. List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
  611. InsertParam insertParam =
  612. new Builder(randomCollectionName)
  613. .field(new FieldBuilder("int64", DataType.INT64)
  614. .values(intValues)
  615. .build())
  616. .field(new FieldBuilder("float", DataType.FLOAT)
  617. .values(floatValues)
  618. .build())
  619. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  620. .values(vectors)
  621. .build())
  622. .withEntityIds(insertIds)
  623. .build();
  624. InsertResponse insertResponse = client.insert(insertParam);
  625. assertTrue(insertResponse.ok());
  626. List<Long> entityIds = insertResponse.getEntityIds();
  627. assertEquals(size, entityIds.size());
  628. assertTrue(client.flush(randomCollectionName).ok());
  629. final int searchSize = 5;
  630. List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
  631. final long topK = 10;
  632. SearchParam searchParam =
  633. new SearchParam.Builder(randomCollectionName)
  634. .withDSL(generateComplexDSL(topK, vectorsToSearch.toString()))
  635. .withParamsInJson(new JsonBuilder().param("fields",
  636. new ArrayList<>(Arrays.asList("int64", "float"))).build())
  637. .build();
  638. ListenableFuture<SearchResponse> searchResponseFuture = client.searchAsync(searchParam);
  639. SearchResponse searchResponse = searchResponseFuture.get();
  640. assertTrue(searchResponse.ok());
  641. List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
  642. assertEquals(searchSize, resultIdsList.size());
  643. List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
  644. assertEquals(searchSize, resultDistancesList.size());
  645. List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
  646. assertEquals(searchSize, queryResultsList.size());
  647. final double epsilon = 0.001;
  648. for (int i = 0; i < searchSize; i++) {
  649. SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
  650. assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
  651. assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
  652. assertTrue(Math.abs(firstQueryResult.getDistance()) < epsilon);
  653. assertTrue(Math.abs(resultDistancesList.get(i).get(0)) < epsilon);
  654. }
  655. }
  656. @org.junit.jupiter.api.Test
  657. void searchBinary() {
  658. final int binaryDimension = 64;
  659. String binaryCollectionName = generator.generate(10);
  660. CollectionMapping collectionMapping = CollectionMapping
  661. .create(binaryCollectionName)
  662. .addField("int64", DataType.INT64)
  663. .addField("float", DataType.FLOAT)
  664. .addVectorField("binary_vec", DataType.VECTOR_BINARY, binaryDimension);
  665. client.createCollection(collectionMapping);
  666. // field list for insert
  667. List<Long> intValues = new ArrayList<>(size);
  668. List<Float> floatValues = new ArrayList<>(size);
  669. for (int i = 0; i < size; i++) {
  670. intValues.add((long) i);
  671. floatValues.add((float) i);
  672. }
  673. List<List<Byte>> vectors = generateBinaryVectors(size, binaryDimension);
  674. InsertParam insertParam =
  675. new Builder(binaryCollectionName)
  676. .field(new FieldBuilder("int64", DataType.INT64)
  677. .values(intValues)
  678. .build())
  679. .field(new FieldBuilder("float", DataType.FLOAT)
  680. .values(floatValues)
  681. .build())
  682. .field(new FieldBuilder("binary_vec", DataType.VECTOR_BINARY)
  683. .values(vectors)
  684. .build())
  685. .build();
  686. InsertResponse insertResponse = client.insert(insertParam);
  687. assertTrue(insertResponse.ok());
  688. List<Long> entityIds = insertResponse.getEntityIds();
  689. assertEquals(size, entityIds.size());
  690. assertTrue(client.flush(binaryCollectionName).ok());
  691. final int searchSize = 5;
  692. List<List<Byte>> vectorsToSearch = vectors.subList(0, searchSize);
  693. final long topK = 10;
  694. SearchParam searchParam =
  695. new SearchParam.Builder(binaryCollectionName)
  696. .withDSL(generateComplexDSLBinary(topK, vectorsToSearch.toString()))
  697. .build();
  698. SearchResponse searchResponse = client.search(searchParam);
  699. assertTrue(searchResponse.ok());
  700. List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
  701. assertEquals(searchSize, resultIdsList.size());
  702. List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
  703. assertEquals(searchSize, resultDistancesList.size());
  704. List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
  705. assertEquals(searchSize, queryResultsList.size());
  706. for (int i = 0; i < searchSize; i++) {
  707. SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
  708. assertEquals(entityIds.get(i), firstQueryResult.getEntityId());
  709. assertEquals(entityIds.get(i), resultIdsList.get(i).get(0));
  710. }
  711. assertTrue(client.dropCollection(binaryCollectionName).ok());
  712. }
  713. @org.junit.jupiter.api.Test
  714. void getCollectionInfo() {
  715. GetCollectionInfoResponse getCollectionInfoResponse =
  716. client.getCollectionInfo(randomCollectionName);
  717. assertTrue(getCollectionInfoResponse.ok());
  718. assertTrue(getCollectionInfoResponse.getCollectionMapping().isPresent());
  719. assertEquals(
  720. getCollectionInfoResponse.getCollectionMapping().get().getCollectionName(),
  721. randomCollectionName);
  722. List<? extends Map<String, Object>> fields = getCollectionInfoResponse.getCollectionMapping()
  723. .get().getFields();
  724. for (Map<String, Object> field : fields) {
  725. if (field.get("field").equals("float_vec")) {
  726. JSONObject params = new JSONObject(field.get("params").toString());
  727. assertTrue(params.has("dim"));
  728. }
  729. }
  730. String nonExistingCollectionName = generator.generate(10);
  731. getCollectionInfoResponse = client.getCollectionInfo(nonExistingCollectionName);
  732. assertFalse(getCollectionInfoResponse.ok());
  733. assertFalse(getCollectionInfoResponse.getCollectionMapping().isPresent());
  734. }
  735. @org.junit.jupiter.api.Test
  736. void listCollections() {
  737. ListCollectionsResponse listCollectionsResponse = client.listCollections();
  738. assertTrue(listCollectionsResponse.ok());
  739. assertTrue(listCollectionsResponse.getCollectionNames().contains(randomCollectionName));
  740. }
  741. @org.junit.jupiter.api.Test
  742. void serverStatus() {
  743. Response serverStatusResponse = client.getServerStatus();
  744. assertTrue(serverStatusResponse.ok());
  745. }
  746. @org.junit.jupiter.api.Test
  747. void serverVersion() {
  748. Response serverVersionResponse = client.getServerVersion();
  749. assertTrue(serverVersionResponse.ok());
  750. }
  751. @org.junit.jupiter.api.Test
  752. void countEntities() {
  753. insert();
  754. assertTrue(client.flush(randomCollectionName).ok());
  755. CountEntitiesResponse countEntitiesResponse = client.countEntities(randomCollectionName);
  756. assertTrue(countEntitiesResponse.ok());
  757. assertEquals(size, countEntitiesResponse.getCollectionEntityCount());
  758. }
  759. @org.junit.jupiter.api.Test
  760. void loadCollection() {
  761. insert();
  762. assertTrue(client.flush(randomCollectionName).ok());
  763. Response loadCollectionResponse = client.loadCollection(randomCollectionName);
  764. assertTrue(loadCollectionResponse.ok());
  765. }
  766. @org.junit.jupiter.api.Test
  767. void getCollectionStats() {
  768. insert();
  769. assertTrue(client.flush(randomCollectionName).ok());
  770. Response getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
  771. assertTrue(getCollectionStatsResponse.ok());
  772. String jsonString = getCollectionStatsResponse.getMessage();
  773. JSONObject jsonInfo = new JSONObject(jsonString);
  774. assertEquals(jsonInfo.getInt("row_count"), size);
  775. JSONArray partitions = jsonInfo.getJSONArray("partitions");
  776. JSONObject partitionInfo = partitions.getJSONObject(0);
  777. assertEquals(partitionInfo.getString("tag"), "_default");
  778. assertEquals(partitionInfo.getInt("row_count"), size);
  779. }
  780. @org.junit.jupiter.api.Test
  781. @SuppressWarnings("unchecked")
  782. void getEntityByID() {
  783. List<Long> intValues = new ArrayList<>(size);
  784. List<Float> floatValues = new ArrayList<>(size);
  785. List<List<Float>> vectors = generateFloatVectors(size, dimension);
  786. for (int i = 0; i < size; i++) {
  787. intValues.add((long) i);
  788. floatValues.add((float) i);
  789. }
  790. vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
  791. List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
  792. InsertParam insertParam =
  793. new Builder(randomCollectionName)
  794. .field(new FieldBuilder("int64", DataType.INT64)
  795. .values(intValues)
  796. .build())
  797. .field(new FieldBuilder("float", DataType.FLOAT)
  798. .values(floatValues)
  799. .build())
  800. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  801. .values(vectors)
  802. .build())
  803. .withEntityIds(insertIds)
  804. .build();
  805. InsertResponse insertResponse = client.insert(insertParam);
  806. assertTrue(insertResponse.ok());
  807. List<Long> entityIds = insertResponse.getEntityIds();
  808. assertEquals(size, entityIds.size());
  809. assertTrue(client.flush(randomCollectionName).ok());
  810. GetEntityByIDResponse getEntityByIDResponse =
  811. client.getEntityByID(randomCollectionName, entityIds.subList(0, 100));
  812. assertTrue(getEntityByIDResponse.ok());
  813. int vecIndex = 0;
  814. List<Map<String, Object>> fieldsMap = getEntityByIDResponse.getFieldsMap();
  815. assertTrue(fieldsMap.get(vecIndex).get("float_vec") instanceof List);
  816. List<Float> first = (List<Float>) (fieldsMap.get(vecIndex).get("float_vec"));
  817. assertArrayEquals(first.toArray(), vectors.get(0).toArray());
  818. }
  819. @org.junit.jupiter.api.Test
  820. @SuppressWarnings("unchecked")
  821. void getEntityByIDBinary() {
  822. final int binaryDimension = 64;
  823. String binaryCollectionName = generator.generate(10);
  824. CollectionMapping collectionMapping = CollectionMapping
  825. .create(binaryCollectionName)
  826. .addVectorField("binary_vec", DataType.VECTOR_BINARY, binaryDimension)
  827. .setParamsInJson(new JsonBuilder().param("auto_id", false).build());
  828. client.createCollection(collectionMapping);
  829. List<List<Byte>> vectors = generateBinaryVectors(size, binaryDimension);
  830. List<Long> entityIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
  831. InsertParam insertParam =
  832. new Builder(binaryCollectionName)
  833. .field(new FieldBuilder("binary_vec", DataType.VECTOR_BINARY)
  834. .values(vectors)
  835. .build())
  836. .withEntityIds(entityIds)
  837. .build();
  838. InsertResponse insertResponse = client.insert(insertParam);
  839. assertTrue(insertResponse.ok());
  840. assertEquals(size, insertResponse.getEntityIds().size());
  841. assertTrue(client.flush(binaryCollectionName).ok());
  842. GetEntityByIDResponse getEntityByIDResponse =
  843. client.getEntityByID(binaryCollectionName, entityIds.subList(0, 100));
  844. assertTrue(getEntityByIDResponse.ok());
  845. assertEquals(getEntityByIDResponse.getFieldsMap().size(), 100);
  846. List<Map<String, Object>> fieldsMap = getEntityByIDResponse.getFieldsMap();
  847. assertTrue(fieldsMap.get(0).get("binary_vec") instanceof List);
  848. List<Byte> first = (List<Byte>) (fieldsMap.get(0).get("binary_vec"));
  849. assertArrayEquals(first.toArray(), vectors.get(0).toArray());
  850. }
  851. @org.junit.jupiter.api.Test
  852. void getEntityIds() {
  853. insert();
  854. assertTrue(client.flush(randomCollectionName).ok());
  855. Response getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
  856. assertTrue(getCollectionStatsResponse.ok());
  857. JSONObject jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
  858. JSONObject segmentInfo =
  859. jsonInfo
  860. .getJSONArray("partitions")
  861. .getJSONObject(0)
  862. .getJSONArray("segments")
  863. .getJSONObject(0);
  864. ListIDInSegmentResponse listIDInSegmentResponse =
  865. client.listIDInSegment(randomCollectionName, segmentInfo.getLong("id"));
  866. assertTrue(listIDInSegmentResponse.ok());
  867. assertFalse(listIDInSegmentResponse.getIds().isEmpty());
  868. }
  869. @org.junit.jupiter.api.Test
  870. void deleteEntityByID() {
  871. List<Long> intValues = new ArrayList<>(size);
  872. List<Float> floatValues = new ArrayList<>(size);
  873. List<List<Float>> vectors = generateFloatVectors(size, dimension);
  874. for (int i = 0; i < size; i++) {
  875. intValues.add((long) i);
  876. floatValues.add((float) i);
  877. }
  878. vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
  879. List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
  880. InsertParam insertParam =
  881. new Builder(randomCollectionName)
  882. .field(new FieldBuilder("int64", DataType.INT64)
  883. .values(intValues)
  884. .build())
  885. .field(new FieldBuilder("float", DataType.FLOAT)
  886. .values(floatValues)
  887. .build())
  888. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  889. .values(vectors)
  890. .build())
  891. .withEntityIds(insertIds)
  892. .build();
  893. InsertResponse insertResponse = client.insert(insertParam);
  894. assertTrue(insertResponse.ok());
  895. assertEquals(size, insertResponse.getEntityIds().size());
  896. assertTrue(client.flush(randomCollectionName).ok());
  897. assertTrue(client.deleteEntityByID(randomCollectionName,
  898. insertResponse.getEntityIds().subList(0, 100)).ok());
  899. assertTrue(client.flush(randomCollectionName).ok());
  900. assertEquals(client.countEntities(randomCollectionName).getCollectionEntityCount(), size - 100);
  901. }
  902. @org.junit.jupiter.api.Test
  903. void flush() {
  904. assertTrue(client.flush(randomCollectionName).ok());
  905. }
  906. @org.junit.jupiter.api.Test
  907. void flushAsync() throws ExecutionException, InterruptedException {
  908. assertTrue(client.flushAsync(randomCollectionName).get().ok());
  909. }
  910. @org.junit.jupiter.api.Test
  911. void compact() {
  912. List<Long> intValues = new ArrayList<>(size);
  913. List<Float> floatValues = new ArrayList<>(size);
  914. List<List<Float>> vectors = generateFloatVectors(size, dimension);
  915. for (int i = 0; i < size; i++) {
  916. intValues.add((long) i);
  917. floatValues.add((float) i);
  918. }
  919. vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
  920. List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
  921. InsertParam insertParam =
  922. new Builder(randomCollectionName)
  923. .field(new FieldBuilder("int64", DataType.INT64)
  924. .values(intValues)
  925. .build())
  926. .field(new FieldBuilder("float", DataType.FLOAT)
  927. .values(floatValues)
  928. .build())
  929. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  930. .values(vectors)
  931. .build())
  932. .withEntityIds(insertIds)
  933. .build();
  934. InsertResponse insertResponse = client.insert(insertParam);
  935. assertTrue(insertResponse.ok());
  936. assertEquals(size, insertResponse.getEntityIds().size());
  937. assertTrue(client.flush(randomCollectionName).ok());
  938. Response getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
  939. assertTrue(getCollectionStatsResponse.ok());
  940. JSONObject jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
  941. long previousSegmentSize =
  942. jsonInfo
  943. .getJSONArray("partitions")
  944. .getJSONObject(0)
  945. .getLong("data_size");
  946. assertTrue(
  947. client.deleteEntityByID(randomCollectionName,
  948. insertResponse.getEntityIds().subList(0, size / 2)).ok());
  949. assertTrue(client.flush(randomCollectionName).ok());
  950. assertTrue(client.compact(
  951. new CompactParam.Builder(randomCollectionName).withThreshold(0.2).build()).ok());
  952. getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
  953. assertTrue(getCollectionStatsResponse.ok());
  954. jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
  955. long currentSegmentSize =
  956. jsonInfo
  957. .getJSONArray("partitions")
  958. .getJSONObject(0)
  959. .getLong("data_size");
  960. assertTrue(currentSegmentSize < previousSegmentSize);
  961. }
  962. @org.junit.jupiter.api.Test
  963. void compactAsync() throws ExecutionException, InterruptedException {
  964. List<Long> intValues = new ArrayList<>(size);
  965. List<Float> floatValues = new ArrayList<>(size);
  966. List<List<Float>> vectors = generateFloatVectors(size, dimension);
  967. for (int i = 0; i < size; i++) {
  968. intValues.add((long) i);
  969. floatValues.add((float) i);
  970. }
  971. vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
  972. List<Long> insertIds = LongStream.range(0, size).boxed().collect(Collectors.toList());
  973. InsertParam insertParam =
  974. new Builder(randomCollectionName)
  975. .field(new FieldBuilder("int64", DataType.INT64)
  976. .values(intValues)
  977. .build())
  978. .field(new FieldBuilder("float", DataType.FLOAT)
  979. .values(floatValues)
  980. .build())
  981. .field(new FieldBuilder("float_vec", DataType.VECTOR_FLOAT)
  982. .values(vectors)
  983. .build())
  984. .withEntityIds(insertIds)
  985. .build();
  986. InsertResponse insertResponse = client.insert(insertParam);
  987. assertTrue(insertResponse.ok());
  988. assertEquals(size, insertResponse.getEntityIds().size());
  989. assertTrue(client.flush(randomCollectionName).ok());
  990. Response getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
  991. assertTrue(getCollectionStatsResponse.ok());
  992. JSONObject jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
  993. JSONObject segmentInfo =
  994. jsonInfo
  995. .getJSONArray("partitions")
  996. .getJSONObject(0)
  997. .getJSONArray("segments")
  998. .getJSONObject(0);
  999. long previousSegmentSize = segmentInfo.getLong("data_size");
  1000. assertTrue(
  1001. client.deleteEntityByID(randomCollectionName,
  1002. insertResponse.getEntityIds().subList(0, size / 2)).ok());
  1003. assertTrue(client.flush(randomCollectionName).ok());
  1004. assertTrue(client.compactAsync(
  1005. new CompactParam.Builder(randomCollectionName).withThreshold(0.8).build()).get().ok());
  1006. getCollectionStatsResponse = client.getCollectionStats(randomCollectionName);
  1007. assertTrue(getCollectionStatsResponse.ok());
  1008. jsonInfo = new JSONObject(getCollectionStatsResponse.getMessage());
  1009. segmentInfo =
  1010. jsonInfo
  1011. .getJSONArray("partitions")
  1012. .getJSONObject(0)
  1013. .getJSONArray("segments")
  1014. .getJSONObject(0);
  1015. long currentSegmentSize = segmentInfo.getLong("data_size");
  1016. assertFalse(currentSegmentSize < previousSegmentSize); // threshold 0.8 > 0.5, no compact
  1017. }
  1018. }