Browse Source

Add COSINE metric type (#616)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 year ago
parent
commit
70e108a7d5

+ 2 - 2
docker-compose.yml

@@ -32,7 +32,7 @@ services:
 
 
   standalone:
   standalone:
     container_name: milvus-javasdk-test-standalone
     container_name: milvus-javasdk-test-standalone
-    image: milvusdb/milvus:master-20230815-ec65a4e0
+    image: milvusdb/milvus:v2.3.0
     command: ["milvus", "run", "standalone"]
     command: ["milvus", "run", "standalone"]
     environment:
     environment:
       ETCD_ENDPOINTS: etcd:2379
       ETCD_ENDPOINTS: etcd:2379
@@ -77,7 +77,7 @@ services:
 
 
   standaloneslave:
   standaloneslave:
     container_name: milvus-javasdk-test-slave-standalone
     container_name: milvus-javasdk-test-slave-standalone
-    image: milvusdb/milvus:master-20230815-ec65a4e0
+    image: milvusdb/milvus:v2.3.0
     command: ["milvus", "run", "standalone"]
     command: ["milvus", "run", "standalone"]
     environment:
     environment:
       ETCD_ENDPOINTS: etcdslave:2379
       ETCD_ENDPOINTS: etcdslave:2379

+ 4 - 1
src/main/java/io/milvus/param/MetricType.java

@@ -25,9 +25,12 @@ package io.milvus.param;
  */
  */
 public enum MetricType {
 public enum MetricType {
     INVALID,
     INVALID,
+    // Only for float vectors
     L2,
     L2,
     IP,
     IP,
-    // Only supported for binary vectors
+    COSINE,
+
+    // Only for binary vectors
     HAMMING,
     HAMMING,
     JACCARD,
     JACCARD,
     ;
     ;

+ 2 - 2
src/main/java/io/milvus/param/ParamUtils.java

@@ -188,7 +188,7 @@ public class ParamUtils {
      * @param metric metric type
      * @param metric metric type
      */
      */
     public static boolean IsFloatMetric(MetricType metric) {
     public static boolean IsFloatMetric(MetricType metric) {
-        return metric == MetricType.L2 || metric == MetricType.IP;
+        return metric == MetricType.L2 || metric == MetricType.IP || metric == MetricType.COSINE;
     }
     }
 
 
     /**
     /**
@@ -197,7 +197,7 @@ public class ParamUtils {
      * @param metric metric type
      * @param metric metric type
      */
      */
     public static boolean IsBinaryMetric(MetricType metric) {
     public static boolean IsBinaryMetric(MetricType metric) {
-        return !IsFloatMetric(metric);
+        return metric != MetricType.INVALID && !IsFloatMetric(metric);
     }
     }
 
 
     /**
     /**

+ 8 - 4
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -1366,7 +1366,7 @@ class MilvusClientDockerTest {
                 .withFieldName(field2Name)
                 .withFieldName(field2Name)
                 .withIndexName("abv")
                 .withIndexName("abv")
                 .withIndexType(IndexType.FLAT)
                 .withIndexType(IndexType.FLAT)
-                .withMetricType(MetricType.L2)
+                .withMetricType(MetricType.COSINE)
                 .withExtraParam("{}")
                 .withExtraParam("{}")
                 .build();
                 .build();
 
 
@@ -1466,12 +1466,14 @@ class MilvusClientDockerTest {
             System.out.println("'extra_meta' is from dynamic field, value: " + extraMeta);
             System.out.println("'extra_meta' is from dynamic field, value: " + extraMeta);
         }
         }
 
 
-        // search
-        List<List<Float>> targetVectors = generateFloatVectors(2);
+        // search the No.11 and No.15
+        List<List<Float>> targetVectors = new ArrayList<>();
+        targetVectors.add(vectors.get(1));
+        targetVectors.add(vectors.get(5));
         int topK = 5;
         int topK = 5;
         SearchParam searchParam = SearchParam.newBuilder()
         SearchParam searchParam = SearchParam.newBuilder()
                 .withCollectionName(randomCollectionName)
                 .withCollectionName(randomCollectionName)
-                .withMetricType(MetricType.L2)
+                .withMetricType(MetricType.COSINE)
                 .withTopK(topK)
                 .withTopK(topK)
                 .withVectors(targetVectors)
                 .withVectors(targetVectors)
                 .withVectorFieldName(field2Name)
                 .withVectorFieldName(field2Name)
@@ -1495,6 +1497,8 @@ class MilvusClientDockerTest {
                 }
                 }
             }
             }
         }
         }
+        Assertions.assertEquals(results.getIDScore(0).get(0).getLongID(), 11L);
+        Assertions.assertEquals(results.getIDScore(1).get(0).getLongID(), 15L);
 
 
         // drop collection
         // drop collection
         R<RpcStatus> dropR = client.dropCollection(DropCollectionParam.newBuilder()
         R<RpcStatus> dropR = client.dropCollection(DropCollectionParam.newBuilder()