Browse Source

Merge pull request #9 from youny626/branch-0.5.0

Add more getters in SearchResponse & add normalize method in unittest
Jin Hai 5 years ago
parent
commit
59758d888a

+ 0 - 4
src/main/java/io/milvus/client/MilvusClient.java

@@ -7,10 +7,6 @@ public interface MilvusClient {
         return clientVersion;
     }
 
-    /**
-     * @param connectParam
-     * @return <code>Response</code>
-     */
     Response connect(ConnectParam connectParam);
 
     boolean connected();

+ 23 - 1
src/main/java/io/milvus/client/SearchResponse.java

@@ -1,5 +1,6 @@
 package io.milvus.client;
 
+import java.util.ArrayList;
 import java.util.List;
 
 public class SearchResponse {
@@ -35,8 +36,29 @@ public class SearchResponse {
         return queryResultsList;
     }
 
-    //TODO: iterator
+    public List<List<Long>> getResultIdsList() {
+        List<List<Long>> resultIdsList = new ArrayList<>();
+        for (List<QueryResult> queryResults : queryResultsList) {
+            List<Long> resultIds = new ArrayList<>();
+            for (QueryResult queryResult : queryResults) {
+                resultIds.add(queryResult.vectorId);
+            }
+            resultIdsList.add(resultIds);
+        }
+        return resultIdsList;
+    }
 
+    public List<List<Double>> getResultDistancesList() {
+        List<List<Double>> resultDistancesList = new ArrayList<>();
+        for (List<QueryResult> queryResults : queryResultsList) {
+            List<Double> resultDistances = new ArrayList<>();
+            for (QueryResult queryResult : queryResults) {
+                resultDistances.add(queryResult.distance);
+            }
+            resultDistancesList.add(resultDistances);
+        }
+        return resultDistancesList;
+    }
 
     public Response getResponse() {
         return response;

+ 15 - 2
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -9,6 +9,8 @@ import org.apache.commons.text.RandomStringGenerator;
 
 import java.util.*;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 import static org.junit.jupiter.api.Assertions.*;
 
@@ -22,6 +24,7 @@ class MilvusGrpcClientTest {
     private long size;
     private long dimension;
     private TableParam tableParam;
+    private TableSchema tableSchema;
 
     @org.junit.jupiter.api.BeforeEach
     void setUp() throws Exception {
@@ -39,9 +42,9 @@ class MilvusGrpcClientTest {
         size = 100;
         dimension = 128;
         tableParam = new TableParam.Builder(randomTableName).build();
-        TableSchema tableSchema = new TableSchema.Builder(randomTableName, dimension)
+        tableSchema = new TableSchema.Builder(randomTableName, dimension)
                                                     .withIndexFileSize(1024)
-                                                    .withMetricType(MetricType.L2)
+                                                    .withMetricType(MetricType.IP)
                                                     .build();
         TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
 
@@ -114,6 +117,13 @@ class MilvusGrpcClientTest {
         assertEquals(size, insertResponse.getVectorIds().size());
     }
 
+    List<Float> normalize(List<Float> vector) {
+        float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
+        final float norm = (float) Math.sqrt(squareSum);
+        vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
+        return vector;
+    }
+
     @org.junit.jupiter.api.Test
     void search() throws InterruptedException {
         Random random = new Random();
@@ -125,6 +135,9 @@ class MilvusGrpcClientTest {
             for (int j = 0; j < dimension; ++j) {
                 vector.add(random.nextFloat());
             }
+            if (tableSchema.getMetricType() == MetricType.IP) {
+                vector = normalize(vector);
+            }
             vectors.add(vector);
             if (i < searchSize) {
                 vectorsToSearch.add(vector);