فهرست منبع

Support string field (#294)

Signed-off-by: groot <yihua.mo@zilliz.com>
groot 3 سال پیش
والد
کامیت
60f955b567

+ 1 - 1
docker-compose.yml

@@ -31,7 +31,7 @@ services:
 
   standalone:
     container_name: milvus-javasdk-test-standalone
-    image: milvusdb/milvus:v2.0.0
+    image: milvusdb/milvus-dev:master-20220515-338edcd3
     command: ["milvus", "run", "standalone"]
     environment:
       ETCD_ENDPOINTS: etcd:2379

+ 1 - 0
src/main/java/io/milvus/param/Constant.java

@@ -27,6 +27,7 @@ public class Constant {
     public static final String VECTOR_TAG = "$0";
     public static final String VECTOR_FIELD = "anns_field";
     public static final String VECTOR_DIM = "dim";
+    public static final String VARCHAR_MAX_LENGTH = "max_length_per_row";
     public static final String TOP_K = "topk";
     public static final String INDEX_TYPE = "index_type";
     public static final String METRIC_TYPE = "metric_type";

+ 1 - 0
src/main/java/io/milvus/param/ParamUtils.java

@@ -284,6 +284,7 @@ public class ParamUtils {
                     ScalarField scalarField5 = ScalarField.newBuilder().setDoubleData(doubleArray).build();
                     return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField5).build();
                 case String:
+                case VarChar:
                     List<String> strings = objects.stream().map(p -> (String) p).collect(Collectors.toList());
                     StringArray stringArray = StringArray.newBuilder().addAllData(strings).build();
                     ScalarField scalarField6 = ScalarField.newBuilder().setStringData(stringArray).build();

+ 34 - 0
src/main/java/io/milvus/param/collection/FieldType.java

@@ -59,6 +59,14 @@ public class FieldType {
         return 0;
     }
 
+    public int getMaxLength() {
+        if (typeParams.containsKey(Constant.VARCHAR_MAX_LENGTH)) {
+            return Integer.parseInt(typeParams.get(Constant.VARCHAR_MAX_LENGTH));
+        }
+
+        return 0;
+    }
+
     public static Builder newBuilder() {
         return new Builder();
     }
@@ -150,6 +158,17 @@ public class FieldType {
             return this;
         }
 
+        /**
+         * Sets the max length of a varchar field. The value must be greater than zero.
+         *
+         * @param maxLength max length of a varchar field
+         * @return <code>Builder</code>
+         */
+        public Builder withMaxLength(@NonNull Integer maxLength) {
+            this.typeParams.put(Constant.VARCHAR_MAX_LENGTH, maxLength.toString());
+            return this;
+        }
+
         /**
          * Enables auto-id function for the field. Note that the auto-id function can only be enabled on primary key field.
          * If auto-id function is enabled, Milvus will automatically generate unique ID for each entity,
@@ -192,6 +211,21 @@ public class FieldType {
                 }
             }
 
+            if (dataType == DataType.VarChar || dataType == DataType.String) {
+                if (!typeParams.containsKey(Constant.VARCHAR_MAX_LENGTH)) {
+                    throw new ParamException("Varchar field max length must be specified");
+                }
+
+                try {
+                    int len = Integer.parseInt(typeParams.get(Constant.VARCHAR_MAX_LENGTH));
+                    if (len <= 0) {
+                        throw new ParamException("Varchar field max length must be larger than zero");
+                    }
+                } catch (NumberFormatException e) {
+                    throw new ParamException("Varchar field max length must be an integer number");
+                }
+            }
+
             return new FieldType(this);
         }
     }

+ 1 - 1
src/main/java/io/milvus/param/dml/InsertParam.java

@@ -207,7 +207,7 @@ public class InsertParam {
                             throw new ParamException("Bool field value type must be Boolean");
                         }
                     }
-                } else if (field.getType() == DataType.String) {
+                } else if (field.getType() == DataType.String || field.getType() == DataType.VarChar) {
                     for (Object obj : values) {
                         if (!(obj instanceof String)) {
                             throw new ParamException("String field value type must be String");

+ 2 - 0
src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -79,6 +79,7 @@ public class FieldDataWrapper {
                 return fieldData.getScalars().getFloatData().getDataList().size();
             case Double:
                 return fieldData.getScalars().getDoubleData().getDataList().size();
+            case VarChar:
             case String:
                 return fieldData.getScalars().getStringData().getDataList().size();
             default:
@@ -144,6 +145,7 @@ public class FieldDataWrapper {
                 return fieldData.getScalars().getFloatData().getDataList();
             case Double:
                 return fieldData.getScalars().getDoubleData().getDataList();
+            case VarChar:
             case String:
                 return fieldData.getScalars().getStringData().getDataList();
             default:

+ 1 - 1
src/main/java/io/milvus/response/SearchResultsWrapper.java

@@ -84,7 +84,7 @@ public class SearchResultsWrapper {
             }
         } else if (ids.hasStrId()) {
             StringArray strIDs = ids.getStrId();
-            if (offset + k >= strIDs.getDataCount()) {
+            if (offset + k > strIDs.getDataCount()) {
                 throw new IllegalResponseException("Result ids count is wrong");
             }
 

+ 168 - 0
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -764,4 +764,172 @@ class MilvusClientDockerTest {
                 .build());
         assertEquals(R.Status.Success.getCode(), deleteR.getStatus().intValue());*/
     }
+
+    @Test
+    void testStringField() {
+        String randomCollectionName = generator.generate(10);
+
+        // collection schema
+        String field1Name = "str_id";
+        String field2Name = "vec_field";
+        String field3Name = "str_field";
+        String field4Name = "int_field";
+        List<FieldType> fieldsSchema = new ArrayList<>();
+        fieldsSchema.add(FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withAutoID(false)
+                .withDataType(DataType.VarChar)
+                .withName(field1Name)
+                .withMaxLength(32)
+                .withDescription("string identity")
+                .build());
+
+        fieldsSchema.add(FieldType.newBuilder()
+                .withDataType(DataType.FloatVector)
+                .withName(field2Name)
+                .withDescription("face")
+                .withDimension(dimension)
+                .build());
+
+        fieldsSchema.add(FieldType.newBuilder()
+                .withDataType(DataType.VarChar)
+                .withName(field3Name)
+                .withMaxLength(32)
+                .withDescription("comment")
+                .build());
+
+        fieldsSchema.add(FieldType.newBuilder()
+                .withDataType(DataType.Int64)
+                .withName(field4Name)
+                .withDescription("sequence")
+                .build());
+
+        // create collection
+        CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withDescription("test")
+                .withFieldTypes(fieldsSchema)
+                .build();
+
+        R<RpcStatus> createR = client.createCollection(createParam);
+        assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
+
+        R<DescribeCollectionResponse> response = client.describeCollection(DescribeCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .build());
+
+        DescCollResponseWrapper desc = new DescCollResponseWrapper(response.getData());
+        System.out.println(desc.toString());
+
+        // insert data
+        int rowCount = 10000;
+        List<String> ids = new ArrayList<>();
+        List<String> comments = new ArrayList<>();
+        List<Long> sequences = new ArrayList<>();
+        for (long i = 0L; i < rowCount; ++i) {
+            ids.add(generator.generate(8));
+            comments.add(generator.generate(8));
+            sequences.add(i);
+        }
+        List<List<Float>> vectors = generateFloatVectors(rowCount);
+
+        List<InsertParam.Field> fieldsInsert = new ArrayList<>();
+        fieldsInsert.add(new InsertParam.Field(field1Name, DataType.VarChar, ids));
+        fieldsInsert.add(new InsertParam.Field(field3Name, DataType.VarChar, comments));
+        fieldsInsert.add(new InsertParam.Field(field2Name, DataType.FloatVector, vectors));
+        fieldsInsert.add(new InsertParam.Field(field4Name, DataType.Int64, sequences));
+
+        InsertParam insertParam = InsertParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFields(fieldsInsert)
+                .build();
+
+        R<MutationResult> insertR = client.withTimeout(10, TimeUnit.SECONDS).insert(insertParam);
+        assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
+
+        MutationResultWrapper insertResultWrapper = new MutationResultWrapper(insertR.getData());
+        System.out.println(insertResultWrapper.getInsertCount() + " rows inserted");
+
+        // get collection statistics
+        R<GetCollectionStatisticsResponse> statR = client.getCollectionStatistics(GetCollectionStatisticsParam
+                .newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFlush(true)
+                .build());
+        assertEquals(R.Status.Success.getCode(), statR.getStatus().intValue());
+
+        GetCollStatResponseWrapper stat = new GetCollStatResponseWrapper(statR.getData());
+        System.out.println("Collection row count: " + stat.getRowCount());
+
+        // load collection
+        R<RpcStatus> loadR = client.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .build());
+        assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
+
+        // query vectors to verify
+        List<Long> queryItems = new ArrayList<>();
+        List<String> queryIds = new ArrayList<>();
+        int nq = 5;
+        Random ran = new Random();
+        int randomIndex = ran.nextInt(rowCount - nq);
+        for (int i = randomIndex; i < randomIndex + nq; ++i) {
+            queryIds.add(ids.get(i));
+            queryItems.add(sequences.get(i));
+        }
+        String expr = field4Name + " in " + queryItems.toString();
+        List<String> outputFields = Arrays.asList(field1Name, field3Name);
+        QueryParam queryParam = QueryParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withExpr(expr)
+                .withOutFields(outputFields)
+                .build();
+
+        R<QueryResults> queryR = client.query(queryParam);
+        assertEquals(R.Status.Success.getCode(), queryR.getStatus().intValue());
+
+        // verify query result
+        QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper(queryR.getData());
+        for (String fieldName : outputFields) {
+            FieldDataWrapper wrapper = queryResultsWrapper.getFieldWrapper(fieldName);
+            System.out.println("Query data of " + fieldName + ", row count: " + wrapper.getRowCount());
+            System.out.println(wrapper.getFieldData());
+            assertEquals(nq, wrapper.getFieldData().size());
+
+            if (fieldName.compareTo(field1Name) == 0) {
+                List<?> out = queryResultsWrapper.getFieldWrapper(field1Name).getFieldData();
+                assertEquals(nq, out.size());
+                for (Object o : out) {
+                    String id = (String) o;
+                    assertTrue(queryIds.contains(id));
+                }
+            }
+        }
+
+        // search
+        int topK = 5;
+        List<List<Float>> targetVectors = new ArrayList<>();
+        for (Long seq : queryItems) {
+            targetVectors.add(vectors.get(seq.intValue()));
+        }
+        SearchParam searchParam = SearchParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withMetricType(MetricType.L2)
+                .withTopK(topK)
+                .withVectors(targetVectors)
+                .withVectorFieldName(field2Name)
+                .addOutField(field4Name)
+                .build();
+
+        R<SearchResults> searchR = client.search(searchParam);
+        assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
+
+        // verify the search result
+        SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
+        for (int i = 0; i < targetVectors.size(); ++i) {
+            List<SearchResultsWrapper.IDScore> scores = results.getIDScore(i);
+            System.out.println("The result of No." + i + " target vector(ID = " + queryIds.get(i) + "):");
+            System.out.println(scores);
+        }
+    }
 }

+ 4 - 2
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -1479,7 +1479,7 @@ class MilvusServiceClientTest {
         fields.add(new InsertParam.Field("field3", DataType.Bool, bVal));
         fields.add(new InsertParam.Field("field4", DataType.Float, fVal));
         fields.add(new InsertParam.Field("field5", DataType.Double, dVal));
-        fields.add(new InsertParam.Field("field6", DataType.String, sVal));
+        fields.add(new InsertParam.Field("field6", DataType.VarChar, sVal));
         fields.add(new InsertParam.Field("field7", DataType.FloatVector, fVectors));
         fields.add(new InsertParam.Field("field8", DataType.BinaryVector, bVectors));
         InsertParam param = InsertParam.newBuilder()
@@ -1501,6 +1501,8 @@ class MilvusServiceClientTest {
                 builder.withDimension(16);
             } else if (field.getType() == DataType.FloatVector) {
                 builder.withDimension(2);
+            } else if (field.getType() == DataType.VarChar) {
+                builder.withMaxLength(20);
             }
 
             colBuilder.addFields(ParamUtils.ConvertField(builder.build()));
@@ -2449,7 +2451,7 @@ class MilvusServiceClientTest {
             strBuilder.addData(String.valueOf(i));
         }
         testScalarField(ScalarField.newBuilder().setStringData(strBuilder).build(),
-                DataType.String, dim);
+                DataType.VarChar, dim);
     }
 
     @Test