Browse Source

Add more getters in SearchResponse & add normalize method in unittest

zhiru 5 years ago
parent
commit
d7d91b0bb9

+ 16 - 1
src/main/java/io/milvus/client/MilvusClient.java

@@ -3,13 +3,28 @@ package io.milvus.client;
 public interface MilvusClient {
 public interface MilvusClient {
 
 
     String clientVersion = "0.1.0";
     String clientVersion = "0.1.0";
+
+    /**
+     * @return the current Milvus client version
+     */
     default String clientVersion() {
     default String clientVersion() {
         return clientVersion;
         return clientVersion;
     }
     }
 
 
     /**
     /**
-     * @param connectParam
+     * Connects to Milvus server
+     * @param connectParam the <code>ConnectParam</code> object
+     *                     <pre>
+     *                     example usage:
+     *                     <code>
+     *                         ConnectParam connectParam = new ConnectParam.Builder()
+     *                                                                     .withHost("localhost")
+     *                                                                     .withPort("19530")
+     *                                                                     .build();
+     *                     </code>
+     *                     </pre>
      * @return <code>Response</code>
      * @return <code>Response</code>
+     * @see ConnectParam
      */
      */
     Response connect(ConnectParam connectParam);
     Response connect(ConnectParam connectParam);
 
 

+ 23 - 1
src/main/java/io/milvus/client/SearchResponse.java

@@ -1,5 +1,6 @@
 package io.milvus.client;
 package io.milvus.client;
 
 
+import java.util.ArrayList;
 import java.util.List;
 import java.util.List;
 
 
 public class SearchResponse {
 public class SearchResponse {
@@ -35,8 +36,29 @@ public class SearchResponse {
         return queryResultsList;
         return queryResultsList;
     }
     }
 
 
-    //TODO: iterator
+    public List<List<Long>> getResultIdsList() {
+        List<List<Long>> resultIdsList = new ArrayList<>();
+        for (List<QueryResult> queryResults : queryResultsList) {
+            List<Long> resultIds = new ArrayList<>();
+            for (QueryResult queryResult : queryResults) {
+                resultIds.add(queryResult.vectorId);
+            }
+            resultIdsList.add(resultIds);
+        }
+        return resultIdsList;
+    }
 
 
+    public List<List<Double>> getResultDistancesList() {
+        List<List<Double>> resultDistancesList = new ArrayList<>();
+        for (List<QueryResult> queryResults : queryResultsList) {
+            List<Double> resultDistances = new ArrayList<>();
+            for (QueryResult queryResult : queryResults) {
+                resultDistances.add(queryResult.distance);
+            }
+            resultDistancesList.add(resultDistances);
+        }
+        return resultDistancesList;
+    }
 
 
     public Response getResponse() {
     public Response getResponse() {
         return response;
         return response;

+ 15 - 2
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -9,6 +9,8 @@ import org.apache.commons.text.RandomStringGenerator;
 
 
 import java.util.*;
 import java.util.*;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 
 import static org.junit.jupiter.api.Assertions.*;
 import static org.junit.jupiter.api.Assertions.*;
 
 
@@ -22,6 +24,7 @@ class MilvusGrpcClientTest {
     private long size;
     private long size;
     private long dimension;
     private long dimension;
     private TableParam tableParam;
     private TableParam tableParam;
+    private TableSchema tableSchema;
 
 
     @org.junit.jupiter.api.BeforeEach
     @org.junit.jupiter.api.BeforeEach
     void setUp() throws Exception {
     void setUp() throws Exception {
@@ -39,9 +42,9 @@ class MilvusGrpcClientTest {
         size = 100;
         size = 100;
         dimension = 128;
         dimension = 128;
         tableParam = new TableParam.Builder(randomTableName).build();
         tableParam = new TableParam.Builder(randomTableName).build();
-        TableSchema tableSchema = new TableSchema.Builder(randomTableName, dimension)
+        tableSchema = new TableSchema.Builder(randomTableName, dimension)
                                                     .withIndexFileSize(1024)
                                                     .withIndexFileSize(1024)
-                                                    .withMetricType(MetricType.L2)
+                                                    .withMetricType(MetricType.IP)
                                                     .build();
                                                     .build();
         TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
         TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
 
 
@@ -114,6 +117,13 @@ class MilvusGrpcClientTest {
         assertEquals(size, insertResponse.getVectorIds().size());
         assertEquals(size, insertResponse.getVectorIds().size());
     }
     }
 
 
+    List<Float> normalize(List<Float> vector) {
+        float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
+        final float norm = (float) Math.sqrt(squareSum);
+        vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
+        return vector;
+    }
+
     @org.junit.jupiter.api.Test
     @org.junit.jupiter.api.Test
     void search() throws InterruptedException {
     void search() throws InterruptedException {
         Random random = new Random();
         Random random = new Random();
@@ -125,6 +135,9 @@ class MilvusGrpcClientTest {
             for (int j = 0; j < dimension; ++j) {
             for (int j = 0; j < dimension; ++j) {
                 vector.add(random.nextFloat());
                 vector.add(random.nextFloat());
             }
             }
+            if (tableSchema.getMetricType() == MetricType.IP) {
+                vector = normalize(vector);
+            }
             vectors.add(vector);
             vectors.add(vector);
             if (i < searchSize) {
             if (i < searchSize) {
                 vectorsToSearch.add(vector);
                 vectorsToSearch.add(vector);