|
@@ -9,6 +9,8 @@ import org.apache.commons.text.RandomStringGenerator;
|
|
|
|
|
|
import java.util.*;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
+import java.util.function.Function;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
import static org.junit.jupiter.api.Assertions.*;
|
|
|
|
|
@@ -22,6 +24,7 @@ class MilvusGrpcClientTest {
|
|
|
private long size;
|
|
|
private long dimension;
|
|
|
private TableParam tableParam;
|
|
|
+ private TableSchema tableSchema;
|
|
|
|
|
|
@org.junit.jupiter.api.BeforeEach
|
|
|
void setUp() throws Exception {
|
|
@@ -39,9 +42,9 @@ class MilvusGrpcClientTest {
|
|
|
size = 100;
|
|
|
dimension = 128;
|
|
|
tableParam = new TableParam.Builder(randomTableName).build();
|
|
|
- TableSchema tableSchema = new TableSchema.Builder(randomTableName, dimension)
|
|
|
+ tableSchema = new TableSchema.Builder(randomTableName, dimension)
|
|
|
.withIndexFileSize(1024)
|
|
|
- .withMetricType(MetricType.L2)
|
|
|
+ .withMetricType(MetricType.IP)
|
|
|
.build();
|
|
|
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
|
|
|
|
|
@@ -114,6 +117,13 @@ class MilvusGrpcClientTest {
|
|
|
assertEquals(size, insertResponse.getVectorIds().size());
|
|
|
}
|
|
|
|
|
|
+ List<Float> normalize(List<Float> vector) {
|
|
|
+ float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
|
|
|
+ final float norm = (float) Math.sqrt(squareSum);
|
|
|
+ vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
|
|
|
+ return vector;
|
|
|
+ }
|
|
|
+
|
|
|
@org.junit.jupiter.api.Test
|
|
|
void search() throws InterruptedException {
|
|
|
Random random = new Random();
|
|
@@ -125,6 +135,9 @@ class MilvusGrpcClientTest {
|
|
|
for (int j = 0; j < dimension; ++j) {
|
|
|
vector.add(random.nextFloat());
|
|
|
}
|
|
|
+ if (tableSchema.getMetricType() == MetricType.IP) {
|
|
|
+ vector = normalize(vector);
|
|
|
+ }
|
|
|
vectors.add(vector);
|
|
|
if (i < searchSize) {
|
|
|
vectorsToSearch.add(vector);
|