Browse Source

Fix a critical bug that partial upsert override field value to null (#1638)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 13 hours ago
parent
commit
c7790abbbd

+ 78 - 26
examples/src/main/java/io/milvus/v2/UpsertExample.java

@@ -51,7 +51,8 @@ public class UpsertExample {
     private static final String ID_FIELD = "pk";
     private static final String VECTOR_FIELD = "vector";
     private static final String TEXT_FIELD = "text";
-    private static final Integer VECTOR_DIM = 128;
+    private static final String NULLABLE_FIELD = "nullable";
+    private static final Integer VECTOR_DIM = 4;
 
     private static List<Object> createCollection(boolean autoID) {
         // Drop collection if exists
@@ -78,6 +79,11 @@ public class UpsertExample {
                 .dataType(DataType.VarChar)
                 .maxLength(100)
                 .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(NULLABLE_FIELD)
+                .dataType(DataType.Int32)
+                .isNullable(true)
+                .build());
 
         List<IndexParam> indexes = new ArrayList<>();
         indexes.add(IndexParam.builder()
@@ -106,6 +112,7 @@ public class UpsertExample {
             List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
             row.add(VECTOR_FIELD, gson.toJsonTree(vector));
             row.addProperty(TEXT_FIELD, String.format("text_%d", i));
+            row.addProperty(NULLABLE_FIELD, i);
             rows.add(row);
         }
         InsertResp resp = client.insert(InsertReq.builder()
@@ -119,7 +126,7 @@ public class UpsertExample {
         QueryResp queryRet = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
                 .filter(expr)
-                .outputFields(Arrays.asList(ID_FIELD, TEXT_FIELD))
+                .outputFields(Arrays.asList(ID_FIELD, VECTOR_FIELD, TEXT_FIELD, NULLABLE_FIELD))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());
         System.out.println("\nQuery with expression: " + expr);
@@ -130,34 +137,79 @@ public class UpsertExample {
     }
 
     private static void doUpsert(boolean autoID) {
-        // if autoID is true, the collection primary key is auto-generated by server
+        System.out.printf("\n============================= autoID = %s =============================", autoID ? "true" : "false");
+        // If autoID is true, the collection primary key is auto-generated by server
         List<Object> ids = createCollection(autoID);
 
-        // query before upsert
-        Long testID = (Long)ids.get(1);
-        String filter = String.format("%s == %d", ID_FIELD, testID);
-        queryWithExpr(filter);
-
-        // upsert
-        // the server will return a new primary key, the old entity is deleted,
-        // and a new entity is created with the new primary key
         Gson gson = new Gson();
-        JsonObject row = new JsonObject();
-        row.addProperty(ID_FIELD, testID);
-        List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
-        row.add(VECTOR_FIELD, gson.toJsonTree(vector));
-        row.addProperty(TEXT_FIELD, "this field has been updated");
-        UpsertResp upsertResp = client.upsert(UpsertReq.builder()
-                .collectionName(COLLECTION_NAME)
-                .data(Collections.singletonList(row))
-                .build());
-        List<Object> newIds = upsertResp.getPrimaryKeys();
-        Long newID = (Long)newIds.get(0);
-        System.out.println("\nUpsert done");
+        {
+            // Query before upsert, get the No.2 primary key
+            Long oldID = (Long) ids.get(1);
+            String filter = String.format("%s == %d", ID_FIELD, oldID);
+            queryWithExpr(filter);
+
+            // Upsert, update all fields value
+            // If autoID is true, the server will return a new primary key for the updated entity
+            JsonObject row = new JsonObject();
+            row.addProperty(ID_FIELD, oldID);
+            List<Float> vector = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f);
+            row.add(VECTOR_FIELD, gson.toJsonTree(vector));
+            row.addProperty(TEXT_FIELD, "this field has been updated");
+            row.add(NULLABLE_FIELD, null); // update nullable field to null
+            UpsertResp upsertResp = client.upsert(UpsertReq.builder()
+                    .collectionName(COLLECTION_NAME)
+                    .data(Collections.singletonList(row))
+                    .build());
+            List<Object> newIds = upsertResp.getPrimaryKeys();
+            Long newID = (Long) newIds.get(0);
+            System.out.printf("\nUpsert done, primary key %d has been updated to %d%n", oldID, newID);
+
+            // Query after upsert, you will see the vector field is [1.0f, 1.0f, 1.0f, 1.0f],
+            // text field is "this field has been updated", nullable field is null
+            filter = String.format("%s == %d", ID_FIELD, newID);
+            queryWithExpr(filter);
+        }
 
-        // query after upsert
-        filter = String.format("%s == %d", ID_FIELD, newID);
-        queryWithExpr(filter);
+        {
+            // Query before upsert, get the No.5 and No.6 primary key
+            Long oldID1 = (Long)ids.get(4);
+            Long oldID2 = (Long)ids.get(5);
+            String filter = String.format("%s in [%d, %d]", ID_FIELD, oldID1, oldID2);
+            queryWithExpr(filter);
+
+            // Partial upsert, only update the specified field, other fields will keep old values
+            // If autoID is true, the server will return a new primary key for the updated entity
+            // Note: for the case to do partial upsert for multi entities, it only allows updating
+            // the same fields for all rows.
+            // For example, assume a collection has 2 fields: A and B
+            // it is legal to update the same fields as: client.upsert(data = [ {"A": 1}, {"A": 3}])
+            // it is illegal to update different fields as: client.upsert(data = [ {"A": 1}, {"B": 3}])
+            // Read the doc for more info: https://milvus.io/docs/upsert-entities.md
+            // Here we update the same field "text" for the two rows.
+            JsonObject row1 = new JsonObject();
+            row1.addProperty(ID_FIELD, oldID1);
+            row1.addProperty(TEXT_FIELD, "this row has been partially updated");
+
+            JsonObject row2 = new JsonObject();
+            row2.addProperty(ID_FIELD, oldID2);
+            row2.addProperty(TEXT_FIELD, "this row has been partially updated");
+
+            UpsertResp upsertResp = client.upsert(UpsertReq.builder()
+                    .collectionName(COLLECTION_NAME)
+                    .data(Arrays.asList(row1, row2))
+                    .partialUpdate(true)
+                    .build());
+            List<Object> newIds = upsertResp.getPrimaryKeys();
+            Long newID1 = (Long) newIds.get(0);
+            Long newID2 = (Long) newIds.get(1);
+            System.out.printf("\nPartial upsert done, primary key %d has been updated to %d, %d has been updated to %d%n",
+                    oldID1, newID1, oldID2, newID2);
+
+            // query after upsert, you will see the text field is "this row has been partially updated"
+            // the other fields keep old values
+            filter = String.format("%s in [%d, %d]", ID_FIELD, newID1, newID2);
+            queryWithExpr(filter);
+        }
     }
 
     public static void main(String[] args) {

+ 6 - 2
sdk-core/src/main/java/io/milvus/v2/utils/DataUtils.java

@@ -250,9 +250,13 @@ public class DataUtils {
                     return;
                 }
 
-                // if the field doesn't have default value, require user provide the value
                 // in v2.6.1 support partial update, user can input partial fields
-                if (!field.getIsNullable() && field.getDefaultValue() == null && !partialUpdate) {
+                if (partialUpdate) {
+                    return;
+                }
+
+                // if the field doesn't have default value, require user provide the value
+                if (!field.getIsNullable() && field.getDefaultValue() == null) {
                     String msg = String.format("The field: %s is not provided.", fieldName);
                     throw new MilvusClientException(ErrorCode.INVALID_PARAMS, msg);
                 }

+ 69 - 38
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -1636,6 +1636,11 @@ class MilvusClientV2DockerTest {
                 .dataType(DataType.FloatVector)
                 .dimension(4)
                 .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("text")
+                .dataType(DataType.VarChar)
+                .maxLength(1024)
+                .build());
 
         List<IndexParam> indexParams = new ArrayList<>();
         indexParams.add(IndexParam.builder()
@@ -1657,7 +1662,8 @@ class MilvusClientV2DockerTest {
         List<JsonObject> data = new ArrayList<>();
         for (int i = 0; i < 10; i++) {
             JsonObject row = new JsonObject();
-            row.addProperty("pk", String.format("pk_%d", i));
+            row.addProperty("pk", "pk_" + i);
+            row.addProperty("text", "desc_" + i);
             row.add("float_vector", JsonUtils.toJsonTree(new float[]{(float)i, (float)(i + 1), (float)(i + 2), (float)(i + 3)}));
             data.add(row);
         }
@@ -1682,57 +1688,82 @@ class MilvusClientV2DockerTest {
         Assertions.assertEquals(8L, rowCount);
 
         // upsert
-        List<JsonObject> dataUpdate = new ArrayList<>();
-        JsonObject row1 = new JsonObject();
-        row1.addProperty("pk", "pk_5");
-        row1.add("float_vector", JsonUtils.toJsonTree(new float[]{5.0f, 5.0f, 5.0f, 5.0f}));
-        dataUpdate.add(row1);
-        JsonObject row2 = new JsonObject();
-        row2.addProperty("pk", "pk_2");
-        row2.add("float_vector", JsonUtils.toJsonTree(new float[]{2.0f, 2.0f, 2.0f, 2.0f}));
-        dataUpdate.add(row2);
-        UpsertResp upsertResp = client.upsert(UpsertReq.builder()
-                .databaseName(testDbName)
-                .collectionName(randomCollectionName)
-                .data(dataUpdate)
-                .build());
-        Assertions.assertEquals(2, upsertResp.getUpsertCnt());
-        Assertions.assertEquals(2, upsertResp.getPrimaryKeys().size());
+        // id=5 and id=8 has been deleted, need to provide all fields
+        {
+            JsonObject row1 = new JsonObject();
+            row1.addProperty("pk", "pk_5");
+            row1.addProperty("text", "updated_5");
+            row1.add("float_vector", JsonUtils.toJsonTree(new float[]{5.0f, 5.0f, 5.0f, 5.0f}));
+
+            JsonObject row2 = new JsonObject();
+            row2.addProperty("pk", "pk_8");
+            row2.addProperty("text", "updated_8");
+            row2.add("float_vector", JsonUtils.toJsonTree(new float[]{5.0f, 5.0f, 5.0f, 5.0f}));
+
+            UpsertResp upsertResp = client.upsert(UpsertReq.builder()
+                    .databaseName(testDbName)
+                    .collectionName(randomCollectionName)
+                    .data(Arrays.asList(row1, row2))
+                    .build());
+            Assertions.assertEquals(2, upsertResp.getUpsertCnt());
+            Assertions.assertEquals(2, upsertResp.getPrimaryKeys().size());
+        }
+        // id=2 is a partial update, "text" old value is reused
+        {
+            JsonObject row = new JsonObject();
+            row.addProperty("pk", "pk_2");
+            row.add("float_vector", JsonUtils.toJsonTree(new float[]{5.0f, 5.0f, 5.0f, 5.0f}));
+
+            UpsertResp upsertResp = client.upsert(UpsertReq.builder()
+                    .databaseName(testDbName)
+                    .collectionName(randomCollectionName)
+                    .data(Collections.singletonList(row))
+                    .partialUpdate(true)
+                    .build());
+            Assertions.assertEquals(1, upsertResp.getUpsertCnt());
+            Assertions.assertEquals(1, upsertResp.getPrimaryKeys().size());
+        }
 
         // get row count
         rowCount = getRowCount(testDbName, randomCollectionName);
-        Assertions.assertEquals(9L, rowCount);
+        Assertions.assertEquals(10L, rowCount);
 
         // verify
         QueryResp queryResp = client.query(QueryReq.builder()
                 .databaseName(testDbName)
                 .collectionName(randomCollectionName)
-                .ids(Arrays.asList("pk_2", "pk_5"))
+                .ids(Arrays.asList("pk_2", "pk_5", "pk_8"))
                 .outputFields(Collections.singletonList("*"))
                 .build());
         List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
-        Assertions.assertEquals(2, queryResults.size());
+        Assertions.assertEquals(3, queryResults.size());
 
-        QueryResp.QueryResult result1 = queryResults.get(0);
-        Map<String, Object> entity1 = result1.getEntity();
-        Assertions.assertTrue(entity1.containsKey("pk"));
-        Assertions.assertEquals("pk_2", entity1.get("pk"));
-        Assertions.assertTrue(entity1.containsKey("float_vector"));
-        Assertions.assertTrue(entity1.get("float_vector") instanceof List);
-        List<Float> vector1 = (List<Float>) entity1.get("float_vector");
-        for (Float f : vector1) {
-            Assertions.assertEquals(2.0f, f);
+        {
+            QueryResp.QueryResult result = queryResults.get(0);
+            Map<String, Object> entity = result.getEntity();
+            Assertions.assertTrue(entity.containsKey("pk"));
+            Assertions.assertEquals("pk_2", entity.get("pk"));
+            Assertions.assertEquals("desc_2", entity.get("text"));
+            Assertions.assertTrue(entity.containsKey("float_vector"));
+            Assertions.assertTrue(entity.get("float_vector") instanceof List);
+            List<Float> vector1 = (List<Float>) entity.get("float_vector");
+            for (Float f : vector1) {
+                Assertions.assertEquals(5.0f, f);
+            }
         }
 
-        QueryResp.QueryResult result2 = queryResults.get(1);
-        Map<String, Object> entity2 = result2.getEntity();
-        Assertions.assertTrue(entity2.containsKey("pk"));
-        Assertions.assertEquals("pk_5", entity2.get("pk"));
-        Assertions.assertTrue(entity2.containsKey("float_vector"));
-        Assertions.assertTrue(entity2.get("float_vector") instanceof List);
-        List<Float> vector2 = (List<Float>) entity2.get("float_vector");
-        for (Float f : vector2) {
-            Assertions.assertEquals(5.0f, f);
+        {
+            QueryResp.QueryResult result = queryResults.get(1);
+            Map<String, Object> entity = result.getEntity();
+            Assertions.assertTrue(entity.containsKey("pk"));
+            Assertions.assertEquals("pk_5", entity.get("pk"));
+            Assertions.assertEquals("updated_5", entity.get("text"));
+            Assertions.assertTrue(entity.containsKey("float_vector"));
+            Assertions.assertTrue(entity.get("float_vector") instanceof List);
+            List<Float> vector2 = (List<Float>) entity.get("float_vector");
+            for (Float f : vector2) {
+                Assertions.assertEquals(5.0f, f);
+            }
         }
 
         client.dropCollection(DropCollectionReq.builder()