浏览代码

Allows upsert and autoid=true for MilvusClientV1 (#1505)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 2 天之前
父节点
当前提交
5060e10426

+ 185 - 0
examples/src/main/java/io/milvus/v1/UpsertExample.java

@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.v1;
+
+import com.google.gson.Gson;
+import com.google.gson.JsonObject;
+import io.milvus.client.MilvusClient;
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.common.clientenum.ConsistencyLevelEnum;
+import io.milvus.grpc.DataType;
+import io.milvus.grpc.MutationResult;
+import io.milvus.grpc.QueryResults;
+import io.milvus.param.*;
+import io.milvus.param.collection.*;
+import io.milvus.param.dml.InsertParam;
+import io.milvus.param.dml.QueryParam;
+import io.milvus.param.dml.UpsertParam;
+import io.milvus.param.index.CreateIndexParam;
+import io.milvus.response.QueryResultsWrapper;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+public class UpsertExample {
+    private static final MilvusClient client;
+
+    static {
+        ConnectParam connectParam = ConnectParam.newBuilder()
+                .withHost("localhost")
+                .withPort(19530)
+                .build();
+        client = new MilvusServiceClient(connectParam);
+    }
+    private static final String COLLECTION_NAME = "java_sdk_example_upsert_v1";
+    private static final String ID_FIELD = "pk";
+    private static final String VECTOR_FIELD = "vector";
+    private static final String TEXT_FIELD = "text";
+    private static final Integer VECTOR_DIM = 128;
+
+    private static void queryWithExpr(String expr) {
+        R<QueryResults> queryRet = client.query(QueryParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withExpr(expr)
+                .withOutFields(Arrays.asList(ID_FIELD, TEXT_FIELD))
+                .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
+                .build());
+        QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryRet.getData());
+        System.out.println("\nQuery with expression: " + expr);
+        List<QueryResultsWrapper.RowRecord> records = queryWrapper.getRowRecords();
+        for (QueryResultsWrapper.RowRecord record : records) {
+            System.out.println(record);
+        }
+    }
+
+    private static List<Long> createCollection(boolean autoID) {
+        // Define fields
+        List<FieldType> fieldsSchema = Arrays.asList(
+                FieldType.newBuilder()
+                        .withName(ID_FIELD)
+                        .withDataType(DataType.Int64)
+                        .withPrimaryKey(true)
+                        .withAutoID(autoID)
+                        .build(),
+                FieldType.newBuilder()
+                        .withName(VECTOR_FIELD)
+                        .withDataType(DataType.FloatVector)
+                        .withDimension(VECTOR_DIM)
+                        .build(),
+                FieldType.newBuilder()
+                        .withName(TEXT_FIELD)
+                        .withDataType(DataType.VarChar)
+                        .withMaxLength(100)
+                        .build()
+        );
+
+        CollectionSchemaParam collectionSchemaParam = CollectionSchemaParam.newBuilder()
+                .withFieldTypes(fieldsSchema)
+                .build();
+
+        // Drop the collection if exists
+        client.dropCollection(DropCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+
+        // Create the collection with 3 fields
+        R<RpcStatus> ret = client.createCollection(CreateCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withSchema(collectionSchemaParam)
+                .build());
+        CommonUtils.handleResponseStatus(ret);
+
+        // Specify an index type on the vector field.
+        ret = client.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withFieldName(VECTOR_FIELD)
+                .withIndexType(IndexType.FLAT)
+                .withMetricType(MetricType.L2)
+                .build());
+        CommonUtils.handleResponseStatus(ret);
+
+        // Call loadCollection() to enable automatically loading data into memory for searching
+        client.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        System.out.println("\nCollection created with autoID = " + autoID);
+
+        // insert rows
+        Gson gson = new Gson();
+        List<JsonObject> rows = new ArrayList<>();
+        for (int i = 0; i < 100; i++) {
+            JsonObject row = new JsonObject();
+            if (!autoID) {
+                row.addProperty(ID_FIELD, i);
+            }
+            List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
+            row.add(VECTOR_FIELD, gson.toJsonTree(vector));
+            row.addProperty(TEXT_FIELD, String.format("text_%d", i));
+            rows.add(row);
+        }
+        R<MutationResult> resp = client.insert(InsertParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withRows(rows)
+                .build());
+        CommonUtils.handleResponseStatus(resp);
+        return resp.getData().getIDs().getIntId().getDataList();
+    }
+
+    private static void doUpsert(boolean autoID) {
+        // if autoID is true, the collection primary key is auto-generated by server
+        List<Long> ids = createCollection(autoID);
+
+        // query before upsert
+        Long testID = ids.get(1);
+        String filter = String.format("%s == %d", ID_FIELD, testID);
+        queryWithExpr(filter);
+
+        // upsert
+        // the server will return a new primary key, the old entity is deleted,
+        // and a new entity is created with the new primary key
+        Gson gson = new Gson();
+        JsonObject row = new JsonObject();
+        row.addProperty(ID_FIELD, testID);
+        List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
+        row.add(VECTOR_FIELD, gson.toJsonTree(vector));
+        row.addProperty(TEXT_FIELD, "this field has been updated");
+        R<MutationResult> upsertResp = client.upsert(UpsertParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withRows(Collections.singletonList(row))
+                .build());
+        CommonUtils.handleResponseStatus(upsertResp);
+        List<Long> newIds = upsertResp.getData().getIDs().getIntId().getDataList();
+        Long newID = newIds.get(0);
+        System.out.println("\nUpsert done");
+
+        // query after upsert
+        filter = String.format("%s == %d", ID_FIELD, newID);
+        queryWithExpr(filter);
+    }
+
+    public static void main(String[] args) {
+        doUpsert(true);
+        doUpsert(false);
+
+        client.close();
+    }
+}

+ 169 - 0
examples/src/main/java/io/milvus/v2/UpsertExample.java

@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.v2;
+
+import com.google.gson.Gson;
+import com.google.gson.JsonObject;
+import io.milvus.v1.CommonUtils;
+import io.milvus.v2.client.ConnectConfig;
+import io.milvus.v2.client.MilvusClientV2;
+import io.milvus.v2.common.ConsistencyLevel;
+import io.milvus.v2.common.DataType;
+import io.milvus.v2.common.IndexParam;
+import io.milvus.v2.service.collection.request.AddFieldReq;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import io.milvus.v2.service.collection.request.DropCollectionReq;
+import io.milvus.v2.service.vector.request.InsertReq;
+import io.milvus.v2.service.vector.request.QueryReq;
+import io.milvus.v2.service.vector.request.UpsertReq;
+import io.milvus.v2.service.vector.response.InsertResp;
+import io.milvus.v2.service.vector.response.QueryResp;
+import io.milvus.v2.service.vector.response.UpsertResp;
+
+import java.util.*;
+
+public class UpsertExample {
+    private static final MilvusClientV2 client;
+    static {
+        client = new MilvusClientV2(ConnectConfig.builder()
+                .uri("http://localhost:19530")
+                .build());
+    }
+
+    private static final String COLLECTION_NAME = "java_sdk_example_upsert_v2";
+    private static final String ID_FIELD = "pk";
+    private static final String VECTOR_FIELD = "vector";
+    private static final String TEXT_FIELD = "text";
+    private static final Integer VECTOR_DIM = 128;
+
+    private static List<Object> createCollection(boolean autoID) {
+        // Drop collection if exists
+        client.dropCollection(DropCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .build());
+
+        // Create collection
+        CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(ID_FIELD)
+                .dataType(DataType.Int64)
+                .isPrimaryKey(Boolean.TRUE)
+                .autoID(autoID)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(VECTOR_FIELD)
+                .dataType(DataType.FloatVector)
+                .dimension(VECTOR_DIM)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(TEXT_FIELD)
+                .dataType(DataType.VarChar)
+                .maxLength(100)
+                .build());
+
+        List<IndexParam> indexes = new ArrayList<>();
+        indexes.add(IndexParam.builder()
+                .fieldName(VECTOR_FIELD)
+                .indexType(IndexParam.IndexType.FLAT)
+                .metricType(IndexParam.MetricType.COSINE)
+                .build());
+
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .collectionSchema(collectionSchema)
+                .indexParams(indexes)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build();
+        client.createCollection(requestCreate);
+        System.out.println("\nCollection created with autoID = " + autoID);
+
+        // Insert rows
+        Gson gson = new Gson();
+        List<JsonObject> rows = new ArrayList<>();
+        for (int i = 0; i < 100; i++) {
+            JsonObject row = new JsonObject();
+            if (!autoID) {
+                row.addProperty(ID_FIELD, i);
+            }
+            List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
+            row.add(VECTOR_FIELD, gson.toJsonTree(vector));
+            row.addProperty(TEXT_FIELD, String.format("text_%d", i));
+            rows.add(row);
+        }
+        InsertResp resp = client.insert(InsertReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(rows)
+                .build());
+        return resp.getPrimaryKeys();
+    }
+
+    private static void queryWithExpr(String expr) {
+        QueryResp queryRet = client.query(QueryReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .filter(expr)
+                .outputFields(Arrays.asList(ID_FIELD, TEXT_FIELD))
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .build());
+        System.out.println("\nQuery with expression: " + expr);
+        List<QueryResp.QueryResult> records = queryRet.getQueryResults();
+        for (QueryResp.QueryResult record : records) {
+            System.out.println(record.getEntity());
+        }
+    }
+
+    private static void doUpsert(boolean autoID) {
+        // if autoID is true, the collection primary key is auto-generated by server
+        List<Object> ids = createCollection(autoID);
+
+        // query before upsert
+        Long testID = (Long)ids.get(1);
+        String filter = String.format("%s == %d", ID_FIELD, testID);
+        queryWithExpr(filter);
+
+        // upsert
+        // the server will return a new primary key, the old entity is deleted,
+        // and a new entity is created with the new primary key
+        Gson gson = new Gson();
+        JsonObject row = new JsonObject();
+        row.addProperty(ID_FIELD, testID);
+        List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
+        row.add(VECTOR_FIELD, gson.toJsonTree(vector));
+        row.addProperty(TEXT_FIELD, "this field has been updated");
+        UpsertResp upsertResp = client.upsert(UpsertReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(Collections.singletonList(row))
+                .build());
+        List<Object> newIds = upsertResp.getPrimaryKeys();
+        Long newID = (Long)newIds.get(0);
+        System.out.println("\nUpsert done");
+
+        // query after upsert
+        filter = String.format("%s == %d", ID_FIELD, newID);
+        queryWithExpr(filter);
+    }
+
+    public static void main(String[] args) {
+        doUpsert(true);
+        doUpsert(false);
+
+        client.close();
+    }
+}

+ 4 - 9
sdk-core/src/main/java/io/milvus/param/ParamUtils.java

@@ -526,13 +526,6 @@ public class ParamUtils {
                                     DescCollResponseWrapper wrapper) {
             String collectionName = requestParam.getCollectionName();
 
-            // currently, not allow to upsert for collection whose primary key is auto-generated
-            FieldType pk = wrapper.getPrimaryField();
-            if (pk.isAutoID()) {
-                throw new ParamException(String.format("Upsert don't support autoID==True, collection: %s",
-                        requestParam.getCollectionName()));
-            }
-
             // generate upsert request builder
             MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build();
             upsertBuilder = UpsertRequest.newBuilder()
@@ -601,7 +594,8 @@ public class ParamUtils {
                 boolean found = false;
                 for (InsertParam.Field field : fields) {
                     if (field.getName().equals(fieldType.getName())) {
-                        if (fieldType.isAutoID()) {
+                        // from v2.4.10, milvus allows upsert for auto-id pk, no need to check for upsert action
+                        if (fieldType.isAutoID() && insertBuilder != null) {
                             String msg = String.format("The primary key: %s is auto generated, no need to input.",
                                     fieldType.getName());
                             throw new ParamException(msg);
@@ -669,7 +663,8 @@ public class ParamUtils {
                         rowFieldData = JsonNull.INSTANCE;
                     }
 
-                    if (fieldType.isAutoID()) {
+                    // from v2.4.10, milvus allows upsert for auto-id pk, no need to check for upsert action
+                    if (fieldType.isAutoID() && insertBuilder != null) {
                         String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
                         throw new ParamException(msg);
                     }

+ 18 - 11
sdk-core/src/main/java/io/milvus/v2/service/vector/VectorService.java

@@ -148,19 +148,17 @@ public class VectorService extends BaseService {
         String key = GTsDict.CombineCollectionName(actualDbName(dbName), collectionName);
         GTsDict.getInstance().updateCollectionTs(key, response.getTimestamp());
 
+        // handle integer pk or string pk
+        List<Object> ids = new ArrayList<>();
         if (response.getIDs().hasIntId()) {
-            List<Object> ids = new ArrayList<>(response.getIDs().getIntId().getDataList());
-            return InsertResp.builder()
-                    .InsertCnt(response.getInsertCnt())
-                    .primaryKeys(ids)
-                    .build();
-        } else {
-            List<Object> ids = new ArrayList<>(response.getIDs().getStrId().getDataList());
-            return InsertResp.builder()
-                    .InsertCnt(response.getInsertCnt())
-                    .primaryKeys(ids)
-                    .build();
+            ids = new ArrayList<>(response.getIDs().getIntId().getDataList());
+        } else if (response.getIDs().hasStrId()) {
+            ids = new ArrayList<>(response.getIDs().getStrId().getDataList());
         }
+        return InsertResp.builder()
+                .InsertCnt(response.getInsertCnt())
+                .primaryKeys(ids)
+                .build();
     }
 
     private UpsertRequest buildUpsertRequest(UpsertReq request, DescribeCollectionResponse descResp) {
@@ -207,8 +205,17 @@ public class VectorService extends BaseService {
         // update the last write timestamp for SESSION consistency
         String key = GTsDict.CombineCollectionName(actualDbName(dbName), collectionName);
         GTsDict.getInstance().updateCollectionTs(key, response.getTimestamp());
+
+        // handle integer pk or string pk
+        List<Object> ids = new ArrayList<>();
+        if (response.getIDs().hasIntId()) {
+            ids = new ArrayList<>(response.getIDs().getIntId().getDataList());
+        } else if (response.getIDs().hasStrId()) {
+            ids = new ArrayList<>(response.getIDs().getStrId().getDataList());
+        }
         return UpsertResp.builder()
                 .upsertCnt(response.getUpsertCnt())
+                .primaryKeys(ids)
                 .build();
     }
 

+ 1 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/response/InsertResp.java

@@ -29,6 +29,7 @@ import java.util.List;
 @Data
 @SuperBuilder
 public class InsertResp {
+    // TODO: the first character should be lower case, add a new member and deprecate the old member
     private long InsertCnt;
     @Builder.Default
     private List<Object> primaryKeys = new ArrayList<>();

+ 10 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/response/UpsertResp.java

@@ -19,11 +19,21 @@
 
 package io.milvus.v2.service.vector.response;
 
+import lombok.Builder;
 import lombok.Data;
 import lombok.experimental.SuperBuilder;
 
+import java.util.ArrayList;
+import java.util.List;
+
 @Data
 @SuperBuilder
 public class UpsertResp {
     private long upsertCnt;
+
+    // From v2.4.10, milvus allows upsert for auto-id=true, the server will return a new pk.
+    // the new pk is not equal to the original pk, the original entity is deleted, and a new entity
+    // is created with this new pk. Here we return this new pk to user.
+    @Builder.Default
+    private List<Object> primaryKeys = new ArrayList<>();
 }

+ 1 - 0
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -1164,6 +1164,7 @@ class MilvusClientV2DockerTest {
                 .data(dataUpdate)
                 .build());
         Assertions.assertEquals(2, upsertResp.getUpsertCnt());
+        Assertions.assertEquals(2, upsertResp.getPrimaryKeys().size());
 
         // get row count
         rowCount = getRowCount(randomCollectionName);