Browse Source

search interface supprt not specify metricType (#777)

Signed-off-by: lentitude2tk <xushuang.hu@zilliz.com>
xushuang.hu 1 year ago
parent
commit
876435badb

+ 1 - 1
src/main/java/io/milvus/param/MetricType.java

@@ -25,7 +25,7 @@ package io.milvus.param;
  */
  */
 public enum MetricType {
 public enum MetricType {
     None,
     None,
-    INVALID,
+
     // Only for float vectors
     // Only for float vectors
     L2,
     L2,
     IP,
     IP,

+ 8 - 23
src/main/java/io/milvus/param/ParamUtils.java

@@ -197,24 +197,6 @@ public class ParamUtils {
         }
         }
     }
     }
 
 
-    /**
-     * Checks if a metric is for float vector.
-     *
-     * @param metric metric type
-     */
-    public static boolean IsFloatMetric(MetricType metric) {
-        return metric == MetricType.L2 || metric == MetricType.IP || metric == MetricType.COSINE;
-    }
-
-    /**
-     * Checks if a metric is for binary vector.
-     *
-     * @param metric metric type
-     */
-    public static boolean IsBinaryMetric(MetricType metric) {
-        return metric != MetricType.INVALID && !IsFloatMetric(metric);
-    }
-
     public static class InsertBuilderWrapper {
     public static class InsertBuilderWrapper {
         private InsertRequest.Builder insertBuilder;
         private InsertRequest.Builder insertBuilder;
         private UpsertRequest.Builder upsertBuilder;
         private UpsertRequest.Builder upsertBuilder;
@@ -483,11 +465,6 @@ public class ParamUtils {
                                 .setKey(Constant.TOP_K)
                                 .setKey(Constant.TOP_K)
                                 .setValue(String.valueOf(requestParam.getTopK()))
                                 .setValue(String.valueOf(requestParam.getTopK()))
                                 .build())
                                 .build())
-                .addSearchParams(
-                        KeyValuePair.newBuilder()
-                                .setKey(Constant.METRIC_TYPE)
-                                .setValue(requestParam.getMetricType())
-                                .build())
                 .addSearchParams(
                 .addSearchParams(
                         KeyValuePair.newBuilder()
                         KeyValuePair.newBuilder()
                                 .setKey(Constant.ROUND_DECIMAL)
                                 .setKey(Constant.ROUND_DECIMAL)
@@ -499,6 +476,14 @@ public class ParamUtils {
                                 .setValue(String.valueOf(requestParam.isIgnoreGrowing()))
                                 .setValue(String.valueOf(requestParam.isIgnoreGrowing()))
                                 .build());
                                 .build());
 
 
+        if (!Objects.equals(requestParam.getMetricType(), MetricType.None.name())) {
+            builder.addSearchParams(
+                    KeyValuePair.newBuilder()
+                            .setKey(Constant.METRIC_TYPE)
+                            .setValue(requestParam.getMetricType())
+                            .build());
+        }
+
         if (null != requestParam.getParams() && !requestParam.getParams().isEmpty()) {
         if (null != requestParam.getParams() && !requestParam.getParams().isEmpty()) {
             try {
             try {
             Map<String, Object> paramMap = JacksonUtils.fromJson(requestParam.getParams(),Map.class);
             Map<String, Object> paramMap = JacksonUtils.fromJson(requestParam.getParams(),Map.class);

+ 0 - 4
src/main/java/io/milvus/param/QueryNodeSingleSearch.java

@@ -131,10 +131,6 @@ public class QueryNodeSingleSearch {
             ParamUtils.CheckNullEmptyString(collectionName, "Collection name");
             ParamUtils.CheckNullEmptyString(collectionName, "Collection name");
             ParamUtils.CheckNullEmptyString(vectorFieldName, "Target field name");
             ParamUtils.CheckNullEmptyString(vectorFieldName, "Target field name");
 
 
-            if (metricType == MetricType.INVALID) {
-                throw new ParamException("Metric type is illegal");
-            }
-
             if (vectors == null || vectors.isEmpty()) {
             if (vectors == null || vectors.isEmpty()) {
                 throw new ParamException("Target vectors can not be empty");
                 throw new ParamException("Target vectors can not be empty");
             }
             }

+ 1 - 15
src/main/java/io/milvus/param/dml/SearchParam.java

@@ -82,7 +82,7 @@ public class SearchParam {
     public static class Builder {
     public static class Builder {
         private String collectionName;
         private String collectionName;
         private final List<String> partitionNames = Lists.newArrayList();
         private final List<String> partitionNames = Lists.newArrayList();
-        private MetricType metricType = MetricType.L2;
+        private MetricType metricType = MetricType.None;
         private String vectorFieldName;
         private String vectorFieldName;
         private Integer topK;
         private Integer topK;
         private String expr = "";
         private String expr = "";
@@ -287,10 +287,6 @@ public class SearchParam {
                 throw new ParamException("The guarantee timestamp must be greater than 0");
                 throw new ParamException("The guarantee timestamp must be greater than 0");
             }
             }
 
 
-            if (metricType == MetricType.INVALID) {
-                throw new ParamException("Metric type is invalid");
-            }
-
             if (vectors == null || vectors.isEmpty()) {
             if (vectors == null || vectors.isEmpty()) {
                 throw new ParamException("Target vectors can not be empty");
                 throw new ParamException("Target vectors can not be empty");
             }
             }
@@ -309,11 +305,6 @@ public class SearchParam {
                         throw new ParamException("Target vector dimension must be equal");
                         throw new ParamException("Target vector dimension must be equal");
                     }
                     }
                 }
                 }
-
-                // check metric type
-                if (!ParamUtils.IsFloatMetric(metricType)) {
-                    throw new ParamException("Target vector is float but metric type is incorrect");
-                }
             } else if (vectors.get(0) instanceof ByteBuffer) {
             } else if (vectors.get(0) instanceof ByteBuffer) {
                 // binary vectors
                 // binary vectors
                 ByteBuffer first = (ByteBuffer) vectors.get(0);
                 ByteBuffer first = (ByteBuffer) vectors.get(0);
@@ -324,11 +315,6 @@ public class SearchParam {
                         throw new ParamException("Target vector dimension must be equal");
                         throw new ParamException("Target vector dimension must be equal");
                     }
                     }
                 }
                 }
-
-                // check metric type
-                if (!ParamUtils.IsBinaryMetric(metricType)) {
-                    throw new ParamException("Target vector is binary but metric type is incorrect");
-                }
             } else {
             } else {
                 throw new ParamException("Target vector type must be List<Float> or ByteBuffer");
                 throw new ParamException("Target vector type must be List<Float> or ByteBuffer");
             }
             }

+ 1 - 1
src/main/java/io/milvus/response/DescIndexResponseWrapper.java

@@ -93,7 +93,7 @@ public class DescIndexResponseWrapper {
                 return MetricType.valueOf(params.get(Constant.METRIC_TYPE));
                 return MetricType.valueOf(params.get(Constant.METRIC_TYPE));
             }
             }
 
 
-            return MetricType.INVALID;
+            return MetricType.None;
         }
         }
 
 
         public String getExtraParam() {
         public String getExtraParam() {

+ 3 - 45
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -1973,20 +1973,6 @@ class MilvusServiceClientTest {
                 .build()
                 .build()
         );
         );
 
 
-        // metric type is invalid
-        assertThrows(ParamException.class, () -> SearchParam.newBuilder()
-                .withCollectionName("collection1")
-                .withPartitionNames(partitions)
-                .withParams("{}")
-                .withOutFields(outputFields)
-                .withVectorFieldName("field1")
-                .withMetricType(MetricType.INVALID)
-                .withTopK(5)
-                .withVectors(vectors)
-                .withExpr("dummy")
-                .build()
-        );
-
         // illegal topk value
         // illegal topk value
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
         assertThrows(ParamException.class, () -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
@@ -2066,38 +2052,9 @@ class MilvusServiceClientTest {
                 .withExpr("dummy")
                 .withExpr("dummy")
                 .build()
                 .build()
         );
         );
-
-        // float vector metric type is illegal
-        List<List<Float>> vectors2 = Collections.singletonList(vector2);
-        assertThrows(ParamException.class, () -> SearchParam.newBuilder()
-                .withCollectionName("collection1")
-                .withPartitionNames(partitions)
-                .withParams("{}")
-                .withOutFields(outputFields)
-                .withVectorFieldName("field1")
-                .withMetricType(MetricType.JACCARD)
-                .withTopK(5)
-                .withVectors(vectors2)
-                .withExpr("dummy")
-                .build()
-        );
-
-        // binary vector metric type is illegal
-        List<ByteBuffer> binVectors2 = Collections.singletonList(buf2);
-        assertThrows(ParamException.class, () -> SearchParam.newBuilder()
-                .withCollectionName("collection1")
-                .withPartitionNames(partitions)
-                .withParams("{}")
-                .withOutFields(outputFields)
-                .withVectorFieldName("field1")
-                .withMetricType(MetricType.IP)
-                .withTopK(5)
-                .withVectors(binVectors2)
-                .withExpr("dummy")
-                .build()
-        );
-
+        
         // succeed float vector case
         // succeed float vector case
+        List<List<Float>> vectors2 = Collections.singletonList(vector2);
         assertDoesNotThrow(() -> SearchParam.newBuilder()
         assertDoesNotThrow(() -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
                 .withPartitionNames(partitions)
                 .withPartitionNames(partitions)
@@ -2113,6 +2070,7 @@ class MilvusServiceClientTest {
         );
         );
 
 
         // succeed binary vector case
         // succeed binary vector case
+        List<ByteBuffer> binVectors2 = Collections.singletonList(buf2);
         assertDoesNotThrow(() -> SearchParam.newBuilder()
         assertDoesNotThrow(() -> SearchParam.newBuilder()
                 .withCollectionName("collection1")
                 .withCollectionName("collection1")
                 .withPartitionNames(partitions)
                 .withPartitionNames(partitions)