Browse Source

Fix thread-safe insert/upsert bug of V2 (#1024)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 11 months ago
parent
commit
5e34460f58

+ 5 - 3
src/main/java/io/milvus/v2/service/vector/VectorService.java

@@ -33,6 +33,7 @@ import io.milvus.v2.service.collection.response.DescribeCollectionResp;
 import io.milvus.v2.service.index.IndexService;
 import io.milvus.v2.service.vector.request.*;
 import io.milvus.v2.service.vector.response.*;
+import io.milvus.v2.utils.DataUtils;
 import io.milvus.v2.utils.RpcUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.slf4j.Logger;
@@ -103,8 +104,8 @@ public class VectorService extends BaseService {
 
         // TODO: set the database name
         DescribeCollectionResponse descResp = getCollectionInfo(blockingStub, "", request.getCollectionName());
-
-        MutationResult response = blockingStub.insert(dataUtils.convertGrpcInsertRequest(request, new DescCollResponseWrapper(descResp)));
+        DataUtils.InsertBuilderWrapper requestBuilder = new DataUtils.InsertBuilderWrapper();
+        MutationResult response = blockingStub.insert(requestBuilder.convertGrpcInsertRequest(request, new DescCollResponseWrapper(descResp)));
         cleanCacheIfFailed(response.getStatus(), "", request.getCollectionName());
         rpcUtils.handleResponse(title, response.getStatus());
         return InsertResp.builder()
@@ -117,7 +118,8 @@ public class VectorService extends BaseService {
 
         // TODO: set the database name
         DescribeCollectionResponse descResp = getCollectionInfo(blockingStub, "", request.getCollectionName());
-        MutationResult response = blockingStub.upsert(dataUtils.convertGrpcUpsertRequest(request, new DescCollResponseWrapper(descResp)));
+        DataUtils.InsertBuilderWrapper requestBuilder = new DataUtils.InsertBuilderWrapper();
+        MutationResult response = blockingStub.upsert(requestBuilder.convertGrpcUpsertRequest(request, new DescCollResponseWrapper(descResp)));
         cleanCacheIfFailed(response.getStatus(), "", request.getCollectionName());
         rpcUtils.handleResponse(title, response.getStatus());
         return UpsertResp.builder()

+ 138 - 140
src/main/java/io/milvus/v2/utils/DataUtils.java

@@ -21,182 +21,180 @@ package io.milvus.v2.utils;
 
 import com.google.gson.JsonElement;
 import com.google.gson.JsonObject;
-import com.google.common.collect.Lists;
-import com.google.protobuf.ByteString;
-import io.milvus.exception.IllegalResponseException;
 import io.milvus.exception.ParamException;
 import io.milvus.grpc.*;
 import io.milvus.param.Constant;
 import io.milvus.param.ParamUtils;
 import io.milvus.param.collection.FieldType;
-import io.milvus.param.dml.InsertParam;
 import io.milvus.response.DescCollResponseWrapper;
 import io.milvus.v2.service.vector.request.InsertReq;
 import io.milvus.v2.service.vector.request.UpsertReq;
 import lombok.NonNull;
 
-import java.nio.ByteBuffer;
 import java.util.*;
-import java.util.stream.Collectors;
 
 public class DataUtils {
-    private InsertRequest.Builder insertBuilder;
-    private UpsertRequest.Builder upsertBuilder;
-
-    public InsertRequest convertGrpcInsertRequest(@NonNull InsertReq requestParam,
-                                                  DescCollResponseWrapper wrapper) {
-        String collectionName = requestParam.getCollectionName();
-
-        // generate insert request builder
-        MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build();
-        insertBuilder = InsertRequest.newBuilder()
-                .setCollectionName(collectionName)
-                .setBase(msgBase)
-                .setNumRows(requestParam.getData().size());
-        upsertBuilder = null;
-        fillFieldsData(requestParam, wrapper);
-        return insertBuilder.build();
-    }
-    public UpsertRequest convertGrpcUpsertRequest(@NonNull UpsertReq requestParam,
-                                                  DescCollResponseWrapper wrapper) {
-        String collectionName = requestParam.getCollectionName();
-
-        // currently, not allow to upsert for collection whose primary key is auto-generated
-        FieldType pk = wrapper.getPrimaryField();
-        if (pk.isAutoID()) {
-            throw new ParamException(String.format("Upsert don't support autoID==True, collection: %s",
-                    requestParam.getCollectionName()));
+
+    public static class InsertBuilderWrapper {
+        private InsertRequest.Builder insertBuilder;
+        private UpsertRequest.Builder upsertBuilder;
+
+        public InsertRequest convertGrpcInsertRequest(@NonNull InsertReq requestParam,
+                                                      DescCollResponseWrapper wrapper) {
+            String collectionName = requestParam.getCollectionName();
+
+            // generate insert request builder
+            MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build();
+            insertBuilder = InsertRequest.newBuilder()
+                    .setCollectionName(collectionName)
+                    .setBase(msgBase)
+                    .setNumRows(requestParam.getData().size());
+            upsertBuilder = null;
+            fillFieldsData(requestParam, wrapper);
+            return insertBuilder.build();
         }
 
-        // generate upsert request builder
-        MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build();
-        upsertBuilder = UpsertRequest.newBuilder()
-                .setCollectionName(collectionName)
-                .setBase(msgBase)
-                .setNumRows(requestParam.getData().size());
-        insertBuilder = null;
-        fillFieldsData(requestParam, wrapper);
-        return upsertBuilder.build();
-    }
+        public UpsertRequest convertGrpcUpsertRequest(@NonNull UpsertReq requestParam,
+                                                      DescCollResponseWrapper wrapper) {
+            String collectionName = requestParam.getCollectionName();
 
-    private void addFieldsData(io.milvus.grpc.FieldData value) {
-        if (insertBuilder != null) {
-            insertBuilder.addFieldsData(value);
-        } else if (upsertBuilder != null) {
-            upsertBuilder.addFieldsData(value);
-        }
-    }
+            // currently, not allow to upsert for collection whose primary key is auto-generated
+            FieldType pk = wrapper.getPrimaryField();
+            if (pk.isAutoID()) {
+                throw new ParamException(String.format("Upsert don't support autoID==True, collection: %s",
+                        requestParam.getCollectionName()));
+            }
 
-    private void setPartitionName(String value) {
-        if (insertBuilder != null) {
-            insertBuilder.setPartitionName(value);
-        } else if (upsertBuilder != null) {
-            upsertBuilder.setPartitionName(value);
+            // generate upsert request builder
+            MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build();
+            upsertBuilder = UpsertRequest.newBuilder()
+                    .setCollectionName(collectionName)
+                    .setBase(msgBase)
+                    .setNumRows(requestParam.getData().size());
+            insertBuilder = null;
+            fillFieldsData(requestParam, wrapper);
+            return upsertBuilder.build();
         }
-    }
 
-    private void fillFieldsData(UpsertReq requestParam, DescCollResponseWrapper wrapper) {
-        // set partition name only when there is no partition key field
-        String partitionName = requestParam.getPartitionName();
-        boolean isPartitionKeyEnabled = false;
-        for (FieldType fieldType : wrapper.getFields()) {
-            if (fieldType.isPartitionKey()) {
-                isPartitionKeyEnabled = true;
-                break;
+        private void addFieldsData(io.milvus.grpc.FieldData value) {
+            if (insertBuilder != null) {
+                insertBuilder.addFieldsData(value);
+            } else if (upsertBuilder != null) {
+                upsertBuilder.addFieldsData(value);
             }
         }
-        if (isPartitionKeyEnabled) {
-            if (partitionName != null && !partitionName.isEmpty()) {
-                String msg = "Collection " + requestParam.getCollectionName() + " has partition key, not allow to specify partition name";
-                throw new ParamException(msg);
+
+        private void setPartitionName(String value) {
+            if (insertBuilder != null) {
+                insertBuilder.setPartitionName(value);
+            } else if (upsertBuilder != null) {
+                upsertBuilder.setPartitionName(value);
             }
-        } else if (partitionName != null) {
-            this.setPartitionName(partitionName);
         }
 
-        // convert insert data
-        List<JsonObject> rowFields = requestParam.getData();
-        checkAndSetRowData(wrapper, rowFields);
-    }
-
-    private void fillFieldsData(InsertReq requestParam, DescCollResponseWrapper wrapper) {
-        // set partition name only when there is no partition key field
-        String partitionName = requestParam.getPartitionName();
-        boolean isPartitionKeyEnabled = false;
-        for (FieldType fieldType : wrapper.getFields()) {
-            if (fieldType.isPartitionKey()) {
-                isPartitionKeyEnabled = true;
-                break;
+        private void fillFieldsData(UpsertReq requestParam, DescCollResponseWrapper wrapper) {
+            // set partition name only when there is no partition key field
+            String partitionName = requestParam.getPartitionName();
+            boolean isPartitionKeyEnabled = false;
+            for (FieldType fieldType : wrapper.getFields()) {
+                if (fieldType.isPartitionKey()) {
+                    isPartitionKeyEnabled = true;
+                    break;
+                }
             }
-        }
-        if (isPartitionKeyEnabled) {
-            if (partitionName != null && !partitionName.isEmpty()) {
-                String msg = "Collection " + requestParam.getCollectionName() + " has partition key, not allow to specify partition name";
-                throw new ParamException(msg);
+            if (isPartitionKeyEnabled) {
+                if (partitionName != null && !partitionName.isEmpty()) {
+                    String msg = "Collection " + requestParam.getCollectionName() + " has partition key, not allow to specify partition name";
+                    throw new ParamException(msg);
+                }
+            } else if (partitionName != null) {
+                this.setPartitionName(partitionName);
             }
-        } else if (partitionName != null) {
-            this.setPartitionName(partitionName);
+
+            // convert insert data
+            List<JsonObject> rowFields = requestParam.getData();
+            checkAndSetRowData(wrapper, rowFields);
         }
 
-        // convert insert data
-        List<JsonObject> rowFields = requestParam.getData();
-        checkAndSetRowData(wrapper, rowFields);
-    }
+        private void fillFieldsData(InsertReq requestParam, DescCollResponseWrapper wrapper) {
+            // set partition name only when there is no partition key field
+            String partitionName = requestParam.getPartitionName();
+            boolean isPartitionKeyEnabled = false;
+            for (FieldType fieldType : wrapper.getFields()) {
+                if (fieldType.isPartitionKey()) {
+                    isPartitionKeyEnabled = true;
+                    break;
+                }
+            }
+            if (isPartitionKeyEnabled) {
+                if (partitionName != null && !partitionName.isEmpty()) {
+                    String msg = "Collection " + requestParam.getCollectionName() + " has partition key, not allow to specify partition name";
+                    throw new ParamException(msg);
+                }
+            } else if (partitionName != null) {
+                this.setPartitionName(partitionName);
+            }
 
-    private void checkAndSetRowData(DescCollResponseWrapper wrapper, List<JsonObject> rows) {
-        List<FieldType> fieldTypes = wrapper.getFields();
-
-        Map<String, ParamUtils.InsertDataInfo> nameInsertInfo = new HashMap<>();
-        ParamUtils.InsertDataInfo insertDynamicDataInfo = ParamUtils.InsertDataInfo.builder().fieldType(
-                        FieldType.newBuilder()
-                                .withName(Constant.DYNAMIC_FIELD_NAME)
-                                .withDataType(DataType.JSON)
-                                .withIsDynamic(true)
-                                .build())
-                .data(new LinkedList<>()).build();
-        for (JsonObject row : rows) {
-            for (FieldType fieldType : fieldTypes) {
-                String fieldName = fieldType.getName();
-                ParamUtils.InsertDataInfo insertDataInfo = nameInsertInfo.getOrDefault(fieldName, ParamUtils.InsertDataInfo.builder()
-                        .fieldType(fieldType).data(new LinkedList<>()).build());
-
-                // check normalField
-                JsonElement rowFieldData = row.get(fieldName);
-                if (rowFieldData != null) {
-                    if (fieldType.isAutoID()) {
-                        String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
-                        throw new ParamException(msg);
-                    }
-                    Object fieldValue = ParamUtils.checkFieldValue(fieldType, rowFieldData);
-                    insertDataInfo.getData().add(fieldValue);
-                    nameInsertInfo.put(fieldName, insertDataInfo);
-                } else {
-                    // check if autoId
-                    if (!fieldType.isAutoID()) {
-                        String msg = String.format("The field: %s is not provided.", fieldType.getName());
-                        throw new ParamException(msg);
+            // convert insert data
+            List<JsonObject> rowFields = requestParam.getData();
+            checkAndSetRowData(wrapper, rowFields);
+        }
+
+        private void checkAndSetRowData(DescCollResponseWrapper wrapper, List<JsonObject> rows) {
+            List<FieldType> fieldTypes = wrapper.getFields();
+
+            Map<String, ParamUtils.InsertDataInfo> nameInsertInfo = new HashMap<>();
+            ParamUtils.InsertDataInfo insertDynamicDataInfo = ParamUtils.InsertDataInfo.builder().fieldType(
+                            FieldType.newBuilder()
+                                    .withName(Constant.DYNAMIC_FIELD_NAME)
+                                    .withDataType(DataType.JSON)
+                                    .withIsDynamic(true)
+                                    .build())
+                    .data(new LinkedList<>()).build();
+            for (JsonObject row : rows) {
+                for (FieldType fieldType : fieldTypes) {
+                    String fieldName = fieldType.getName();
+                    ParamUtils.InsertDataInfo insertDataInfo = nameInsertInfo.getOrDefault(fieldName, ParamUtils.InsertDataInfo.builder()
+                            .fieldType(fieldType).data(new LinkedList<>()).build());
+
+                    // check normalField
+                    JsonElement rowFieldData = row.get(fieldName);
+                    if (rowFieldData != null) {
+                        if (fieldType.isAutoID()) {
+                            String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
+                            throw new ParamException(msg);
+                        }
+                        Object fieldValue = ParamUtils.checkFieldValue(fieldType, rowFieldData);
+                        insertDataInfo.getData().add(fieldValue);
+                        nameInsertInfo.put(fieldName, insertDataInfo);
+                    } else {
+                        // check if autoId
+                        if (!fieldType.isAutoID()) {
+                            String msg = String.format("The field: %s is not provided.", fieldType.getName());
+                            throw new ParamException(msg);
+                        }
                     }
                 }
-            }
 
-            // deal with dynamicField
-            if (wrapper.getEnableDynamicField()) {
-                JsonObject dynamicField = new JsonObject();
-                for (String rowFieldName : row.keySet()) {
-                    if (!nameInsertInfo.containsKey(rowFieldName)) {
-                        dynamicField.add(rowFieldName, row.get(rowFieldName));
+                // deal with dynamicField
+                if (wrapper.getEnableDynamicField()) {
+                    JsonObject dynamicField = new JsonObject();
+                    for (String rowFieldName : row.keySet()) {
+                        if (!nameInsertInfo.containsKey(rowFieldName)) {
+                            dynamicField.add(rowFieldName, row.get(rowFieldName));
+                        }
                     }
+                    insertDynamicDataInfo.getData().add(dynamicField);
                 }
-                insertDynamicDataInfo.getData().add(dynamicField);
             }
-        }
 
-        for (String fieldNameKey : nameInsertInfo.keySet()) {
-            ParamUtils.InsertDataInfo insertDataInfo = nameInsertInfo.get(fieldNameKey);
-            this.addFieldsData(ParamUtils.genFieldData(insertDataInfo.getFieldType(), insertDataInfo.getData()));
-        }
-        if (wrapper.getEnableDynamicField()) {
-            this.addFieldsData(ParamUtils.genFieldData(insertDynamicDataInfo.getFieldType(), insertDynamicDataInfo.getData(), Boolean.TRUE));
+            for (String fieldNameKey : nameInsertInfo.keySet()) {
+                ParamUtils.InsertDataInfo insertDataInfo = nameInsertInfo.get(fieldNameKey);
+                this.addFieldsData(ParamUtils.genFieldData(insertDataInfo.getFieldType(), insertDataInfo.getData()));
+            }
+            if (wrapper.getEnableDynamicField()) {
+                this.addFieldsData(ParamUtils.genFieldData(insertDynamicDataInfo.getFieldType(), insertDynamicDataInfo.getData(), Boolean.TRUE));
+            }
         }
     }
 }

+ 135 - 2
src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -85,14 +85,18 @@ class MilvusClientV2DockerTest {
         }
     }
 
-    private List<Float> generateFolatVector() {
+    private List<Float> generateFolatVector(int dim) {
         List<Float> vector = new ArrayList<>();
-        for (int i = 0; i < dimension; ++i) {
+        for (int i = 0; i < dim; ++i) {
             vector.add(RANDOM.nextFloat());
         }
         return vector;
     }
 
+    private List<Float> generateFolatVector() {
+        return generateFolatVector(dimension);
+    }
+
     private List<List<Float>> generateFloatVectors(int count) {
         List<List<Float>> vectors = new ArrayList<>();
         for (int n = 0; n < count; ++n) {
@@ -1183,4 +1187,133 @@ class MilvusClientV2DockerTest {
             Assertions.fail(e.getMessage());
         }
     }
+
+    @Test
+    void testMultiThreadsInsert() {
+        String randomCollectionName = generator.generate(10);
+        int dim = 64;
+
+        CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("id")
+                .dataType(DataType.VarChar)
+                .isPrimaryKey(Boolean.TRUE)
+                .maxLength(65535)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("vector")
+                .dataType(DataType.FloatVector)
+                .dimension(dim)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("dataTime")
+                .dataType(DataType.Int64)
+                .build());
+
+        List<IndexParam> indexParams = new ArrayList<>();
+        indexParams.add(IndexParam.builder()
+                .fieldName("vector")
+                .indexType(IndexParam.IndexType.FLAT)
+                .metricType(IndexParam.MetricType.L2)
+                .build());
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(randomCollectionName)
+                .collectionSchema(collectionSchema)
+                .indexParams(indexParams)
+                .build();
+        client.createCollection(requestCreate);
+        System.out.println("Collection created");
+
+        try {
+            Gson gson = new Gson();
+            Random rand = new Random();
+            List<Thread> threadList = new ArrayList<>();
+            for (int k = 0; k < 10; k++) {
+                Thread t = new Thread(() -> {
+                    for (int i = 0; i < 20; i++) {
+                        List<JsonObject> rows = new ArrayList<>();
+                        int cnt = rand.nextInt(100) + 100;
+                        for (int j = 0; j < cnt; j++) {
+                            JsonObject obj = new JsonObject();
+                            obj.addProperty("id", String.format("%d", i*cnt + j));
+                            List<Float> vector = generateFolatVector(dim);
+                            obj.add("vector", gson.toJsonTree(vector));
+                            obj.addProperty("dataTime", System.currentTimeMillis());
+                            rows.add(obj);
+                        }
+
+                        client.insert(InsertReq.builder()
+                                .collectionName(randomCollectionName)
+                                .data(rows)
+                                .build());
+                    }
+                });
+                t.start();
+                threadList.add(t);
+            }
+
+            for (Thread t : threadList) {
+                t.join();
+            }
+            System.out.println("Multi-thread insert done");
+
+            QueryResp queryResp = client.query(QueryReq.builder()
+                    .filter("")
+                    .collectionName(randomCollectionName)
+                    .outputFields(Collections.singletonList("count(*)"))
+                    .consistencyLevel(ConsistencyLevel.STRONG)
+                    .build());
+            System.out.println(queryResp.getQueryResults().get(0).getEntity().get("count(*)"));
+        } catch (Exception e) {
+            System.out.println(e.getMessage());
+            Assertions.fail(e.getMessage());
+        }
+
+        try {
+            Gson gson = new Gson();
+            Random rand = new Random();
+            List<Thread> threadList = new ArrayList<>();
+            for (int k = 0; k < 10; k++) {
+                Thread t = new Thread(() -> {
+                    for (int i = 0; i < 20; i++) {
+                        List<JsonObject> rows = new ArrayList<>();
+                        int cnt = rand.nextInt(100) + 100;
+                        for (int j = 0; j < cnt; j++) {
+                            JsonObject obj = new JsonObject();
+                            obj.addProperty("id", String.format("%d", i*cnt + j));
+                            List<Float> vector = generateFolatVector(dim);
+                            obj.add("vector", gson.toJsonTree(vector));
+                            obj.addProperty("dataTime", System.currentTimeMillis());
+                            rows.add(obj);
+                        }
+
+                        UpsertReq upsertReq = UpsertReq.builder()
+                                .collectionName(randomCollectionName)
+                                .data(rows)
+                                .build();
+                        client.upsert(upsertReq);
+                    }
+                });
+                t.start();
+                threadList.add(t);
+            }
+
+            for (Thread t : threadList) {
+                t.join();
+            }
+            System.out.println("Multi-thread upsert done");
+
+            QueryResp queryResp = client.query(QueryReq.builder()
+                    .filter("")
+                    .collectionName(randomCollectionName)
+                    .outputFields(Collections.singletonList("count(*)"))
+                    .consistencyLevel(ConsistencyLevel.STRONG)
+                    .build());
+            System.out.println(queryResp.getQueryResults().get(0).getEntity().get("count(*)"));
+        } catch (Exception e) {
+            System.out.println(e.getMessage());
+            Assertions.fail(e.getMessage());
+        }
+    }
 }