|
@@ -764,4 +764,172 @@ class MilvusClientDockerTest {
|
|
|
.build());
|
|
|
assertEquals(R.Status.Success.getCode(), deleteR.getStatus().intValue());*/
|
|
|
}
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void testStringField() {
|
|
|
+ String randomCollectionName = generator.generate(10);
|
|
|
+
|
|
|
+ // collection schema
|
|
|
+ String field1Name = "str_id";
|
|
|
+ String field2Name = "vec_field";
|
|
|
+ String field3Name = "str_field";
|
|
|
+ String field4Name = "int_field";
|
|
|
+ List<FieldType> fieldsSchema = new ArrayList<>();
|
|
|
+ fieldsSchema.add(FieldType.newBuilder()
|
|
|
+ .withPrimaryKey(true)
|
|
|
+ .withAutoID(false)
|
|
|
+ .withDataType(DataType.VarChar)
|
|
|
+ .withName(field1Name)
|
|
|
+ .withMaxLength(32)
|
|
|
+ .withDescription("string identity")
|
|
|
+ .build());
|
|
|
+
|
|
|
+ fieldsSchema.add(FieldType.newBuilder()
|
|
|
+ .withDataType(DataType.FloatVector)
|
|
|
+ .withName(field2Name)
|
|
|
+ .withDescription("face")
|
|
|
+ .withDimension(dimension)
|
|
|
+ .build());
|
|
|
+
|
|
|
+ fieldsSchema.add(FieldType.newBuilder()
|
|
|
+ .withDataType(DataType.VarChar)
|
|
|
+ .withName(field3Name)
|
|
|
+ .withMaxLength(32)
|
|
|
+ .withDescription("comment")
|
|
|
+ .build());
|
|
|
+
|
|
|
+ fieldsSchema.add(FieldType.newBuilder()
|
|
|
+ .withDataType(DataType.Int64)
|
|
|
+ .withName(field4Name)
|
|
|
+ .withDescription("sequence")
|
|
|
+ .build());
|
|
|
+
|
|
|
+ // create collection
|
|
|
+ CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
|
|
|
+ .withCollectionName(randomCollectionName)
|
|
|
+ .withDescription("test")
|
|
|
+ .withFieldTypes(fieldsSchema)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ R<RpcStatus> createR = client.createCollection(createParam);
|
|
|
+ assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
|
|
|
+
|
|
|
+ R<DescribeCollectionResponse> response = client.describeCollection(DescribeCollectionParam.newBuilder()
|
|
|
+ .withCollectionName(randomCollectionName)
|
|
|
+ .build());
|
|
|
+
|
|
|
+ DescCollResponseWrapper desc = new DescCollResponseWrapper(response.getData());
|
|
|
+ System.out.println(desc.toString());
|
|
|
+
|
|
|
+ // insert data
|
|
|
+ int rowCount = 10000;
|
|
|
+ List<String> ids = new ArrayList<>();
|
|
|
+ List<String> comments = new ArrayList<>();
|
|
|
+ List<Long> sequences = new ArrayList<>();
|
|
|
+ for (long i = 0L; i < rowCount; ++i) {
|
|
|
+ ids.add(generator.generate(8));
|
|
|
+ comments.add(generator.generate(8));
|
|
|
+ sequences.add(i);
|
|
|
+ }
|
|
|
+ List<List<Float>> vectors = generateFloatVectors(rowCount);
|
|
|
+
|
|
|
+ List<InsertParam.Field> fieldsInsert = new ArrayList<>();
|
|
|
+ fieldsInsert.add(new InsertParam.Field(field1Name, DataType.VarChar, ids));
|
|
|
+ fieldsInsert.add(new InsertParam.Field(field3Name, DataType.VarChar, comments));
|
|
|
+ fieldsInsert.add(new InsertParam.Field(field2Name, DataType.FloatVector, vectors));
|
|
|
+ fieldsInsert.add(new InsertParam.Field(field4Name, DataType.Int64, sequences));
|
|
|
+
|
|
|
+ InsertParam insertParam = InsertParam.newBuilder()
|
|
|
+ .withCollectionName(randomCollectionName)
|
|
|
+ .withFields(fieldsInsert)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ R<MutationResult> insertR = client.withTimeout(10, TimeUnit.SECONDS).insert(insertParam);
|
|
|
+ assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
|
|
|
+
|
|
|
+ MutationResultWrapper insertResultWrapper = new MutationResultWrapper(insertR.getData());
|
|
|
+ System.out.println(insertResultWrapper.getInsertCount() + " rows inserted");
|
|
|
+
|
|
|
+ // get collection statistics
|
|
|
+ R<GetCollectionStatisticsResponse> statR = client.getCollectionStatistics(GetCollectionStatisticsParam
|
|
|
+ .newBuilder()
|
|
|
+ .withCollectionName(randomCollectionName)
|
|
|
+ .withFlush(true)
|
|
|
+ .build());
|
|
|
+ assertEquals(R.Status.Success.getCode(), statR.getStatus().intValue());
|
|
|
+
|
|
|
+ GetCollStatResponseWrapper stat = new GetCollStatResponseWrapper(statR.getData());
|
|
|
+ System.out.println("Collection row count: " + stat.getRowCount());
|
|
|
+
|
|
|
+ // load collection
|
|
|
+ R<RpcStatus> loadR = client.loadCollection(LoadCollectionParam.newBuilder()
|
|
|
+ .withCollectionName(randomCollectionName)
|
|
|
+ .build());
|
|
|
+ assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
|
|
|
+
|
|
|
+ // query vectors to verify
|
|
|
+ List<Long> queryItems = new ArrayList<>();
|
|
|
+ List<String> queryIds = new ArrayList<>();
|
|
|
+ int nq = 5;
|
|
|
+ Random ran = new Random();
|
|
|
+ int randomIndex = ran.nextInt(rowCount - nq);
|
|
|
+ for (int i = randomIndex; i < randomIndex + nq; ++i) {
|
|
|
+ queryIds.add(ids.get(i));
|
|
|
+ queryItems.add(sequences.get(i));
|
|
|
+ }
|
|
|
+ String expr = field4Name + " in " + queryItems.toString();
|
|
|
+ List<String> outputFields = Arrays.asList(field1Name, field3Name);
|
|
|
+ QueryParam queryParam = QueryParam.newBuilder()
|
|
|
+ .withCollectionName(randomCollectionName)
|
|
|
+ .withExpr(expr)
|
|
|
+ .withOutFields(outputFields)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ R<QueryResults> queryR = client.query(queryParam);
|
|
|
+ assertEquals(R.Status.Success.getCode(), queryR.getStatus().intValue());
|
|
|
+
|
|
|
+ // verify query result
|
|
|
+ QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper(queryR.getData());
|
|
|
+ for (String fieldName : outputFields) {
|
|
|
+ FieldDataWrapper wrapper = queryResultsWrapper.getFieldWrapper(fieldName);
|
|
|
+ System.out.println("Query data of " + fieldName + ", row count: " + wrapper.getRowCount());
|
|
|
+ System.out.println(wrapper.getFieldData());
|
|
|
+ assertEquals(nq, wrapper.getFieldData().size());
|
|
|
+
|
|
|
+ if (fieldName.compareTo(field1Name) == 0) {
|
|
|
+ List<?> out = queryResultsWrapper.getFieldWrapper(field1Name).getFieldData();
|
|
|
+ assertEquals(nq, out.size());
|
|
|
+ for (Object o : out) {
|
|
|
+ String id = (String) o;
|
|
|
+ assertTrue(queryIds.contains(id));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // search
|
|
|
+ int topK = 5;
|
|
|
+ List<List<Float>> targetVectors = new ArrayList<>();
|
|
|
+ for (Long seq : queryItems) {
|
|
|
+ targetVectors.add(vectors.get(seq.intValue()));
|
|
|
+ }
|
|
|
+ SearchParam searchParam = SearchParam.newBuilder()
|
|
|
+ .withCollectionName(randomCollectionName)
|
|
|
+ .withMetricType(MetricType.L2)
|
|
|
+ .withTopK(topK)
|
|
|
+ .withVectors(targetVectors)
|
|
|
+ .withVectorFieldName(field2Name)
|
|
|
+ .addOutField(field4Name)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ R<SearchResults> searchR = client.search(searchParam);
|
|
|
+ assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
|
|
|
+
|
|
|
+ // verify the search result
|
|
|
+ SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
|
|
|
+ for (int i = 0; i < targetVectors.size(); ++i) {
|
|
|
+ List<SearchResultsWrapper.IDScore> scores = results.getIDScore(i);
|
|
|
+ System.out.println("The result of No." + i + " target vector(ID = " + queryIds.get(i) + "):");
|
|
|
+ System.out.println(scores);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|