Browse Source

Change GRPC proto (and related code) to increase search result's transmission speed

Zhiru Zhu 5 years ago
parent
commit
b96790499a

+ 6 - 1
pom.xml

@@ -25,7 +25,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>0.3.0-SNAPSHOT</version>
+    <version>0.3.0RC-SNAPSHOT</version>
     <packaging>jar</packaging>
 
     <name>io.milvus:milvus-sdk-java</name>
@@ -134,6 +134,11 @@
             <artifactId>commons-text</artifactId>
             <version>1.6</version>
         </dependency>
+        <dependency>
+            <groupId>org.apache.commons</groupId>
+            <artifactId>commons-collections4</artifactId>
+            <version>4.4</version>
+        </dependency>
     </dependencies>
 
     <profiles>

+ 4 - 2
src/main/java/io/milvus/client/MilvusClient.java

@@ -19,6 +19,8 @@
 
 package io.milvus.client;
 
+import java.io.IOException;
+
 /** The Milvus Client Interface */
 public interface MilvusClient {
 
@@ -175,7 +177,7 @@ public interface MilvusClient {
    * @see SearchResponse.QueryResult
    * @see Response
    */
-  SearchResponse search(SearchParam searchParam);
+  SearchResponse search(SearchParam searchParam) throws IOException;
 
   /**
    * Searches vectors in specific files specified by <code>searchInFilesParam</code>
@@ -202,7 +204,7 @@ public interface MilvusClient {
    * @see SearchResponse.QueryResult
    * @see Response
    */
-  SearchResponse searchInFiles(SearchInFilesParam searchInFilesParam);
+  SearchResponse searchInFiles(SearchInFilesParam searchInFilesParam) throws IOException;
 
   /**
    * Describes table

+ 48 - 41
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -23,6 +23,7 @@ import io.grpc.ConnectivityState;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
 import io.grpc.StatusRuntimeException;
+import org.apache.commons.collections4.ListUtils;
 
 import javax.annotation.Nonnull;
 import java.text.SimpleDateFormat;
@@ -321,8 +322,9 @@ public class MilvusGrpcClient implements MilvusClient {
 
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
-      return new SearchResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
+      SearchResponse searchResponse = new SearchResponse();
+      searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
+      return searchResponse;
     }
 
     List<io.milvus.grpc.RowRecord> queryRowRecordList = getQueryRowRecordList(searchParam);
@@ -338,29 +340,32 @@ public class MilvusGrpcClient implements MilvusClient {
             .setNprobe(searchParam.getNProbe())
             .build();
 
-    io.milvus.grpc.TopKQueryResultList response;
+    io.milvus.grpc.TopKQueryResult response;
 
     try {
       response = blockingStub.search(request);
 
       if (response.getStatus().getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
-        List<List<SearchResponse.QueryResult>> queryResultsList = getQueryResultsList(response);
+        SearchResponse searchResponse = buildSearchResponse(response);
+        searchResponse.setResponse(new Response(Response.Status.SUCCESS));
         logInfo(
             "Search completed successfully! Returned results for {0} queries",
-            queryResultsList.size());
-        return new SearchResponse(new Response(Response.Status.SUCCESS), queryResultsList);
+            searchResponse.getNumQueries());
+        return searchResponse;
       } else {
         logSevere("Search failed:\n{0}", response.toString());
-        return new SearchResponse(
+        SearchResponse searchResponse = new SearchResponse();
+        searchResponse.setResponse(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
-                response.getStatus().getReason()),
-            new ArrayList<>());
+                response.getStatus().getReason()));
+        return searchResponse;
       }
     } catch (StatusRuntimeException e) {
       logSevere("search RPC failed:\n{0}", e.getStatus().toString());
-      return new SearchResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
+      SearchResponse searchResponse = new SearchResponse();
+      searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
+      return searchResponse;
     }
   }
 
@@ -369,8 +374,9 @@ public class MilvusGrpcClient implements MilvusClient {
 
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
-      return new SearchResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
+      SearchResponse searchResponse = new SearchResponse();
+      searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
+      return searchResponse;
     }
 
     SearchParam searchParam = searchInFilesParam.getSearchParam();
@@ -394,30 +400,33 @@ public class MilvusGrpcClient implements MilvusClient {
             .setSearchParam(searchParamToSet)
             .build();
 
-    io.milvus.grpc.TopKQueryResultList response;
+    io.milvus.grpc.TopKQueryResult response;
 
     try {
       response = blockingStub.searchInFiles(request);
 
       if (response.getStatus().getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
+        SearchResponse searchResponse = buildSearchResponse(response);
+        searchResponse.setResponse(new Response(Response.Status.SUCCESS));
         logInfo("Search in files {0} completed successfully!", searchInFilesParam.getFileIds());
-
-        List<List<SearchResponse.QueryResult>> queryResultsList = getQueryResultsList(response);
-        return new SearchResponse(new Response(Response.Status.SUCCESS), queryResultsList);
+        return searchResponse;
       } else {
         logSevere(
             "Search in files {0} failed:\n{1}",
             searchInFilesParam.getFileIds(), response.toString());
-        return new SearchResponse(
+
+        SearchResponse searchResponse = new SearchResponse();
+        searchResponse.setResponse(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
-                response.getStatus().getReason()),
-            new ArrayList<>());
+                response.getStatus().getReason()));
+        return searchResponse;
       }
     } catch (StatusRuntimeException e) {
       logSevere("searchInFiles RPC failed:\n{0}", e.getStatus().toString());
-      return new SearchResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
+      SearchResponse searchResponse = new SearchResponse();
+      searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
+      return searchResponse;
     }
   }
 
@@ -728,25 +737,23 @@ public class MilvusGrpcClient implements MilvusClient {
         .build();
   }
 
-  private List<List<SearchResponse.QueryResult>> getQueryResultsList(
-      io.milvus.grpc.TopKQueryResultList searchResponse) {
-    // TODO: refactor
-    List<List<SearchResponse.QueryResult>> queryResultsList = new ArrayList<>();
-    Optional<List<io.milvus.grpc.TopKQueryResult>> topKQueryResultList =
-        Optional.ofNullable(searchResponse.getTopkQueryResultList());
-    if (topKQueryResultList.isPresent()) {
-      for (io.milvus.grpc.TopKQueryResult topKQueryResult : topKQueryResultList.get()) {
-        List<SearchResponse.QueryResult> responseQueryResults = new ArrayList<>();
-        List<io.milvus.grpc.QueryResult> queryResults = topKQueryResult.getQueryResultArraysList();
-        for (io.milvus.grpc.QueryResult queryResult : queryResults) {
-          SearchResponse.QueryResult responseQueryResult =
-              new SearchResponse.QueryResult(queryResult.getId(), queryResult.getDistance());
-          responseQueryResults.add(responseQueryResult);
-        }
-        queryResultsList.add(responseQueryResults);
-      }
-    }
-    return queryResultsList;
+  private SearchResponse buildSearchResponse(io.milvus.grpc.TopKQueryResult topKQueryResult) {
+
+    final int numQueries = (int) topKQueryResult.getRowNum();
+    final int topK =
+        topKQueryResult.getIdsCount() / numQueries; // Guaranteed to be disable from server side
+
+    List<List<Long>> resultIdsList = ListUtils.partition(topKQueryResult.getIdsList(), topK);
+    List<List<Float>> resultDistancesList =
+        ListUtils.partition(topKQueryResult.getDistancesList(), topK);
+
+    SearchResponse searchResponse = new SearchResponse();
+    searchResponse.setNumQueries(numQueries);
+    searchResponse.setTopK(topK);
+    searchResponse.setResultIdsList(resultIdsList);
+    searchResponse.setResultDistancesList(resultDistancesList);
+
+    return searchResponse;
   }
 
   private boolean channelIsReadyOrIdle() {

+ 50 - 29
src/main/java/io/milvus/client/SearchResponse.java

@@ -19,8 +19,10 @@
 
 package io.milvus.client;
 
-import java.util.ArrayList;
 import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.LongStream;
 
 /**
  * Contains the returned <code>response</code> and <code>queryResultsList</code> for <code>search
@@ -28,12 +30,26 @@ import java.util.List;
  */
 public class SearchResponse {
 
-  private final Response response;
-  private final List<List<QueryResult>> queryResultsList;
+  private Response response;
+  private int numQueries;
+  private long topK;
+  private List<List<Long>> resultIdsList;
+  private List<List<Float>> resultDistancesList;
 
-  public SearchResponse(Response response, List<List<QueryResult>> queryResultsList) {
-    this.response = response;
-    this.queryResultsList = queryResultsList;
+  public int getNumQueries() {
+    return numQueries;
+  }
+
+  void setNumQueries(int numQueries) {
+    this.numQueries = numQueries;
+  }
+
+  public long getTopK() {
+    return topK;
+  }
+
+  void setTopK(long topK) {
+    this.topK = topK;
   }
 
   /**
@@ -41,7 +57,17 @@ public class SearchResponse {
    *     the query result of a vector.
    */
   public List<List<QueryResult>> getQueryResultsList() {
-    return queryResultsList;
+    return IntStream.range(0, numQueries)
+        .mapToObj(
+            i ->
+                LongStream.range(0, topK)
+                    .mapToObj(
+                        j ->
+                            new QueryResult(
+                                resultIdsList.get(i).get((int) j),
+                                resultDistancesList.get(i).get((int) j)))
+                    .collect(Collectors.toList()))
+        .collect(Collectors.toList());
   }
 
   /**
@@ -49,37 +75,33 @@ public class SearchResponse {
    *     of a vector.
    */
   public List<List<Long>> getResultIdsList() {
-    List<List<Long>> resultIdsList = new ArrayList<>();
-    for (List<QueryResult> queryResults : queryResultsList) {
-      List<Long> resultIds = new ArrayList<>();
-      for (QueryResult queryResult : queryResults) {
-        resultIds.add(queryResult.vectorId);
-      }
-      resultIdsList.add(resultIds);
-    }
     return resultIdsList;
   }
 
+  void setResultIdsList(List<List<Long>> resultIdsList) {
+    this.resultIdsList = resultIdsList;
+  }
+
   /**
    * @return @return a <code>List</code> of result distances. Each inner <code>List</code> contains
    *     the result distances of a vector.
    */
-  public List<List<Double>> getResultDistancesList() {
-    List<List<Double>> resultDistancesList = new ArrayList<>();
-    for (List<QueryResult> queryResults : queryResultsList) {
-      List<Double> resultDistances = new ArrayList<>();
-      for (QueryResult queryResult : queryResults) {
-        resultDistances.add(queryResult.distance);
-      }
-      resultDistancesList.add(resultDistances);
-    }
+  public List<List<Float>> getResultDistancesList() {
     return resultDistancesList;
   }
 
+  void setResultDistancesList(List<List<Float>> resultDistancesList) {
+    this.resultDistancesList = resultDistancesList;
+  }
+
   public Response getResponse() {
     return response;
   }
 
+  void setResponse(Response response) {
+    this.response = response;
+  }
+
   public boolean ok() {
     return response.ok();
   }
@@ -87,8 +109,7 @@ public class SearchResponse {
   @Override
   public String toString() {
     return String.format(
-        "SearchResponse {%s, returned results for %d queries}",
-        response.toString(), this.queryResultsList.size());
+        "SearchResponse {%s, returned results for %d queries}", response.toString(), numQueries);
   }
 
   /**
@@ -97,9 +118,9 @@ public class SearchResponse {
    */
   public static class QueryResult {
     private final long vectorId;
-    private final double distance;
+    private final float distance;
 
-    public QueryResult(long vectorId, double distance) {
+    public QueryResult(long vectorId, float distance) {
       this.vectorId = vectorId;
       this.distance = distance;
     }
@@ -108,7 +129,7 @@ public class SearchResponse {
       return vectorId;
     }
 
-    public double getDistance() {
+    public float getDistance() {
       return distance;
     }
   }

+ 272 - 31
src/main/proto/milvus.proto

@@ -9,319 +9,560 @@ option java_outer_classname = "MilvusProto";
 package milvus.grpc;
 
 /**
+
  * @brief Table Name
+
  */
+
 message TableName {
+
     string table_name = 1;
+
 }
 
 /**
+
  * @brief Table Name List
+
  */
+
 message TableNameList {
+
     Status status = 1;
+
     repeated string table_names = 2;
+
 }
 
 /**
+
  * @brief Table Schema
+
  */
+
 message TableSchema {
+
     Status status = 1;
+
     string table_name = 2;
+
     int64 dimension = 3;
+
     int64 index_file_size = 4;
+
     int32 metric_type = 5;
+
 }
 
 /**
+
  * @brief Range Schema
+
  */
+
 message Range {
+
     string start_value = 1;
+
     string end_value = 2;
+
 }
 
 /**
+
  * @brief Record inserted
+
  */
+
 message RowRecord {
-    repeated float vector_data = 1;             //binary vector data
+
+    repeated float vector_data = 1; //binary vector data
+
 }
 
 /**
+
  * @brief params to be inserted
+
  */
+
 message InsertParam {
+
     string table_name = 1;
+
     repeated RowRecord row_record_array = 2;
-    repeated int64 row_id_array = 3;            //optional
+
+    repeated int64 row_id_array = 3; //optional
+
 }
 
 /**
+
  * @brief Vector ids
+
  */
+
 message VectorIds {
+
     Status status = 1;
+
     repeated int64 vector_id_array = 2;
+
 }
 
 /**
+
  * @brief params for searching vector
+
  */
+
 message SearchParam {
+
     string table_name = 1;
+
     repeated RowRecord query_record_array = 2;
+
     repeated Range query_range_array = 3;
+
     int64 topk = 4;
+
     int64 nprobe = 5;
+
 }
 
 /**
+
  * @brief params for searching vector in files
+
  */
+
 message SearchInFilesParam {
+
     repeated string file_id_array = 1;
+
     SearchParam search_param = 2;
+
 }
 
 /**
+
  * @brief Query result params
- */
-message QueryResult {
-    int64 id = 1;
-    double distance = 2;
-}
 
-/**
- * @brief TopK query result
  */
+
 message TopKQueryResult {
-    repeated QueryResult query_result_arrays = 1;
-}
 
-/**
- * @brief List of topK query result
- */
-message TopKQueryResultList {
     Status status = 1;
-    repeated TopKQueryResult topk_query_result = 2;
+
+    int64 row_num = 2;
+
+    repeated int64 ids = 3;
+
+    repeated float distances = 4;
+
 }
 
 /**
+
  * @brief Server String Reply
+
  */
+
 message StringReply {
+
     Status status = 1;
+
     string string_reply = 2;
+
 }
 
 /**
+
  * @brief Server bool Reply
+
  */
+
 message BoolReply {
+
     Status status = 1;
+
     bool bool_reply = 2;
+
 }
 
 /**
+
  * @brief Return table row count
+
  */
+
 message TableRowCount {
+
     Status status = 1;
+
     int64 table_row_count = 2;
+
 }
 
 /**
+
  * @brief Give Server Command
+
  */
+
 message Command {
+
     string cmd = 1;
+
 }
 
 /**
+
  * @brief Index
+
  * @index_type: 0-invalid, 1-idmap, 2-ivflat, 3-ivfsq8, 4-nsgmix
+
  * @metric_type: 1-L2, 2-IP
+
  */
+
 message Index {
+
     int32 index_type = 1;
+
     int32 nlist = 2;
+
 }
 
 /**
+
  * @brief Index params
+
  */
+
 message IndexParam {
+
     Status status = 1;
+
     string table_name = 2;
+
     Index index = 3;
+
 }
 
 /**
+
  * @brief table name and range for DeleteByRange
+
  */
+
 message DeleteByRangeParam {
+
     Range range = 1;
+
     string table_name = 2;
+
 }
 
 service MilvusService {
+
     /**
+
      * @brief Create table method
+
      *
+
      * This method is used to create table
+
      *
+
      * @param param, use to provide table information to be created.
+
      *
+
      */
-    rpc CreateTable(TableSchema) returns (Status){}
+
+    rpc CreateTable (TableSchema) returns (Status) {
+    }
 
     /**
+
      * @brief Test table existence method
+
      *
+
      * This method is used to test table existence.
+
      *
+
      * @param table_name, table name is going to be tested.
+
      *
+
      */
-    rpc HasTable(TableName) returns (BoolReply) {}
+
+    rpc HasTable (TableName) returns (BoolReply) {
+    }
 
     /**
+
      * @brief Delete table method
+
      *
+
      * This method is used to delete table.
+
      *
+
      * @param table_name, table name is going to be deleted.
+
      *
+
      */
-    rpc DropTable(TableName) returns (Status) {}
+
+    rpc DropTable (TableName) returns (Status) {
+    }
 
     /**
+
      * @brief Build index by table method
+
      *
+
      * This method is used to build index by table in sync mode.
+
      *
+
      * @param table_name, table is going to be built index.
+
      *
+
      */
-    rpc CreateIndex(IndexParam) returns (Status) {}
+
+    rpc CreateIndex (IndexParam) returns (Status) {
+    }
 
     /**
+
      * @brief Add vector array to table
+
      *
+
      * This method is used to add vector array to table.
+
      *
+
      * @param table_name, table_name is inserted.
+
      * @param record_array, vector array is inserted.
+
      *
+
      * @return vector id array
+
      */
-    rpc Insert(InsertParam) returns (VectorIds) {}
+
+    rpc Insert (InsertParam) returns (VectorIds) {
+    }
 
     /**
+
      * @brief Query vector
+
      *
+
      * This method is used to query vector in table.
+
      *
+
      * @param table_name, table_name is queried.
+
      * @param query_record_array, all vector are going to be queried.
+
      * @param query_range_array, optional ranges for conditional search. If not specified, search whole table
+
      * @param topk, how many similarity vectors will be searched.
+
      *
+
      * @return query result array.
+
      */
-    rpc Search(SearchParam) returns (TopKQueryResultList) {}
+
+    rpc Search (SearchParam) returns (TopKQueryResult) {
+    }
 
     /**
+
      * @brief Internal use query interface
+
      *
+
      * This method is used to query vector in specified files.
+
      *
+
      * @param file_id_array, specified files id array, queried.
+
      * @param query_record_array, all vector are going to be queried.
+
      * @param query_range_array, optional ranges for conditional search. If not specified, search whole table
+
      * @param topk, how many similarity vectors will be searched.
+
      *
+
      * @return query result array.
+
      */
-    rpc SearchInFiles(SearchInFilesParam) returns (TopKQueryResultList) {}
+
+    rpc SearchInFiles (SearchInFilesParam) returns (TopKQueryResult) {
+    }
 
     /**
+
      * @brief Get table schema
+
      *
+
      * This method is used to get table schema.
+
      *
+
      * @param table_name, target table name.
+
      *
+
      * @return table schema
+
      */
-    rpc DescribeTable(TableName) returns (TableSchema) {}
+
+    rpc DescribeTable (TableName) returns (TableSchema) {
+    }
 
     /**
+
      * @brief Get table schema
+
      *
+
      * This method is used to get table schema.
+
      *
+
      * @param table_name, target table name.
+
      *
+
      * @return table schema
+
      */
-    rpc CountTable(TableName) returns (TableRowCount) {}
+
+    rpc CountTable (TableName) returns (TableRowCount) {
+    }
 
     /**
+
      * @brief List all tables in database
+
      *
+
      * This method is used to list all tables.
+
      *
+
      *
+
      * @return table names.
+
      */
-    rpc ShowTables(Command) returns (TableNameList) {}
+
+    rpc ShowTables (Command) returns (TableNameList) {
+    }
 
     /**
+
      * @brief Give the server status
+
      *
+
      * This method is used to give the server status.
+
      *
+
      * @return Server status.
+
      */
-    rpc Cmd(Command) returns (StringReply) {}
+
+    rpc Cmd (Command) returns (StringReply) {
+    }
 
     /**
+
     * @brief delete table by range
+
     *
+
     * This method is used to delete vector by range
+
     *
+
     * @return rpc status.
+
     */
-    rpc DeleteByRange(DeleteByRangeParam) returns (Status) {}
+
+    rpc DeleteByRange (DeleteByRangeParam) returns (Status) {
+    }
 
     /**
+
      * @brief preload table
+
      *
+
      * This method is used to preload table
+
      *
+
      * @return Status.
+
      */
-    rpc PreloadTable(TableName) returns (Status) {}
+
+    rpc PreloadTable (TableName) returns (Status) {
+    }
 
     /**
+
      * @brief describe index
+
      *
+
      * This method is used to describe index
+
      *
+
      * @return Status.
+
      */
-    rpc DescribeIndex(TableName) returns (IndexParam) {}
+
+    rpc DescribeIndex (TableName) returns (IndexParam) {
+    }
 
     /**
+
      * @brief drop index
+
      *
+
      * This method is used to drop index
+
      *
+
      * @return Status.
+
      */
-    rpc DropIndex(TableName) returns (Status) {}
+
+    rpc DropIndex (TableName) returns (Status) {
+    }
 
 }

+ 10 - 4
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -21,6 +21,7 @@ package io.milvus.client;
 
 import org.apache.commons.text.RandomStringGenerator;
 
+import java.io.IOException;
 import java.util.*;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
@@ -66,7 +67,7 @@ class MilvusClientTest {
 
     client = new MilvusGrpcClient();
     ConnectParam connectParam =
-        new ConnectParam.Builder().withHost("localhost").withPort("19530").build();
+        new ConnectParam.Builder().withHost("192.168.1.113").withPort("19530").build();
     client.connect(connectParam);
 
     generator = new RandomStringGenerator.Builder().withinRange('a', 'z').build();
@@ -191,7 +192,7 @@ class MilvusClientTest {
   }
 
   @org.junit.jupiter.api.Test
-  void search() throws InterruptedException {
+  void search() throws InterruptedException, IOException {
     List<List<Float>> vectors = generateVectors(size, dimension);
     vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
     InsertParam insertParam = new InsertParam.Builder(randomTableName, vectors).build();
@@ -215,7 +216,7 @@ class MilvusClientTest {
     c.add(Calendar.DAY_OF_MONTH, 1);
     Date tomorrow = c.getTime();
     queryRanges.add(new DateRange(yesterday, tomorrow));
-    final long topK = 1000;
+    final long topK = 10;
     SearchParam searchParam =
         new SearchParam.Builder(randomTableName, vectorsToSearch)
             .withTopK(topK)
@@ -224,14 +225,19 @@ class MilvusClientTest {
             .build();
     SearchResponse searchResponse = client.search(searchParam);
     assertTrue(searchResponse.ok());
-    System.out.println(searchResponse);
+    List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
+    assertEquals(searchSize, resultIdsList.size());
+    List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
+    assertEquals(searchSize, resultDistancesList.size());
     List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
     assertEquals(searchSize, queryResultsList.size());
     final double epsilon = 0.001;
     for (int i = 0; i < searchSize; i++) {
       SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
       assertEquals(vectorIds.get(i), firstQueryResult.getVectorId());
+      assertEquals(vectorIds.get(i), resultIdsList.get(i).get(0));
       assertTrue(Math.abs(1 - firstQueryResult.getDistance()) < epsilon);
+      assertTrue(Math.abs(1 - resultDistancesList.get(i).get(0)) < epsilon);
     }
   }