zhiru пре 5 година
родитељ
комит
f2ef83dd08
1 измењених фајлова са 26 додато и 16 уклоњено
  1. 26 16
      examples/src/main/java/MilvusClientExample.java

+ 26 - 16
examples/src/main/java/MilvusClientExample.java

@@ -19,12 +19,26 @@ import io.milvus.client.*;
 
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Random;
+import java.util.SplittableRandom;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
 
 public class MilvusClientExample {
 
+  // Helper function that generates random vectors
+  static List<List<Float>> generateRandomVectors(long vectorCount, long dimension) {
+    SplittableRandom splittableRandom = new SplittableRandom();
+    List<List<Float>> vectors = new ArrayList<>();
+    for (int i = 0; i < vectorCount; ++i) {
+      DoubleStream doubleStream = splittableRandom.doubles(dimension);
+      List<Float> vector =
+          doubleStream.boxed().map(Double::floatValue).collect(Collectors.toList());
+      vectors.add(vector);
+    }
+    return vectors;
+  }
+
   // Helper function that normalizes a vector if you are using IP (Inner product) as your metric
   // type
   static List<Float> normalize(List<Float> vector) {
@@ -36,7 +50,7 @@ public class MilvusClientExample {
 
   public static void main(String[] args) throws InterruptedException {
 
-    final String host = "localhost";
+    final String host = "192.168.1.149";
     final String port = "19530";
 
     // Create Milvus client
@@ -77,17 +91,9 @@ public class MilvusClientExample {
     System.out.println(describeTableResponse);
 
     // Insert randomly generated vectors to table
-    final int vectorCount = 1024;
-    Random random = new Random();
-    List<List<Float>> vectors = new ArrayList<>();
-    for (int i = 0; i < vectorCount; ++i) {
-      List<Float> vector = new ArrayList<>();
-      for (int j = 0; j < dimension; ++j) {
-        vector.add(random.nextFloat());
-      }
-      vector = normalize(vector);
-      vectors.add(vector);
-    }
+    final int vectorCount = 100000;
+    List<List<Float>> vectors = generateRandomVectors(vectorCount, dimension);
+    vectors.forEach(MilvusClientExample::normalize);
     InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withTimeout(10).build();
     InsertResponse insertResponse = client.insert(insertParam);
     System.out.println(insertResponse);
@@ -130,9 +136,13 @@ public class MilvusClientExample {
     final double epsilon = 0.001;
     for (int i = 0; i < searchSize; i++) {
       // Since we are searching for vector that is already present in the table,
-      // the first result vector should be itself and the distance should be less than epsilon
-      assert queryResultsList.get(i).get(0).getVectorId() == vectorIds.get(0);
-      assert queryResultsList.get(i).get(0).getDistance() < epsilon;
+      // the first result vector should be itself and the distance (inner product) should be
+      // very close to 1 (some precision is lost during the process)
+      SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+      if (firstQueryResult.getVectorId() != vectorIds.get(i)
+          || firstQueryResult.getDistance() <= (1 - epsilon)) {
+        throw new AssertionError();
+      }
     }
 
     // Drop index for the table