|
@@ -19,12 +19,26 @@ import io.milvus.client.*;
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.List;
|
|
|
-import java.util.Random;
|
|
|
+import java.util.SplittableRandom;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
import java.util.stream.Collectors;
|
|
|
+import java.util.stream.DoubleStream;
|
|
|
|
|
|
public class MilvusClientExample {
|
|
|
|
|
|
+ // Helper function that generates random vectors
|
|
|
+ static List<List<Float>> generateRandomVectors(long vectorCount, long dimension) {
|
|
|
+ SplittableRandom splittableRandom = new SplittableRandom();
|
|
|
+ List<List<Float>> vectors = new ArrayList<>();
|
|
|
+ for (int i = 0; i < vectorCount; ++i) {
|
|
|
+ DoubleStream doubleStream = splittableRandom.doubles(dimension);
|
|
|
+ List<Float> vector =
|
|
|
+ doubleStream.boxed().map(Double::floatValue).collect(Collectors.toList());
|
|
|
+ vectors.add(vector);
|
|
|
+ }
|
|
|
+ return vectors;
|
|
|
+ }
|
|
|
+
|
|
|
// Helper function that normalizes a vector if you are using IP (Inner product) as your metric
|
|
|
// type
|
|
|
static List<Float> normalize(List<Float> vector) {
|
|
@@ -36,7 +50,7 @@ public class MilvusClientExample {
|
|
|
|
|
|
public static void main(String[] args) throws InterruptedException {
|
|
|
|
|
|
- final String host = "localhost";
|
|
|
+ final String host = "192.168.1.149";
|
|
|
final String port = "19530";
|
|
|
|
|
|
// Create Milvus client
|
|
@@ -77,17 +91,9 @@ public class MilvusClientExample {
|
|
|
System.out.println(describeTableResponse);
|
|
|
|
|
|
// Insert randomly generated vectors to table
|
|
|
- final int vectorCount = 1024;
|
|
|
- Random random = new Random();
|
|
|
- List<List<Float>> vectors = new ArrayList<>();
|
|
|
- for (int i = 0; i < vectorCount; ++i) {
|
|
|
- List<Float> vector = new ArrayList<>();
|
|
|
- for (int j = 0; j < dimension; ++j) {
|
|
|
- vector.add(random.nextFloat());
|
|
|
- }
|
|
|
- vector = normalize(vector);
|
|
|
- vectors.add(vector);
|
|
|
- }
|
|
|
+ final int vectorCount = 100000;
|
|
|
+ List<List<Float>> vectors = generateRandomVectors(vectorCount, dimension);
|
|
|
+ vectors.forEach(MilvusClientExample::normalize);
|
|
|
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withTimeout(10).build();
|
|
|
InsertResponse insertResponse = client.insert(insertParam);
|
|
|
System.out.println(insertResponse);
|
|
@@ -130,9 +136,13 @@ public class MilvusClientExample {
|
|
|
final double epsilon = 0.001;
|
|
|
for (int i = 0; i < searchSize; i++) {
|
|
|
// Since we are searching for vector that is already present in the table,
|
|
|
- // the first result vector should be itself and the distance should be less than epsilon
|
|
|
- assert queryResultsList.get(i).get(0).getVectorId() == vectorIds.get(0);
|
|
|
- assert queryResultsList.get(i).get(0).getDistance() < epsilon;
|
|
|
+ // the first result vector should be itself and the distance (inner product) should be
|
|
|
+ // very close to 1 (some precision is lost during the process)
|
|
|
+ SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
|
|
|
+ if (firstQueryResult.getVectorId() != vectorIds.get(i)
|
|
|
+ || firstQueryResult.getDistance() <= (1 - epsilon)) {
|
|
|
+ throw new AssertionError();
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// Drop index for the table
|