Browse Source

Verify index parameter (#542)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 2 years ago
parent
commit
aa669f5443

+ 39 - 0
src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java

@@ -1261,6 +1261,45 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
         logInfo(requestParam.toString());
 
         try {
+            // get collection schema to check input
+            DescribeCollectionParam.Builder descBuilder = DescribeCollectionParam.newBuilder()
+                    .withDatabaseName(requestParam.getDatabaseName())
+                    .withCollectionName(requestParam.getCollectionName());
+            R<DescribeCollectionResponse> descResp = describeCollection(descBuilder.build());
+
+            if (descResp.getStatus() != R.Status.Success.getCode()) {
+                logError("Failed to describe collection: {}", requestParam.getCollectionName());
+                return R.failed(R.Status.valueOf(descResp.getStatus()), descResp.getMessage());
+            }
+
+            DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
+            List<FieldType> fields = wrapper.getFields();
+            // check field existence and index_type/field_type must be matched
+            boolean fieldExists = false;
+            boolean validType = false;
+            for (FieldType field : fields) {
+                if (requestParam.getFieldName().equals(field.getName())) {
+                    fieldExists = true;
+                    if (ParamUtils.VerifyIndexType(requestParam.getIndexType(), field.getDataType())) {
+                        validType = true;
+                    }
+                    break;
+                }
+            }
+
+            if (!fieldExists) {
+                String msg = String.format("Field '%s' doesn't exist in the collection", requestParam.getFieldName());
+                logError("CreateIndexRequest failed! {}\n", msg);
+                return R.failed(R.Status.IllegalArgument, msg);
+            }
+            if (!validType) {
+                String msg = String.format("Index type '%s' doesn't match with data type of field '%s'",
+                        requestParam.getIndexType().name(), requestParam.getFieldName());
+                logError("CreateIndexRequest failed! {}\n", msg);
+                return R.failed(R.Status.IllegalArgument, msg);
+            }
+
+            // prepare index parameters
             CreateIndexRequest.Builder createIndexRequestBuilder = CreateIndexRequest.newBuilder();
             List<KeyValuePair> extraParamList = ParamUtils.AssembleKvPair(requestParam.getExtraParam());
             if (CollectionUtils.isNotEmpty(extraParamList)) {

+ 3 - 0
src/main/java/io/milvus/param/IndexType.java

@@ -25,6 +25,7 @@ package io.milvus.param;
  */
 public enum IndexType {
     INVALID,
+    //Only supported for float vectors
     FLAT,
     IVF_FLAT,
     IVF_SQ8,
@@ -42,5 +43,7 @@ public enum IndexType {
 
     //Only for varchar type field
     TRIE,
+    //Only for scalar type field
+    SORT,
     ;
 }

+ 21 - 2
src/main/java/io/milvus/param/ParamUtils.java

@@ -15,6 +15,7 @@ import io.milvus.param.dml.QueryParam;
 import io.milvus.param.dml.SearchParam;
 import io.milvus.response.DescCollResponseWrapper;
 import lombok.Builder;
+import lombok.Data;
 import lombok.Getter;
 import lombok.NonNull;
 import org.apache.commons.collections4.CollectionUtils;
@@ -204,12 +205,30 @@ public class ParamUtils {
     }
 
     /**
-     * Checks if an index type is for vector.
+     * Checks if an index type is for vector field.
      *
      * @param idx index type
      */
     public static boolean IsVectorIndex(IndexType idx) {
-        return idx != IndexType.INVALID && idx != IndexType.TRIE;
+        return idx != IndexType.INVALID && idx != IndexType.TRIE && idx != IndexType.SORT;
+    }
+
+    /**
+     * Checks if an index type is matched with data type.
+     *
+     * @param indexType index type
+     * @param dataType data type
+     */
+    public static boolean VerifyIndexType(IndexType indexType, DataType dataType) {
+        if (dataType == DataType.FloatVector) {
+            return (IsVectorIndex(indexType) && (indexType != IndexType.BIN_FLAT) && (indexType != IndexType.BIN_IVF_FLAT));
+        } else if (dataType == DataType.BinaryVector) {
+            return indexType == IndexType.BIN_FLAT || indexType == IndexType.BIN_IVF_FLAT;
+        } else if (dataType == DataType.VarChar) {
+            return indexType == IndexType.TRIE;
+        } else {
+            return indexType == IndexType.SORT;
+        }
     }
 
     public static InsertRequest convertInsertParam(@NonNull InsertParam requestParam,

+ 2 - 0
src/main/java/io/milvus/param/index/CreateIndexParam.java

@@ -43,6 +43,7 @@ public class CreateIndexParam {
     private final String collectionName;
     private final String fieldName;
     private final String indexName;
+    private final IndexType indexType; // for easily get to check with field type
     private final Map<String, String> extraParam = new HashMap<>();
     private final boolean syncMode;
     private final long syncWaitingInterval;
@@ -53,6 +54,7 @@ public class CreateIndexParam {
         this.collectionName = builder.collectionName;
         this.fieldName = builder.fieldName;
         this.indexName = builder.indexName;
+        this.indexType = builder.indexType;
         if (builder.indexType != IndexType.INVALID) {
             this.extraParam.put(Constant.INDEX_TYPE, builder.indexType.name());
         }

+ 42 - 2
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -1325,11 +1325,51 @@ class MilvusServiceClientTest {
         MockMilvusServer server = startServer();
         MilvusServiceClient client = startClient();
 
+        // createIndex() calls describeCollection() to check input
+        CollectionSchema schema = CollectionSchema.newBuilder()
+                .addFields(FieldSchema.newBuilder()
+                        .setName("field1")
+                        .setDataType(DataType.FloatVector)
+                        .addTypeParams(KeyValuePair.newBuilder().setKey(Constant.VECTOR_DIM).setValue("256").build())
+                        .build())
+                .build();
+        mockServerImpl.setDescribeCollectionResponse(DescribeCollectionResponse.newBuilder().setSchema(schema).build());
+
         // test return ok for sync mode loading
         mockServerImpl.setDescribeIndexResponse(DescribeIndexResponse.newBuilder()
                 .addIndexDescriptions(IndexDescription.newBuilder().setState(IndexState.InProgress).build())
                 .build());
 
+        // field doesn't exist
+        CreateIndexParam param = CreateIndexParam.newBuilder()
+                .withCollectionName("collection1")
+                .withFieldName("aaa")
+                .withIndexType(IndexType.IVF_FLAT)
+                .withMetricType(MetricType.L2)
+                .withExtraParam("dummy")
+                .withSyncMode(Boolean.TRUE)
+                .withSyncWaitingInterval(500L)
+                .withSyncWaitingTimeout(2L)
+                .build();
+
+        R<RpcStatus> resp = client.createIndex(param);
+        assertNotEquals(R.Status.Success.getCode(), resp.getStatus());
+
+        // index type doesn't match with data type
+        param = CreateIndexParam.newBuilder()
+                .withCollectionName("collection1")
+                .withFieldName("field1")
+                .withIndexType(IndexType.BIN_IVF_FLAT)
+                .withMetricType(MetricType.L2)
+                .withExtraParam("dummy")
+                .withSyncMode(Boolean.TRUE)
+                .withSyncWaitingInterval(500L)
+                .withSyncWaitingTimeout(2L)
+                .build();
+
+        resp = client.createIndex(param);
+        assertNotEquals(R.Status.Success.getCode(), resp.getStatus());
+
         new Thread(() -> {
             try {
                 TimeUnit.SECONDS.sleep(1);
@@ -1343,7 +1383,7 @@ class MilvusServiceClientTest {
             }
         }, "RefreshIndexState").start();
 
-        CreateIndexParam param = CreateIndexParam.newBuilder()
+        param = CreateIndexParam.newBuilder()
                 .withCollectionName("collection1")
                 .withFieldName("field1")
                 .withIndexType(IndexType.IVF_FLAT)
@@ -1355,7 +1395,7 @@ class MilvusServiceClientTest {
                 .build();
 
         // test return ok with correct input
-        R<RpcStatus> resp = client.createIndex(param);
+        resp = client.createIndex(param);
         assertEquals(R.Status.Success.getCode(), resp.getStatus());
 
         // stop mock server