Browse Source

Fix a bug of index extra parameters (#1095)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 6 months ago
parent
commit
00ff0faa1f

+ 22 - 20
src/main/java/io/milvus/v2/service/index/IndexService.java

@@ -19,7 +19,9 @@
 
 package io.milvus.v2.service.index;
 
+import com.google.gson.JsonObject;
 import io.milvus.grpc.*;
+import io.milvus.param.Constant;
 import io.milvus.param.ParamUtils;
 import io.milvus.v2.common.IndexParam;
 import io.milvus.v2.exception.ErrorCode;
@@ -32,6 +34,7 @@ import org.apache.commons.lang3.StringUtils;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 import java.util.stream.Collectors;
 
 public class IndexService extends BaseService {
@@ -40,36 +43,35 @@ public class IndexService extends BaseService {
         for(IndexParam indexParam : request.getIndexParams()) {
             String title = String.format("CreateIndexRequest collectionName:%s, fieldName:%s",
                     request.getCollectionName(), indexParam.getFieldName());
-            CreateIndexRequest createIndexRequest = CreateIndexRequest.newBuilder()
-                    .setCollectionName(request.getCollectionName())
+            CreateIndexRequest.Builder builder = CreateIndexRequest.newBuilder();
+            builder.setCollectionName(request.getCollectionName())
                     .setIndexName(indexParam.getIndexName())
                     .setFieldName(indexParam.getFieldName())
                     .addExtraParams(KeyValuePair.newBuilder()
-                            .setKey("index_type")
+                            .setKey(Constant.INDEX_TYPE)
                             .setValue(indexParam.getIndexType().getName())
-                            .build())
-                    .build();
+                            .build());
             if(indexParam.getMetricType()!= null){
                 // only vector field has a metric type
-                createIndexRequest = createIndexRequest.toBuilder()
-                        .addExtraParams(KeyValuePair.newBuilder()
-                                .setKey("metric_type")
-                                .setValue(indexParam.getMetricType().name())
-                                .build())
-                        .build();
+                builder.addExtraParams(KeyValuePair.newBuilder()
+                        .setKey(Constant.METRIC_TYPE)
+                        .setValue(indexParam.getMetricType().name())
+                        .build());
             }
-            if (indexParam.getExtraParams() != null) {
-                for (String key : indexParam.getExtraParams().keySet()) {
-                    createIndexRequest = createIndexRequest.toBuilder()
-                            .addExtraParams(KeyValuePair.newBuilder()
-                                    .setKey(key)
-                                    .setValue(String.valueOf(indexParam.getExtraParams().get(key)))
-                                    .build())
-                            .build();
+            Map<String, Object> extraParams = indexParam.getExtraParams();
+            if (extraParams != null && !extraParams.isEmpty()) {
+                JsonObject params = new JsonObject();
+                for (String key : extraParams.keySet()) {
+                    params.addProperty(key, extraParams.get(key).toString());
                 }
+                // the extra params is a JSON format string like "{\"M\": 8, \"efConstruction\": 64}"
+                builder.addExtraParams(KeyValuePair.newBuilder()
+                        .setKey(Constant.PARAMS)
+                        .setValue(params.toString())
+                        .build());
             }
 
-            Status status = blockingStub.createIndex(createIndexRequest);
+            Status status = blockingStub.createIndex(builder.build());
             rpcUtils.handleResponse(title, status);
         }
 

+ 2 - 2
src/main/java/io/milvus/v2/service/index/response/DescribeIndexResp.java

@@ -19,8 +19,6 @@
 
 package io.milvus.v2.service.index.response;
 
-import io.milvus.grpc.IndexDescription;
-import io.milvus.response.DescIndexResponseWrapper;
 import io.milvus.v2.common.IndexBuildState;
 import io.milvus.v2.common.IndexParam;
 import lombok.Builder;
@@ -79,5 +77,7 @@ public class DescribeIndexResp {
         private IndexBuildState indexState = IndexBuildState.IndexStateNone;
         @Builder.Default
         String indexFailedReason = "";
+        @Builder.Default
+        private Map<String, String> properties = new HashMap<>();
     }
 }

+ 8 - 2
src/main/java/io/milvus/v2/utils/ConvertUtils.java

@@ -19,6 +19,8 @@
 
 package io.milvus.v2.utils;
 
+import com.google.gson.Gson;
+import com.google.gson.reflect.TypeToken;
 import io.milvus.grpc.*;
 import io.milvus.param.Constant;
 import io.milvus.param.ParamUtils;
@@ -86,6 +88,7 @@ public class ConvertUtils {
             List<KeyValuePair> params = description.getParamsList();
             IndexParam.IndexType indexType = IndexParam.IndexType.None;
             IndexParam.MetricType metricType = IndexParam.MetricType.INVALID;
+            Map<String, String> properties = new HashMap<>();
             for(KeyValuePair param : params) {
                 if (param.getKey().equals(Constant.INDEX_TYPE)) {
                     // may throw IllegalArgumentException
@@ -93,8 +96,10 @@ public class ConvertUtils {
                 } else if (param.getKey().equals(Constant.METRIC_TYPE)) {
                     // may throw IllegalArgumentException
                     metricType = IndexParam.MetricType.valueOf(param.getValue());
-                } else {
-                    extraParams.put(param.getKey(), param.getValue());
+                } else if (param.getKey().equals(Constant.MMAP_ENABLED)) {
+                    properties.put(param.getKey(), param.getValue());
+                } else if (param.getKey().equals(Constant.PARAMS)) {
+                    extraParams = new Gson().fromJson(param.getValue(), new TypeToken<Map<String, String>>() {}.getType());
                 }
             }
 
@@ -110,6 +115,7 @@ public class ConvertUtils {
                     .indexState(IndexBuildState.valueOf(description.getState().name()))
                     .indexFailedReason(description.getIndexStateFailReason())
                     .extraParams(extraParams)
+                    .properties(properties)
                     .build();
             descs.add(desc);
         }

+ 41 - 9
src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -1199,9 +1199,34 @@ class MilvusClientV2DockerTest {
     @Test
     void testIndex() {
         String randomCollectionName = generator.generate(10);
+
+        CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("id")
+                .dataType(DataType.Int64)
+                .isPrimaryKey(Boolean.TRUE)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("vector")
+                .dataType(DataType.FloatVector)
+                .dimension(dimension)
+                .build());
+
+        List<IndexParam> indexes = new ArrayList<>();
+        Map<String,Object> extra = new HashMap<>();
+        extra.put("M",8);
+        extra.put("efConstruction",64);
+        indexes.add(IndexParam.builder()
+                .fieldName("vector")
+                .indexType(IndexParam.IndexType.HNSW)
+                .metricType(IndexParam.MetricType.COSINE)
+                .extraParams(extra)
+                .build());
         client.createCollection(CreateCollectionReq.builder()
                 .collectionName(randomCollectionName)
-                .dimension(dimension)
+                .collectionSchema(collectionSchema)
+                .indexParams(indexes)
                 .build());
 
         client.releaseCollection(ReleaseCollectionReq.builder()
@@ -1234,7 +1259,12 @@ class MilvusClientV2DockerTest {
         DescribeIndexResp.IndexDesc desc = descResp.getIndexDescByFieldName("vector");
         Assertions.assertEquals("vector", desc.getFieldName());
         Assertions.assertFalse(desc.getIndexName().isEmpty());
-        Assertions.assertEquals(IndexParam.IndexType.AUTOINDEX, desc.getIndexType());
+        Assertions.assertEquals(IndexParam.IndexType.HNSW, desc.getIndexType());
+        Map<String, String> extraParams = desc.getExtraParams();
+        Assertions.assertTrue(extraParams.containsKey("M"));
+        Assertions.assertEquals("8", extraParams.get("M"));
+        Assertions.assertTrue(extraParams.containsKey("efConstruction"));
+        Assertions.assertEquals("64", extraParams.get("efConstruction"));
 
         properties.clear();
         properties.put(Constant.MMAP_ENABLED, "false");
@@ -1249,7 +1279,7 @@ class MilvusClientV2DockerTest {
                 .fieldName("vector")
                 .build());
         desc = descResp.getIndexDescByFieldName("vector");
-        Map<String, String> indexProps = desc.getExtraParams();
+        Map<String, String> indexProps = desc.getProperties();
         Assertions.assertTrue(indexProps.containsKey(Constant.MMAP_ENABLED));
         Assertions.assertEquals("false", indexProps.get(Constant.MMAP_ENABLED));
 
@@ -1261,9 +1291,9 @@ class MilvusClientV2DockerTest {
         IndexParam param = IndexParam.builder()
                 .fieldName("vector")
                 .indexName("XXX")
-                .indexType(IndexParam.IndexType.IVF_FLAT)
+                .indexType(IndexParam.IndexType.HNSW)
                 .metricType(IndexParam.MetricType.COSINE)
-                .extraParams(new HashMap<String,Object>(){{put("nlist", 64);}})
+                .extraParams(extra)
                 .build();
 
         client.createIndex(CreateIndexReq.builder()
@@ -1279,11 +1309,13 @@ class MilvusClientV2DockerTest {
         desc = descResp.getIndexDescByFieldName("vector");
         Assertions.assertEquals("vector", desc.getFieldName());
         Assertions.assertEquals("XXX", desc.getIndexName());
-        Assertions.assertEquals(IndexParam.IndexType.IVF_FLAT, desc.getIndexType());
+        Assertions.assertEquals(IndexParam.IndexType.HNSW, desc.getIndexType());
         Assertions.assertEquals(IndexParam.MetricType.COSINE, desc.getMetricType());
-        Map<String, String> extraParams = desc.getExtraParams();
-        Assertions.assertTrue(extraParams.containsKey("nlist"));
-        Assertions.assertEquals("64", extraParams.get("nlist"));
+        extraParams = desc.getExtraParams();
+        Assertions.assertTrue(extraParams.containsKey("M"));
+        Assertions.assertEquals("8", extraParams.get("M"));
+        Assertions.assertTrue(extraParams.containsKey("efConstruction"));
+        Assertions.assertEquals("64", extraParams.get("efConstruction"));
     }
 
     @Test