Browse Source

Fix metric type bug of search param (#330)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 3 years ago
parent
commit
d28936a773

+ 6 - 6
src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java

@@ -1331,7 +1331,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
             }
             }
 
 
             DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
             DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
-            InsertRequest insertRequest = ParamUtils.ConvertInsertParam(requestParam, wrapper.getFields());
+            InsertRequest insertRequest = ParamUtils.convertInsertParam(requestParam, wrapper.getFields());
             MutationResult response = blockingStub().insert(insertRequest);
             MutationResult response = blockingStub().insert(insertRequest);
 
 
             if (response.getStatus().getErrorCode() == ErrorCode.Success) {
             if (response.getStatus().getErrorCode() == ErrorCode.Success) {
@@ -1372,7 +1372,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
         }
         }
 
 
         DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
         DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
-        InsertRequest insertRequest = ParamUtils.ConvertInsertParam(requestParam, wrapper.getFields());
+        InsertRequest insertRequest = ParamUtils.convertInsertParam(requestParam, wrapper.getFields());
         ListenableFuture<MutationResult> response = futureStub().insert(insertRequest);
         ListenableFuture<MutationResult> response = futureStub().insert(insertRequest);
 
 
         Futures.addCallback(
         Futures.addCallback(
@@ -1418,7 +1418,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
         logInfo(requestParam.toString());
         logInfo(requestParam.toString());
 
 
         try {
         try {
-            SearchRequest searchRequest = ParamUtils.ConvertSearchParam(requestParam);
+            SearchRequest searchRequest = ParamUtils.convertSearchParam(requestParam);
             SearchResults response = this.blockingStub().search(searchRequest);
             SearchResults response = this.blockingStub().search(searchRequest);
 
 
             //TODO: truncate distance value by round decimal
             //TODO: truncate distance value by round decimal
@@ -1448,7 +1448,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
 
 
         logInfo(requestParam.toString());
         logInfo(requestParam.toString());
 
 
-        SearchRequest searchRequest = ParamUtils.ConvertSearchParam(requestParam);
+        SearchRequest searchRequest = ParamUtils.convertSearchParam(requestParam);
         ListenableFuture<SearchResults> response = this.futureStub().search(searchRequest);
         ListenableFuture<SearchResults> response = this.futureStub().search(searchRequest);
 
 
         Futures.addCallback(
         Futures.addCallback(
@@ -1494,7 +1494,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
         logInfo(requestParam.toString());
         logInfo(requestParam.toString());
 
 
         try {
         try {
-            QueryRequest queryRequest = ParamUtils.ConvertQueryParam(requestParam);
+            QueryRequest queryRequest = ParamUtils.convertQueryParam(requestParam);
             QueryResults response = this.blockingStub().query(queryRequest);
             QueryResults response = this.blockingStub().query(queryRequest);
             if (response.getStatus().getErrorCode() == ErrorCode.Success) {
             if (response.getStatus().getErrorCode() == ErrorCode.Success) {
                 logInfo("QueryRequest successfully!");
                 logInfo("QueryRequest successfully!");
@@ -1529,7 +1529,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
 
 
         logInfo(requestParam.toString());
         logInfo(requestParam.toString());
 
 
-        QueryRequest queryRequest = ParamUtils.ConvertQueryParam(requestParam);
+        QueryRequest queryRequest = ParamUtils.convertQueryParam(requestParam);
         ListenableFuture<QueryResults> response = this.futureStub().query(queryRequest);
         ListenableFuture<QueryResults> response = this.futureStub().query(queryRequest);
 
 
         Futures.addCallback(
         Futures.addCallback(

+ 1 - 1
src/main/java/io/milvus/param/ConnectParam.java

@@ -104,7 +104,7 @@ public class ConnectParam {
         private boolean keepAliveWithoutCalls = false;
         private boolean keepAliveWithoutCalls = false;
         private boolean secure = false;
         private boolean secure = false;
         private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS);
         private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS);
-        private String authorization = "";
+        private String authorization = Base64.getEncoder().encodeToString("root:milvus".getBytes(StandardCharsets.UTF_8));
 
 
         private Builder() {
         private Builder() {
         }
         }

+ 19 - 19
src/main/java/io/milvus/param/ParamUtils.java

@@ -34,13 +34,24 @@ public class ParamUtils {
     }
     }
 
 
     /**
     /**
-     * Convert {@link InsertParam} to proto type InsertRequest.
+     * Checks if a metric is for float vector.
      *
      *
-     * @param requestParam {@link InsertParam} object
-     * @param fieldTypes {@link FieldType} object to validate the requestParam
-     * @return a <code>InsertRequest</code> object
+     * @param metric metric type
      */
      */
-    public static InsertRequest ConvertInsertParam(@NonNull InsertParam requestParam,
+    public static boolean IsFloatMetric(MetricType metric) {
+        return metric == MetricType.L2 || metric == MetricType.IP;
+    }
+
+    /**
+     * Checks if a metric is for binary vector.
+     *
+     * @param metric metric type
+     */
+    public static boolean IsBinaryMetric(MetricType metric) {
+        return !IsFloatMetric(metric);
+    }
+
+    public static InsertRequest convertInsertParam(@NonNull InsertParam requestParam,
                                                    @NonNull List<FieldType> fieldTypes) {
                                                    @NonNull List<FieldType> fieldTypes) {
         String collectionName = requestParam.getCollectionName();
         String collectionName = requestParam.getCollectionName();
         String partitionName = requestParam.getPartitionName();
         String partitionName = requestParam.getPartitionName();
@@ -85,14 +96,8 @@ public class ParamUtils {
         return insertBuilder.build();
         return insertBuilder.build();
     }
     }
 
 
-    /**
-     * Convert {@link SearchParam} to proto type SearchRequest.
-     *
-     * @param requestParam {@link SearchParam} object
-     * @return a <code>SearchRequest</code> object
-     */
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
-    public static SearchRequest ConvertSearchParam(@NonNull SearchParam requestParam) throws ParamException {
+    public static SearchRequest convertSearchParam(@NonNull SearchParam requestParam) throws ParamException {
         SearchRequest.Builder builder = SearchRequest.newBuilder()
         SearchRequest.Builder builder = SearchRequest.newBuilder()
                 .setDbName("")
                 .setDbName("")
                 .setCollectionName(requestParam.getCollectionName());
                 .setCollectionName(requestParam.getCollectionName());
@@ -188,13 +193,8 @@ public class ParamUtils {
 
 
         return builder.build();
         return builder.build();
     }
     }
-    /**
-     * Convert {@link QueryParam} to proto type QueryRequest.
-     *
-     * @param requestParam {@link QueryParam} object
-     * @return a <code>QueryRequest</code> object
-     */
-    public static QueryRequest ConvertQueryParam(@NonNull QueryParam requestParam) {
+
+    public static QueryRequest convertQueryParam(@NonNull QueryParam requestParam) {
         long guaranteeTimestamp = getGuaranteeTimestamp(requestParam.getConsistencyLevel(),
         long guaranteeTimestamp = getGuaranteeTimestamp(requestParam.getConsistencyLevel(),
                 requestParam.getGuaranteeTimestamp(), requestParam.getGracefulTime());
                 requestParam.getGuaranteeTimestamp(), requestParam.getGracefulTime());
         return QueryRequest.newBuilder()
         return QueryRequest.newBuilder()

+ 11 - 1
src/main/java/io/milvus/param/dml/SearchParam.java

@@ -311,7 +311,7 @@ public class SearchParam {
             }
             }
 
 
             if (metricType == MetricType.INVALID) {
             if (metricType == MetricType.INVALID) {
-                throw new ParamException("Metric type is illegal");
+                throw new ParamException("Metric type is invalid");
             }
             }
 
 
             if (vectors == null || vectors.isEmpty()) {
             if (vectors == null || vectors.isEmpty()) {
@@ -332,6 +332,11 @@ public class SearchParam {
                         throw new ParamException("Target vector dimension must be equal");
                         throw new ParamException("Target vector dimension must be equal");
                     }
                     }
                 }
                 }
+
+                // check metric type
+                if (!ParamUtils.IsFloatMetric(metricType)) {
+                    throw new ParamException("Target vector is float but metric type is incorrect");
+                }
             } else if (vectors.get(0) instanceof ByteBuffer) {
             } else if (vectors.get(0) instanceof ByteBuffer) {
                 // binary vectors
                 // binary vectors
                 ByteBuffer first = (ByteBuffer) vectors.get(0);
                 ByteBuffer first = (ByteBuffer) vectors.get(0);
@@ -342,6 +347,11 @@ public class SearchParam {
                         throw new ParamException("Target vector dimension must be equal");
                         throw new ParamException("Target vector dimension must be equal");
                     }
                     }
                 }
                 }
+
+                // check metric type
+                if (!ParamUtils.IsBinaryMetric(metricType)) {
+                    throw new ParamException("Target vector is binary but metric type is incorrect");
+                }
             } else {
             } else {
                 throw new ParamException("Target vector type must be Lst<Float> or ByteBuffer");
                 throw new ParamException("Target vector type must be Lst<Float> or ByteBuffer");
             }
             }

+ 71 - 3
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -1685,6 +1685,8 @@ class MilvusServiceClientTest {
         List<String> partitions = Collections.singletonList("partition1");
         List<String> partitions = Collections.singletonList("partition1");
         List<String> outputFields = Collections.singletonList("field1");
         List<String> outputFields = Collections.singletonList("field1");
         List<List<Float>> vectors = new ArrayList<>();
         List<List<Float>> vectors = new ArrayList<>();
+
+        // target vector is empty
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
                 .withPartitionNames(partitions)
                 .withPartitionNames(partitions)
@@ -1699,6 +1701,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
+        // travel timestamp must be greater than 0
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
                 .withPartitionNames(partitions)
                 .withPartitionNames(partitions)
@@ -1714,6 +1717,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
+        // guarantee timestamp must be greater than 0
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
                 .withPartitionNames(partitions)
                 .withPartitionNames(partitions)
@@ -1729,6 +1733,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
+        // collection name is empty
         List<Float> vector1 = Collections.singletonList(0.1F);
         List<Float> vector1 = Collections.singletonList(0.1F);
         vectors.add(vector1);
         vectors.add(vector1);
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
@@ -1744,6 +1749,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
+        // target field name is empty
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
                 .withPartitionNames(partitions)
                 .withPartitionNames(partitions)
@@ -1757,6 +1763,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
+        // metric type is invalid
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
                 .withPartitionNames(partitions)
                 .withPartitionNames(partitions)
@@ -1770,6 +1777,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
+        // illegal topk value
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
                 .withPartitionNames(partitions)
                 .withPartitionNames(partitions)
@@ -1783,7 +1791,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
-        // vector type illegal
+        // target vector type must be Lst<Float> or ByteBuffer
         List<String> fakeVectors1 = Collections.singletonList("fake");
         List<String> fakeVectors1 = Collections.singletonList("fake");
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
@@ -1798,6 +1806,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
+        // float vector field's value must be Lst<Float>
         List<List<String>> fakeVectors2 = Collections.singletonList(fakeVectors1);
         List<List<String>> fakeVectors2 = Collections.singletonList(fakeVectors1);
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
@@ -1812,7 +1821,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
-        // vector dimension not equal
+        // float vector dimension not equal
         List<Float> vector2 = Arrays.asList(0.1F, 0.2F);
         List<Float> vector2 = Arrays.asList(0.1F, 0.2F);
         vectors.add(vector2);
         vectors.add(vector2);
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
@@ -1828,6 +1837,7 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
+        // binary vector dimension not equal
         ByteBuffer buf1 = ByteBuffer.allocate(1);
         ByteBuffer buf1 = ByteBuffer.allocate(1);
         buf1.put((byte) 1);
         buf1.put((byte) 1);
         ByteBuffer buf2 = ByteBuffer.allocate(2);
         ByteBuffer buf2 = ByteBuffer.allocate(2);
@@ -1840,12 +1850,70 @@ class MilvusServiceClientTest {
                 .withParams("dummy")
                 .withParams("dummy")
                 .withOutFields(outputFields)
                 .withOutFields(outputFields)
                 .withVectorFieldName("field1")
                 .withVectorFieldName("field1")
-                .withMetricType(MetricType.IP)
+                .withMetricType(MetricType.HAMMING)
                 .withTopK(5)
                 .withTopK(5)
                 .withVectors(binVectors)
                 .withVectors(binVectors)
                 .withExpr("dummy")
                 .withExpr("dummy")
                 .build()
                 .build()
         );
         );
+
+        // float vector metric type is illegal
+        List<List<Float>> vectors2 = Arrays.asList(vector2);
+        assertThrows(ParamException.class, () -> SearchParam.newBuilder()
+                .withCollectionName("collection1")
+                .withPartitionNames(partitions)
+                .withParams("dummy")
+                .withOutFields(outputFields)
+                .withVectorFieldName("field1")
+                .withMetricType(MetricType.JACCARD)
+                .withTopK(5)
+                .withVectors(vectors2)
+                .withExpr("dummy")
+                .build()
+        );
+
+        // binary vector metric type is illegal
+        List<ByteBuffer> binVectors2 = Arrays.asList(buf2);
+        assertThrows(ParamException.class, () -> SearchParam.newBuilder()
+                .withCollectionName("collection1")
+                .withPartitionNames(partitions)
+                .withParams("dummy")
+                .withOutFields(outputFields)
+                .withVectorFieldName("field1")
+                .withMetricType(MetricType.IP)
+                .withTopK(5)
+                .withVectors(binVectors2)
+                .withExpr("dummy")
+                .build()
+        );
+
+        // succeed float vector case
+        assertDoesNotThrow(() -> SearchParam.newBuilder()
+                .withCollectionName("collection1")
+                .withPartitionNames(partitions)
+                .withParams("dummy")
+                .withOutFields(outputFields)
+                .withVectorFieldName("field1")
+                .withMetricType(MetricType.L2)
+                .withTopK(5)
+                .withVectors(vectors2)
+                .withExpr("dummy")
+                .build()
+        );
+
+        // succeed binary vector case
+        assertDoesNotThrow(() -> SearchParam.newBuilder()
+                .withCollectionName("collection1")
+                .withPartitionNames(partitions)
+                .withParams("dummy")
+                .withOutFields(outputFields)
+                .withVectorFieldName("field1")
+                .withMetricType(MetricType.HAMMING)
+                .withTopK(5)
+                .withVectors(binVectors2)
+                .withExpr("dummy")
+                .build()
+        );
     }
     }
 
 
     @Test
     @Test