MilvusGrpcClient.java 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.client;
  20. import com.google.common.util.concurrent.FutureCallback;
  21. import com.google.common.util.concurrent.Futures;
  22. import com.google.common.util.concurrent.ListenableFuture;
  23. import com.google.common.util.concurrent.MoreExecutors;
  24. import io.grpc.CallOptions;
  25. import io.grpc.Channel;
  26. import io.grpc.ClientCall;
  27. import io.grpc.ClientInterceptor;
  28. import io.grpc.ManagedChannel;
  29. import io.grpc.ManagedChannelBuilder;
  30. import io.grpc.MethodDescriptor;
  31. import io.grpc.StatusRuntimeException;
  32. import io.milvus.client.exception.ClientSideMilvusException;
  33. import io.milvus.client.exception.MilvusException;
  34. import io.milvus.client.exception.ServerSideMilvusException;
  35. import io.milvus.client.exception.UnsupportedServerVersion;
  36. import io.milvus.grpc.*;
  37. import org.apache.commons.lang3.ArrayUtils;
  38. import org.json.JSONObject;
  39. import org.slf4j.Logger;
  40. import org.slf4j.LoggerFactory;
  41. import javax.annotation.Nonnull;
  42. import java.nio.ByteBuffer;
  43. import java.util.ArrayList;
  44. import java.util.Arrays;
  45. import java.util.Collections;
  46. import java.util.HashMap;
  47. import java.util.Iterator;
  48. import java.util.List;
  49. import java.util.Map;
  50. import java.util.concurrent.TimeUnit;
  51. import java.util.function.Function;
  52. import java.util.function.Supplier;
  53. import java.util.stream.Collectors;
  54. /** Actual implementation of interface <code>MilvusClient</code> */
  55. public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
  56. private static final Logger logger = LoggerFactory.getLogger(MilvusGrpcClient.class);
  57. private static final String SUPPORTED_SERVER_VERSION = "0.11";
  58. private final String target;
  59. private final ManagedChannel channel;
  60. private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
  61. private final MilvusServiceGrpc.MilvusServiceFutureStub futureStub;
  62. public MilvusGrpcClient(ConnectParam connectParam) {
  63. target = connectParam.getTarget();
  64. channel = ManagedChannelBuilder
  65. .forTarget(connectParam.getTarget())
  66. .usePlaintext()
  67. .maxInboundMessageSize(Integer.MAX_VALUE)
  68. .defaultLoadBalancingPolicy(connectParam.getDefaultLoadBalancingPolicy())
  69. .keepAliveTime(connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
  70. .keepAliveTimeout(connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
  71. .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
  72. .idleTimeout(connectParam.getIdleTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
  73. .build();
  74. blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
  75. futureStub = MilvusServiceGrpc.newFutureStub(channel);
  76. try {
  77. String serverVersion = getServerVersion();
  78. if (!serverVersion.matches("^" + SUPPORTED_SERVER_VERSION + "(\\..*)?$")) {
  79. throw new UnsupportedServerVersion(connectParam.getTarget(), SUPPORTED_SERVER_VERSION, serverVersion);
  80. }
  81. } catch (Throwable t) {
  82. channel.shutdownNow();
  83. throw t;
  84. }
  85. }
  86. @Override
  87. public String target() {
  88. return target;
  89. }
  90. @Override
  91. protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
  92. return blockingStub;
  93. }
  94. @Override
  95. protected MilvusServiceGrpc.MilvusServiceFutureStub futureStub() {
  96. return futureStub;
  97. }
  98. @Override
  99. protected boolean maybeAvailable() {
  100. switch (channel.getState(false)) {
  101. case IDLE:
  102. case CONNECTING:
  103. case READY:
  104. return true;
  105. default:
  106. return false;
  107. }
  108. }
  109. @Override
  110. public void close(long maxWaitSeconds) {
  111. channel.shutdown();
  112. try {
  113. channel.awaitTermination(maxWaitSeconds, TimeUnit.SECONDS);
  114. } catch (InterruptedException ex) {
  115. logger.warn("Milvus client close interrupted");
  116. channel.shutdownNow();
  117. Thread.currentThread().interrupt();
  118. }
  119. }
  120. public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
  121. final long timeoutMillis = timeoutUnit.toMillis(timeout);
  122. final TimeoutInterceptor timeoutInterceptor = new TimeoutInterceptor(timeoutMillis);
  123. final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub =
  124. this.blockingStub.withInterceptors(timeoutInterceptor);
  125. final MilvusServiceGrpc.MilvusServiceFutureStub futureStub =
  126. this.futureStub.withInterceptors(timeoutInterceptor);
  127. return new AbstractMilvusGrpcClient() {
  128. @Override
  129. public String target() {
  130. return MilvusGrpcClient.this.target();
  131. }
  132. @Override
  133. protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
  134. return blockingStub;
  135. }
  136. @Override
  137. protected MilvusServiceGrpc.MilvusServiceFutureStub futureStub() {
  138. return futureStub;
  139. }
  140. @Override
  141. protected boolean maybeAvailable() {
  142. return MilvusGrpcClient.this.maybeAvailable();
  143. }
  144. @Override
  145. public void close(long maxWaitSeconds) {
  146. MilvusGrpcClient.this.close(maxWaitSeconds);
  147. }
  148. @Override
  149. public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
  150. return MilvusGrpcClient.this.withTimeout(timeout, timeoutUnit);
  151. }
  152. };
  153. }
  154. private static class TimeoutInterceptor implements ClientInterceptor {
  155. private long timeoutMillis;
  156. TimeoutInterceptor(long timeoutMillis) {
  157. this.timeoutMillis = timeoutMillis;
  158. }
  159. @Override
  160. public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
  161. MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
  162. return next.newCall(method, callOptions.withDeadlineAfter(timeoutMillis, TimeUnit.MILLISECONDS));
  163. }
  164. }
  165. }
  166. abstract class AbstractMilvusGrpcClient implements MilvusClient {
  167. private static final Logger logger = LoggerFactory.getLogger(AbstractMilvusGrpcClient.class);
  168. protected abstract MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub();
  169. protected abstract MilvusServiceGrpc.MilvusServiceFutureStub futureStub();
  170. protected abstract boolean maybeAvailable();
  171. private void translateExceptions(Runnable body) {
  172. translateExceptions(() -> {
  173. body.run();
  174. return null;
  175. });
  176. }
  177. @SuppressWarnings("unchecked")
  178. private <T> T translateExceptions(Supplier<T> body) {
  179. try {
  180. T result = body.get();
  181. if (result instanceof ListenableFuture) {
  182. ListenableFuture futureResult = (ListenableFuture) result;
  183. result = (T) Futures.catching(
  184. futureResult, Throwable.class, this::translate, MoreExecutors.directExecutor());
  185. }
  186. return result;
  187. } catch (Throwable e) {
  188. return translate(e);
  189. }
  190. }
  191. private <R> R translate(Throwable e) {
  192. if (e instanceof MilvusException) {
  193. throw (MilvusException) e;
  194. } else if (e.getCause() == null || e.getCause() == e) {
  195. throw new ClientSideMilvusException(target(), e);
  196. } else {
  197. return translate(e.getCause());
  198. }
  199. }
  200. private Void checkResponseStatus(Status status) {
  201. if (status.getErrorCode() != ErrorCode.SUCCESS) {
  202. throw new ServerSideMilvusException(target(), status);
  203. }
  204. return null;
  205. }
  206. @Override
  207. public void createCollection(@Nonnull CollectionMapping collectionMapping) {
  208. translateExceptions(() -> {
  209. Status response = blockingStub().createCollection(collectionMapping.grpc());
  210. checkResponseStatus(response);
  211. });
  212. }
  213. @Override
  214. public boolean hasCollection(@Nonnull String collectionName) {
  215. return translateExceptions(() -> {
  216. CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
  217. BoolReply response = blockingStub().hasCollection(request);
  218. checkResponseStatus(response.getStatus());
  219. return response.getBoolReply();
  220. });
  221. }
  222. @Override
  223. public void dropCollection(@Nonnull String collectionName) {
  224. translateExceptions(() -> {
  225. CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
  226. Status response = blockingStub().dropCollection(request);
  227. checkResponseStatus(response);
  228. });
  229. }
  230. @Override
  231. public void createIndex(@Nonnull Index index) {
  232. translateExceptions(() -> {
  233. Futures.getUnchecked(createIndexAsync(index));
  234. });
  235. }
  236. @Override
  237. public ListenableFuture<Void> createIndexAsync(@Nonnull Index index) {
  238. return translateExceptions(() -> {
  239. IndexParam request = index.grpc();
  240. ListenableFuture<Status> responseFuture = futureStub().createIndex(request);
  241. return Futures.transform(responseFuture, this::checkResponseStatus, MoreExecutors.directExecutor());
  242. });
  243. }
  244. @Override
  245. public void createPartition(String collectionName, String tag) {
  246. translateExceptions(() -> {
  247. PartitionParam request = PartitionParam.newBuilder().setCollectionName(collectionName).setTag(tag).build();
  248. Status response = blockingStub().createPartition(request);
  249. checkResponseStatus(response);
  250. });
  251. }
  252. @Override
  253. public boolean hasPartition(String collectionName, String tag) {
  254. return translateExceptions(() -> {
  255. PartitionParam request = PartitionParam.newBuilder().setCollectionName(collectionName).setTag(tag).build();
  256. BoolReply response = blockingStub().hasPartition(request);
  257. checkResponseStatus(response.getStatus());
  258. return response.getBoolReply();
  259. });
  260. }
  261. @Override
  262. public List<String> listPartitions(String collectionName) {
  263. return translateExceptions(() -> {
  264. CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
  265. PartitionList response = blockingStub().showPartitions(request);
  266. checkResponseStatus(response.getStatus());
  267. return response.getPartitionTagArrayList();
  268. });
  269. }
  270. @Override
  271. public void dropPartition(String collectionName, String tag) {
  272. translateExceptions(() -> {
  273. PartitionParam request =
  274. PartitionParam.newBuilder().setCollectionName(collectionName).setTag(tag).build();
  275. Status response = blockingStub().dropPartition(request);
  276. checkResponseStatus(response);
  277. });
  278. }
  279. @Override
  280. @SuppressWarnings("unchecked")
  281. public List<Long> insert(@Nonnull InsertParam insertParam) {
  282. return translateExceptions(() -> Futures.getUnchecked(insertAsync(insertParam)));
  283. }
  284. @Override
  285. @SuppressWarnings("unchecked")
  286. public ListenableFuture<List<Long>> insertAsync(@Nonnull InsertParam insertParam) {
  287. return translateExceptions(() -> {
  288. io.milvus.grpc.InsertParam request = insertParam.grpc();
  289. ListenableFuture<EntityIds> responseFuture = futureStub().insert(request);
  290. return Futures.transform(responseFuture, entityIds -> {
  291. checkResponseStatus(entityIds.getStatus());
  292. return entityIds.getEntityIdArrayList();
  293. }, MoreExecutors.directExecutor());
  294. });
  295. }
  296. @Override
  297. public SearchResult search(@Nonnull SearchParam searchParam) {
  298. return translateExceptions(() -> Futures.getUnchecked(searchAsync(searchParam)));
  299. }
  300. @Override
  301. public ListenableFuture<SearchResult> searchAsync(@Nonnull SearchParam searchParam) {
  302. return translateExceptions(() -> {
  303. io.milvus.grpc.SearchParam request = searchParam.grpc();
  304. ListenableFuture<QueryResult> responseFuture = futureStub().search(request);
  305. return Futures.transform(responseFuture, queryResult -> {
  306. checkResponseStatus(queryResult.getStatus());
  307. return buildSearchResponse(queryResult);
  308. }, MoreExecutors.directExecutor());
  309. });
  310. }
  311. @Override
  312. public CollectionMapping getCollectionInfo(@Nonnull String collectionName) {
  313. return translateExceptions(() -> {
  314. CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
  315. Mapping response = blockingStub().describeCollection(request);
  316. checkResponseStatus(response.getStatus());
  317. return new CollectionMapping(response);
  318. });
  319. }
  320. @Override
  321. public List<String> listCollections() {
  322. return translateExceptions(() -> {
  323. Command request = Command.newBuilder().setCmd("").build();
  324. CollectionNameList response = blockingStub().showCollections(request);
  325. checkResponseStatus(response.getStatus());
  326. return response.getCollectionNamesList();
  327. });
  328. }
  329. @Override
  330. public long countEntities(@Nonnull String collectionName) {
  331. return translateExceptions(() -> {
  332. CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
  333. CollectionRowCount response = blockingStub().countCollection(request);
  334. checkResponseStatus(response.getStatus());
  335. return response.getCollectionRowCount();
  336. });
  337. }
  338. @Override
  339. public String getServerStatus() {
  340. return command("status");
  341. }
  342. @Override
  343. public String getServerVersion() {
  344. return command("version");
  345. }
  346. public String command(@Nonnull String command) {
  347. return translateExceptions(() -> {
  348. Command request = Command.newBuilder().setCmd(command).build();
  349. StringReply response = blockingStub().cmd(request);
  350. checkResponseStatus(response.getStatus());
  351. return response.getStringReply();
  352. });
  353. }
  354. @Override
  355. public void loadCollection(@Nonnull String collectionName) {
  356. translateExceptions(() -> {
  357. CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
  358. Status response = blockingStub().preloadCollection(request);
  359. checkResponseStatus(response);
  360. });
  361. }
  362. @Override
  363. public void dropIndex(String collectionName, String fieldName) {
  364. translateExceptions(() -> {
  365. IndexParam request = IndexParam.newBuilder()
  366. .setCollectionName(collectionName)
  367. .setFieldName(fieldName)
  368. .build();
  369. Status response = blockingStub().dropIndex(request);
  370. checkResponseStatus(response);
  371. });
  372. }
  373. @Override
  374. public String getCollectionStats(String collectionName) {
  375. return translateExceptions(() -> {
  376. CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
  377. CollectionInfo response = blockingStub().showCollectionInfo(request);
  378. checkResponseStatus(response.getStatus());
  379. return response.getJsonInfo();
  380. });
  381. }
  382. @Override
  383. public List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames) {
  384. return translateExceptions(() -> {
  385. EntityIdentity request = EntityIdentity.newBuilder()
  386. .setCollectionName(collectionName)
  387. .addAllIdArray(ids)
  388. .addAllFieldNames(fieldNames)
  389. .build();
  390. Entities response = blockingStub().getEntityByID(request);
  391. checkResponseStatus(response.getStatus());
  392. Map<String, Iterator<?>> fieldIterators = response.getFieldsList()
  393. .stream()
  394. .collect(Collectors.toMap(FieldValue::getFieldName, this::fieldValueIterator));
  395. return response.getValidRowList().stream()
  396. .map(valid -> valid ? toMap(fieldIterators) : Collections.<String, Object>emptyMap())
  397. .collect(Collectors.toList());
  398. });
  399. }
  400. private Map<String, Object> toMap(Map<String, Iterator<?>> fieldIterators) {
  401. return fieldIterators.entrySet().stream()
  402. .collect(Collectors.toMap(
  403. entry -> entry.getKey(),
  404. entry -> entry.getValue().next()));
  405. }
  406. private Iterator<?> fieldValueIterator(FieldValue fieldValue) {
  407. if (fieldValue.hasAttrRecord()) {
  408. AttrRecord record = fieldValue.getAttrRecord();
  409. if (record.getInt32ValueCount() > 0) {
  410. return record.getInt32ValueList().iterator();
  411. } else if (record.getInt64ValueCount() > 0) {
  412. return record.getInt64ValueList().iterator();
  413. } else if (record.getFloatValueCount() > 0) {
  414. return record.getFloatValueList().iterator();
  415. } else if (record.getDoubleValueCount() > 0) {
  416. return record.getDoubleValueList().iterator();
  417. }
  418. }
  419. VectorRecord record = fieldValue.getVectorRecord();
  420. return record.getRecordsList().stream()
  421. .map(row -> row.getFloatDataCount() > 0 ? row.getFloatDataList() : row.getBinaryData().asReadOnlyByteBuffer())
  422. .iterator();
  423. }
  424. @Override
  425. public List<Map<String, Object>> getEntityByID(String collectionName, List<Long> ids) {
  426. return getEntityByID(collectionName, ids, Collections.emptyList());
  427. }
  428. @Override
  429. public List<Long> listIDInSegment(String collectionName, Long segmentId) {
  430. return translateExceptions(() -> {
  431. GetEntityIDsParam request = GetEntityIDsParam.newBuilder()
  432. .setCollectionName(collectionName)
  433. .setSegmentId(segmentId)
  434. .build();
  435. EntityIds response = blockingStub().getEntityIDs(request);
  436. checkResponseStatus(response.getStatus());
  437. return response.getEntityIdArrayList();
  438. });
  439. }
  440. @Override
  441. public void deleteEntityByID(String collectionName, List<Long> ids) {
  442. translateExceptions(() -> {
  443. DeleteByIDParam request = DeleteByIDParam.newBuilder()
  444. .setCollectionName(collectionName)
  445. .addAllIdArray(ids)
  446. .build();
  447. Status response = blockingStub().deleteByID(request);
  448. checkResponseStatus(response);
  449. });
  450. }
  451. @Override
  452. public void flush(List<String> collectionNames) {
  453. translateExceptions(() -> Futures.getUnchecked(flushAsync(collectionNames)));
  454. }
  455. @Override
  456. public ListenableFuture<Void> flushAsync(@Nonnull List<String> collectionNames) {
  457. return translateExceptions(() -> {
  458. FlushParam request = FlushParam.newBuilder().addAllCollectionNameArray(collectionNames).build();
  459. ListenableFuture<Status> response = futureStub().flush(request);
  460. return Futures.transform(response, this::checkResponseStatus, MoreExecutors.directExecutor());
  461. });
  462. }
  463. @Override
  464. public void flush(String collectionName) {
  465. flush(Collections.singletonList(collectionName));
  466. }
  467. @Override
  468. public ListenableFuture<Void> flushAsync(String collectionName) {
  469. return flushAsync(Collections.singletonList(collectionName));
  470. }
  471. @Override
  472. public Response compact(CompactParam compactParam) {
  473. if (!maybeAvailable()) {
  474. logWarning("You are not connected to Milvus server");
  475. return new Response(Response.Status.CLIENT_NOT_CONNECTED);
  476. }
  477. io.milvus.grpc.CompactParam request =
  478. io.milvus.grpc.CompactParam.newBuilder()
  479. .setCollectionName(compactParam.getCollectionName())
  480. .setThreshold(compactParam.getThreshold())
  481. .build();
  482. Status response;
  483. try {
  484. response = blockingStub().compact(request);
  485. if (response.getErrorCode() == ErrorCode.SUCCESS) {
  486. logInfo("Compacted collection `{}` successfully!", compactParam.getCollectionName());
  487. return new Response(Response.Status.SUCCESS);
  488. } else {
  489. logError("Compact collection `{}` failed:\n{}",
  490. compactParam.getCollectionName(), response.toString());
  491. return new Response(
  492. Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
  493. }
  494. } catch (StatusRuntimeException e) {
  495. logError("compact RPC failed:\n{}", e.getStatus().toString());
  496. return new Response(Response.Status.RPC_ERROR, e.toString());
  497. }
  498. }
  499. @Override
  500. public ListenableFuture<Response> compactAsync(@Nonnull CompactParam compactParam) {
  501. if (!maybeAvailable()) {
  502. logWarning("You are not connected to Milvus server");
  503. return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
  504. }
  505. io.milvus.grpc.CompactParam request =
  506. io.milvus.grpc.CompactParam.newBuilder()
  507. .setCollectionName(compactParam.getCollectionName())
  508. .setThreshold(compactParam.getThreshold())
  509. .build();
  510. ListenableFuture<Status> response;
  511. response = futureStub().compact(request);
  512. Futures.addCallback(
  513. response,
  514. new FutureCallback<Status>() {
  515. @Override
  516. public void onSuccess(Status result) {
  517. if (result.getErrorCode() == ErrorCode.SUCCESS) {
  518. logInfo("Compacted collection `{}` successfully!",
  519. compactParam.getCollectionName());
  520. } else {
  521. logError("Compact collection `{}` failed:\n{}",
  522. compactParam.getCollectionName(), result.toString());
  523. }
  524. }
  525. @Override
  526. public void onFailure(Throwable t) {
  527. logError("CompactAsync failed:\n{}", t.getMessage());
  528. }
  529. },
  530. MoreExecutors.directExecutor());
  531. return Futures.transform(
  532. response, transformStatusToResponseFunc::apply, MoreExecutors.directExecutor());
  533. }
  534. ///////////////////// Util Functions/////////////////////
  535. Function<Status, Response> transformStatusToResponseFunc =
  536. status -> {
  537. if (status.getErrorCode() == ErrorCode.SUCCESS) {
  538. return new Response(Response.Status.SUCCESS);
  539. } else {
  540. return new Response(
  541. Response.Status.valueOf(status.getErrorCodeValue()), status.getReason());
  542. }
  543. };
  544. private SearchResult buildSearchResponse(QueryResult topKQueryResult) {
  545. final int numQueries = (int) topKQueryResult.getRowNum();
  546. final int topK = numQueries == 0 ? 0 : topKQueryResult.getDistancesCount() / numQueries;
  547. List<List<Long>> resultIdsList = new ArrayList<>(numQueries);
  548. List<List<Float>> resultDistancesList = new ArrayList<>(numQueries);
  549. List<List<Map<String, Object>>> resultFieldsMap = new ArrayList<>(numQueries);
  550. Entities entities = topKQueryResult.getEntities();
  551. List<Long> queryIdsList = entities.getIdsList();
  552. List<Float> queryDistancesList = topKQueryResult.getDistancesList();
  553. // If fields specified, put it into searchResponse
  554. List<Map<String, Object>> fieldsMap = new ArrayList<>();
  555. for (int i = 0; i < queryIdsList.size(); i++) {
  556. fieldsMap.add(new HashMap<>());
  557. }
  558. if (entities.getValidRowCount() != 0) {
  559. List<FieldValue> fieldValueList = entities.getFieldsList();
  560. for (FieldValue fieldValue : fieldValueList) {
  561. String fieldName = fieldValue.getFieldName();
  562. for (int j = 0; j < queryIdsList.size(); j++) {
  563. if (fieldValue.getAttrRecord().getInt32ValueCount() > 0) {
  564. fieldsMap.get(j).put(fieldName, fieldValue.getAttrRecord().getInt32ValueList().get(j));
  565. } else if (fieldValue.getAttrRecord().getInt64ValueCount() > 0) {
  566. fieldsMap.get(j).put(fieldName, fieldValue.getAttrRecord().getInt64ValueList().get(j));
  567. } else if (fieldValue.getAttrRecord().getDoubleValueCount() > 0) {
  568. fieldsMap.get(j).put(fieldName, fieldValue.getAttrRecord().getDoubleValueList().get(j));
  569. } else if (fieldValue.getAttrRecord().getFloatValueCount() > 0) {
  570. fieldsMap.get(j).put(fieldName, fieldValue.getAttrRecord().getFloatValueList().get(j));
  571. } else {
  572. // the object is vector
  573. List<VectorRowRecord> vectorRowRecordList =
  574. fieldValue.getVectorRecord().getRecordsList();
  575. if (vectorRowRecordList.get(j).getFloatDataCount() > 0) {
  576. fieldsMap.get(j).put(fieldName, vectorRowRecordList.get(j).getFloatDataList());
  577. } else {
  578. fieldsMap.get(j).put(fieldName, vectorRowRecordList.get(j).getBinaryData().asReadOnlyByteBuffer());
  579. }
  580. }
  581. }
  582. }
  583. }
  584. if (topK > 0) {
  585. for (int i = 0; i < numQueries; i++) {
  586. // Process result of query i
  587. int pos = i * topK;
  588. while (pos < i * topK + topK && queryIdsList.get(pos) != -1) {
  589. pos++;
  590. }
  591. resultIdsList.add(queryIdsList.subList(i * topK, pos));
  592. resultDistancesList.add(queryDistancesList.subList(i * topK, pos));
  593. resultFieldsMap.add(fieldsMap.subList(i * topK, pos));
  594. }
  595. }
  596. return new SearchResult(numQueries, topK, resultIdsList, resultDistancesList, resultFieldsMap);
  597. }
  598. private String kvListToString(List<KeyValuePair> kv) {
  599. JSONObject jsonObject = new JSONObject();
  600. for (KeyValuePair keyValuePair : kv) {
  601. if (keyValuePair.getValue().equals("null")) continue;
  602. jsonObject.put(keyValuePair.getKey(), keyValuePair.getValue());
  603. }
  604. return jsonObject.toString();
  605. }
  606. ///////////////////// Log Functions//////////////////////
  607. private void logInfo(String msg, Object... params) {
  608. logger.info(msg, params);
  609. }
  610. private void logWarning(String msg, Object... params) {
  611. logger.warn(msg, params);
  612. }
  613. private void logError(String msg, Object... params) {
  614. logger.error(msg, params);
  615. }
  616. }