Browse Source

Support partition key (#492)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 2 years ago
parent
commit
512552aac4

+ 6 - 1
src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java

@@ -391,6 +391,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
                         .setFieldID(fieldID)
                         .setName(fieldType.getName())
                         .setIsPrimaryKey(fieldType.isPrimaryKey())
+                        .setIsPartitionKey(fieldType.isPartitionKey())
                         .setDescription(fieldType.getDescription())
                         .setDataType(fieldType.getDataType())
                         .setAutoID(fieldType.isAutoID());
@@ -406,7 +407,11 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
             }
 
             // Construct CreateCollectionRequest
-            CreateCollectionRequest createCollectionRequest = CreateCollectionRequest.newBuilder()
+            CreateCollectionRequest.Builder createCollectionBuilder = CreateCollectionRequest.newBuilder();
+            if (requestParam.getPartitionsNum() > 0) {
+                createCollectionBuilder.setNumPartitions(requestParam.getPartitionsNum());
+            }
+            CreateCollectionRequest createCollectionRequest = createCollectionBuilder
                     .setCollectionName(requestParam.getCollectionName())
                     .setShardsNum(requestParam.getShardsNum())
                     .setSchema(collectionSchemaBuilder.build().toByteString())

+ 31 - 0
src/main/java/io/milvus/param/collection/CreateCollectionParam.java

@@ -36,12 +36,14 @@ public class CreateCollectionParam {
     private final int shardsNum;
     private final String description;
     private final List<FieldType> fieldTypes;
+    private final int partitionsNum;
 
     private CreateCollectionParam(@NonNull Builder builder) {
         this.collectionName = builder.collectionName;
         this.shardsNum = builder.shardsNum;
         this.description = builder.description;
         this.fieldTypes = builder.fieldTypes;
+        this.partitionsNum = builder.partitionsNum;
     }
 
     public static Builder newBuilder() {
@@ -56,6 +58,7 @@ public class CreateCollectionParam {
         private int shardsNum = 2;
         private String description = "";
         private final List<FieldType> fieldTypes = new ArrayList<>();
+        private int partitionsNum = 0;
 
         private Builder() {
         }
@@ -117,6 +120,20 @@ public class CreateCollectionParam {
             return this;
         }
 
+        /**
+         * Sets the partitions number if there is partition key field. The number must be greater than zero.
+         * The default value is 64(defined in server side). The upper limit is 4096(defined in server side).
+         * Not allow to set this value if none of field is partition key.
+         * Only one partition key field is allowed in a collection.
+         *
+         * @param partitionsNum partitions number
+         * @return <code>Builder</code>
+         */
+        public Builder withPartitionsNum(int partitionsNum) {
+            this.partitionsNum = partitionsNum;
+            return this;
+        }
+
         /**
          * Verifies parameters and creates a new {@link CreateCollectionParam} instance.
          *
@@ -133,10 +150,24 @@ public class CreateCollectionParam {
                 throw new ParamException("Field numbers must be larger than 0");
             }
 
+            boolean hasPartitionKey = false;
             for (FieldType fieldType : fieldTypes) {
                 if (fieldType == null) {
                     throw new ParamException("Collection field cannot be null");
                 }
+
+                if (fieldType.isPartitionKey()) {
+                    if (hasPartitionKey) {
+                        throw new ParamException("Only one partition key field is allowed in a collection");
+                    }
+                    hasPartitionKey = true;
+                }
+            }
+
+            if (partitionsNum > 0) {
+                if (!hasPartitionKey) {
+                    throw new ParamException("None of fields is partition key, not allow to define partition number");
+                }
             }
 
             return new CreateCollectionParam(this);

+ 30 - 1
src/main/java/io/milvus/param/collection/FieldType.java

@@ -41,6 +41,7 @@ public class FieldType {
     private final DataType dataType;
     private final Map<String,String> typeParams;
     private final boolean autoID;
+    private final boolean partitionKey;
 
     private FieldType(@NonNull Builder builder){
         this.name = builder.name;
@@ -49,6 +50,7 @@ public class FieldType {
         this.dataType = builder.dataType;
         this.typeParams = builder.typeParams;
         this.autoID = builder.autoID;
+        this.partitionKey = builder.partitionKey;
     }
 
     public int getDimension() {
@@ -81,6 +83,7 @@ public class FieldType {
         private DataType dataType;
         private final Map<String,String> typeParams = new HashMap<>();
         private boolean autoID = false;
+        private boolean partitionKey = false;
 
         private Builder() {
         }
@@ -184,6 +187,20 @@ public class FieldType {
             return this;
         }
 
+        /**
+         * Sets the field to be partition key.
+         * A partition key field's values are hashed and distributed to different logic partitions.
+         * Only int64 and varchar type field can be partition key.
+         * Primary key field cannot be partition key.
+         *
+         * @param partitionKey true is partition key, false is not
+         * @return <code>Builder</code>
+         */
+        public Builder withPartitionKey(boolean partitionKey) {
+            this.partitionKey = partitionKey;
+            return this;
+        }
+
         /**
          * Verifies parameters and creates a new {@link FieldType} instance.
          *
@@ -197,7 +214,7 @@ public class FieldType {
             }
 
             if (dataType == DataType.String) {
-                throw new ParamException("String type is not supported, use VarChar instead");
+                throw new ParamException("String type is not supported, use Varchar instead");
             }
 
             if (dataType == DataType.FloatVector || dataType == DataType.BinaryVector) {
@@ -230,6 +247,17 @@ public class FieldType {
                 }
             }
 
+            // verify partition key
+            if (partitionKey) {
+                if (primaryKey) {
+                    throw new ParamException("Primary key field can not be partition key");
+                }
+                if (dataType != DataType.Int64 && dataType != DataType.VarChar) {
+                    throw new ParamException("Only Int64 and Varchar type field can be partition key");
+                }
+            }
+
+
             return new FieldType(this);
         }
     }
@@ -245,6 +273,7 @@ public class FieldType {
                 "name='" + name + '\'' +
                 ", type='" + dataType.name() + '\'' +
                 ", primaryKey=" + primaryKey +
+                ", partitionKey=" + partitionKey +
                 ", autoID=" + autoID +
                 ", params=" + typeParams +
                 '}';

+ 1 - 1
src/main/milvus-proto

@@ -1 +1 @@
-Subproject commit 5bbe6698c2b017a4762e0352121f4df7a15c565d
+Subproject commit 2975bfe7a190d38d1d44d9f1379483e93b1fbb4a

+ 80 - 10
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -274,7 +274,7 @@ class MilvusServiceClientTest {
 
     @Test
     void createCollectionParam() {
-        // test throw exception with illegal input
+        // test throw exception with illegal input for FieldType
         assertThrows(ParamException.class, () ->
                 FieldType.newBuilder()
                         .withName("")
@@ -295,6 +295,54 @@ class MilvusServiceClientTest {
                         .build()
         );
 
+        assertThrows(ParamException.class, () ->
+                FieldType.newBuilder()
+                        .withName("userID")
+                        .withDataType(DataType.Int64)
+                        .withPrimaryKey(true)
+                        .withPartitionKey(true)
+                        .build()
+        );
+
+        assertThrows(ParamException.class, () ->
+                FieldType.newBuilder()
+                        .withName("userID")
+                        .withDataType(DataType.FloatVector)
+                        .withPartitionKey(true)
+                        .build()
+        );
+
+        assertDoesNotThrow(() ->
+                FieldType.newBuilder()
+                        .withName("partitionKey")
+                        .withDataType(DataType.Int64)
+                        .withPartitionKey(true)
+                        .build()
+        );
+
+        assertDoesNotThrow(() ->
+                FieldType.newBuilder()
+                        .withName("partitionKey")
+                        .withDataType(DataType.VarChar)
+                        .withMaxLength(120)
+                        .withPartitionKey(true)
+                        .build()
+        );
+
+        Map<String, String> params = new HashMap<>();
+        params.put("1", "1");
+        assertThrows(ParamException.class, () ->
+                FieldType.newBuilder()
+                        .withName("vec")
+                        .withDescription("desc")
+                        .withDataType(DataType.FloatVector)
+                        .withTypeParams(params)
+                        .addTypeParam("2", "2")
+                        .withDimension(-1)
+                        .build()
+        );
+
+        // test throw exception with illegal input for CreateCollectionParam
         assertThrows(ParamException.class, () ->
                 CreateCollectionParam
                         .newBuilder()
@@ -339,16 +387,38 @@ class MilvusServiceClientTest {
                         .build()
         );
 
-        Map<String, String> params = new HashMap<>();
-        params.put("1", "1");
         assertThrows(ParamException.class, () ->
-                FieldType.newBuilder()
-                        .withName("vec")
-                        .withDescription("desc")
-                        .withDataType(DataType.FloatVector)
-                        .withTypeParams(params)
-                        .addTypeParam("2", "2")
-                        .withDimension(-1)
+                CreateCollectionParam
+                        .newBuilder()
+                        .withCollectionName("collection1")
+                        .withShardsNum(0)
+                        .withPartitionsNum(10)
+                        .addFieldType(fieldType1)
+                        .build()
+        );
+
+        FieldType fieldType2 = FieldType.newBuilder()
+                .withName("partitionKey")
+                .withDataType(DataType.Int64)
+                .withPartitionKey(true)
+                .build();
+
+        assertDoesNotThrow(() ->
+                CreateCollectionParam
+                        .newBuilder()
+                        .withCollectionName("collection1")
+                        .addFieldType(fieldType1)
+                        .addFieldType(fieldType2)
+                        .build()
+        );
+
+        assertDoesNotThrow(() ->
+                CreateCollectionParam
+                        .newBuilder()
+                        .withCollectionName("collection1")
+                        .withPartitionsNum(100)
+                        .addFieldType(fieldType1)
+                        .addFieldType(fieldType2)
                         .build()
         );
     }