Browse Source

Support import interface (#274)

Signed-off-by: groot <yihua.mo@zilliz.com>
groot 3 years ago
parent
commit
7d9739dbc8

+ 66 - 0
src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java

@@ -1186,6 +1186,72 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
         }
     }
 
+    @Override
+    public R<ImportResponse> importData(@NonNull ImportParam requestParam) {
+        if (!clientIsReady()) {
+            return R.failed(new ClientNotConnectedException("Client rpc channel is not ready"));
+        }
+
+        logInfo(requestParam.toString());
+
+        try {
+            ImportRequest.Builder builder = ImportRequest.newBuilder();
+            builder.setCollectionName(requestParam.getCollectionName())
+                    .setPartitionName(requestParam.getPartitionName())
+                    .setRowBased(requestParam.isRowBased());
+            requestParam.getFiles().forEach(builder::addFiles);
+            List<KeyValuePair> options = assembleKvPair(requestParam.getOptions());
+            if (CollectionUtils.isNotEmpty(options)) {
+                options.forEach(builder::addOptions);
+            }
+
+            ImportRequest importRequest = builder.build();
+            ImportResponse response = blockingStub().import_(importRequest);
+
+            if (response.getStatus().getErrorCode() == ErrorCode.Success) {
+                logInfo("ImportRequest successfully!");
+                return R.success(response);
+            } else {
+                return failedStatus("ImportRequest", response.getStatus());
+            }
+        } catch (StatusRuntimeException e) {
+            logError("ImportRequest RPC failed:\n{}", e.getStatus().toString());
+            return R.failed(e);
+        } catch (Exception e) {
+            logError("ImportRequest failed:\n{}", e.getMessage());
+            return R.failed(e);
+        }
+    }
+
+    @Override
+    public R<GetImportStateResponse> getImportState(GetImportStateParam requestParam) {
+        if (!clientIsReady()) {
+            return R.failed(new ClientNotConnectedException("Client rpc channel is not ready"));
+        }
+
+        logInfo(requestParam.toString());
+
+        try {
+            GetImportStateRequest importRequest = GetImportStateRequest.newBuilder()
+                    .setTask(requestParam.getTaskID())
+                    .build();
+            GetImportStateResponse response = blockingStub().getImportState(importRequest);
+
+            if (response.getStatus().getErrorCode() == ErrorCode.Success) {
+                logInfo("GetImportStateRequest successfully!");
+                return R.success(response);
+            } else {
+                return failedStatus("GetImportStateRequest", response.getStatus());
+            }
+        } catch (StatusRuntimeException e) {
+            logError("GetImportStateRequest RPC failed:\n{}", e.getStatus().toString());
+            return R.failed(e);
+        } catch (Exception e) {
+            logError("GetImportStateRequest failed:\n{}", e.getMessage());
+            return R.failed(e);
+        }
+    }
+
     @Override
     public R<MutationResult> insert(@NonNull InsertParam requestParam) {
         if (!clientIsReady()) {

+ 16 - 0
src/main/java/io/milvus/client/MilvusClient.java

@@ -285,6 +285,22 @@ public interface MilvusClient {
      */
     R<MutationResult> delete(DeleteParam requestParam);
 
+    /**
+     * Import data from external files, currently support JSON/Numpy format
+     *
+     * @param requestParam {@link ImportParam}
+     * @return {status:result code, data:RpcStatus{msg: result message}}
+     */
+    R<ImportResponse> importData(ImportParam requestParam);
+
+    /**
+     * Import data from external files, currently support JSON/Numpy format
+     *
+     * @param requestParam {@link GetImportStateParam}
+     * @return {status:result code, data:GetImportStateResponse{status,state}}
+     */
+    R<GetImportStateResponse> getImportState(GetImportStateParam requestParam);
+
     /**
      * Conducts ANN search on a vector field. Use expression to do filtering before search.
      *

+ 13 - 0
src/main/java/io/milvus/client/MilvusMultiServiceClient.java

@@ -284,6 +284,19 @@ public class MilvusMultiServiceClient implements MilvusClient {
         return handleResponse(response);
     }
 
+    @Override
+    public R<ImportResponse> importData(@NonNull ImportParam requestParam) {
+        List<R<ImportResponse>> response = this.clusterFactory.getAvailableServerSettings().stream()
+                .map(serverSetting -> serverSetting.getClient().importData(requestParam))
+                .collect(Collectors.toList());
+        return handleResponse(response);
+    }
+
+    @Override
+    public R<GetImportStateResponse> getImportState(GetImportStateParam requestParam) {
+        return this.clusterFactory.getMaster().getClient().getImportState(requestParam);
+    }
+
     @Override
     public R<SearchResults> search(SearchParam requestParam) {
         return this.clusterFactory.getMaster().getClient().search(requestParam);

+ 3 - 0
src/main/java/io/milvus/param/Constant.java

@@ -33,6 +33,9 @@ public class Constant {
     public static final String ROUND_DECIMAL = "round_decimal";
     public static final String PARAMS = "params";
     public static final String ROW_COUNT = "row_count";
+    public static final String BUCKET = "bucket";
+    public static final String FAILED_REASON = "failed_reason";
+    public static final String IMPORT_FILES = "files";
 
     // max value for waiting loading collection/partition interval, unit: millisecond
     public static final Long MAX_WAITING_LOADING_INTERVAL = 2000L;

+ 1 - 1
src/main/java/io/milvus/param/collection/FieldType.java

@@ -208,7 +208,7 @@ public class FieldType {
                 ", type='" + dataType.name() + '\'' +
                 ", primaryKey=" + primaryKey +
                 ", autoID=" + autoID +
-                ", params=" + typeParams.toString() +
+                ", params=" + typeParams +
                 '}';
     }
 }

+ 67 - 0
src/main/java/io/milvus/param/dml/GetImportStateParam.java

@@ -0,0 +1,67 @@
+package io.milvus.param.dml;
+
+import io.milvus.exception.ParamException;
+import lombok.Getter;
+import lombok.NonNull;
+
+/**
+ * Parameters for <code>getImportState</code> interface.
+ */
+@Getter
+public class GetImportStateParam {
+    private final long taskID;
+
+    private GetImportStateParam(@NonNull Builder builder) {
+        this.taskID = builder.taskID;
+    }
+
+    public static Builder newBuilder() {
+        return new Builder();
+    }
+
+    /**
+     * Builder for {@link GetImportStateParam} class.
+     */
+    public static class Builder {
+        private Long taskID;
+
+        private Builder() {
+        }
+
+        /**
+         * Sets an import task id. The id is returned from importData() interface.
+         *
+         * @param taskID id of the task
+         * @return <code>Builder</code>
+         */
+        public Builder withTaskID(@NonNull Long taskID) {
+            this.taskID = taskID;
+            return this;
+        }
+
+        /**
+         * Verifies parameters and creates a new {@link GetImportStateParam} instance.
+         *
+         * @return {@link GetImportStateParam}
+         */
+        public GetImportStateParam build() throws ParamException {
+            if (this.taskID == null) {
+                throw new ParamException("Task ID not specified");
+            }
+
+            return new GetImportStateParam(this);
+        }
+    }
+
+    /**
+     * Constructs a <code>String</code> by {@link GetImportStateParam} instance.
+     *
+     * @return <code>String</code>
+     */
+    @Override
+    public String toString() {
+        return "GetImportStateParam{" +
+                "taskID='" + taskID + '\'' +
+                '}';
+    }
+}

+ 166 - 0
src/main/java/io/milvus/param/dml/ImportParam.java

@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.param.dml;
+
+import com.google.common.collect.Lists;
+import io.milvus.exception.ParamException;
+import io.milvus.param.Constant;
+import io.milvus.param.ParamUtils;
+
+import lombok.Getter;
+import lombok.NonNull;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Parameters for <code>importData</code> interface.
+ */
+@Getter
+public class ImportParam {
+    private final String collectionName;
+    private final String partitionName;
+    private final boolean rowBased;
+    private final List<String> files;
+    private final Map<String, String> options = new HashMap<>();
+
+    private ImportParam(@NonNull Builder builder) {
+        this.collectionName = builder.collectionName;
+        this.partitionName = builder.partitionName;
+        this.rowBased = builder.rowBased;
+        this.files = builder.files;
+        this.options.put(Constant.BUCKET, builder.bucketName);
+    }
+
+    public static Builder newBuilder() {
+        return new Builder();
+    }
+
+    /**
+     * Builder for {@link ImportParam} class.
+     */
+    public static class Builder {
+        private String collectionName;
+        private String partitionName = "";
+        private Boolean rowBased = Boolean.TRUE;
+        private final List<String> files = Lists.newArrayList();
+        private String bucketName = "";
+
+        private Builder() {
+        }
+
+        /**
+         * Sets the collection name. Collection name cannot be empty or null.
+         *
+         * @param collectionName collection name
+         * @return <code>Builder</code>
+         */
+        public Builder withCollectionName(@NonNull String collectionName) {
+            this.collectionName = collectionName;
+            return this;
+        }
+
+        /**
+         * Sets the partition name (Optional).
+         *
+         * @param partitionName partition name
+         * @return <code>Builder</code>
+         */
+        public Builder withPartitionName(@NonNull String partitionName) {
+            this.partitionName = partitionName;
+            return this;
+        }
+
+        /**
+         * Row-based or column-based data
+         *
+         * @param rowBased true: row-based, false: column-based
+         * @return <code>Builder</code>
+         */
+        public Builder withRowBased(@NonNull Boolean rowBased) {
+            this.rowBased = rowBased;
+            return this;
+        }
+
+        /**
+         * Sets bucket name where the files come from MinIO/S3 storage.
+         * If bucket is not specified, the server will use the default bucket to explore.
+         *
+         * @param bucketName bucket name
+         * @return <code>Builder</code>
+         */
+        public Builder withBucket(@NonNull String bucketName) {
+            this.bucketName = bucketName;
+            return this;
+        }
+
+        /**
+         * Specifies file paths to import. Each path is a relative path to the target bucket.
+         *
+         * @param files file paths list
+         * @return <code>Builder</code>
+         */
+        public Builder withFiles(@NonNull List<String> files) {
+            files.forEach(this::addFile);
+            return this;
+        }
+
+        /**
+         * Specifies a file paths to import. The path is a relative path to the target bucket.
+         *
+         * @param filePath file relative path
+         * @return <code>Builder</code>
+         */
+        public Builder addFile(@NonNull String filePath) {
+            if (!this.files.contains(filePath)) {
+                this.files.add(filePath);
+            }
+            return this;
+        }
+
+        /**
+         * Verifies parameters and creates a new {@link ImportParam} instance.
+         *
+         * @return {@link ImportParam}
+         */
+        public ImportParam build() throws ParamException {
+            ParamUtils.CheckNullEmptyString(collectionName, "Collection name");
+
+            return new ImportParam(this);
+        }
+    }
+
+    /**
+     * Constructs a <code>String</code> by {@link ImportParam} instance.
+     *
+     * @return <code>String</code>
+     */
+    @Override
+    public String toString() {
+        return "ImportParam{" +
+                "collectionName='" + collectionName + '\'' +
+                ", partitionName='" + partitionName + '\'' +
+                ", files='" + files.toString() + '\'' +
+                ", rowBased='" + rowBased + '\'' +
+                ", options=" + options.toString() +
+                '}';
+    }
+}

+ 91 - 0
src/main/java/io/milvus/response/GetImportStateWrapper.java

@@ -0,0 +1,91 @@
+package io.milvus.response;
+
+import io.milvus.grpc.GetImportStateResponse;
+import io.milvus.grpc.ImportState;
+import io.milvus.grpc.KeyValuePair;
+import io.milvus.param.Constant;
+import lombok.NonNull;
+
+import java.util.List;
+
+/**
+ * Util class to wrap response of <code>getImportState</code> interface.
+ */
+public class GetImportStateWrapper {
+    private final GetImportStateResponse response;
+
+    public GetImportStateWrapper(@NonNull GetImportStateResponse response) {
+        this.response = response;
+    }
+
+    /**
+     * Gets the long ID array for auto-id primary key, generated by import task.
+     *
+     * @return List&lt;Long&gt; ID array returned by import task
+     */
+    public List<Long> getAutoGeneratedIDs() {
+        return response.getIdListList();
+    }
+
+    /**
+     * Gets state of the import task.
+     *
+     * @return ImportState state of the import task
+     */
+    public ImportState getState() {
+        return response.getState();
+    }
+
+    /**
+     * Gets how many rows were imported by the import task.
+     *
+     * @return Long how many rows were imported by the import task
+     */
+    public long getImportedCount() {
+        return response.getRowCount();
+    }
+
+    /**
+     * Gets failed reason of the import task.
+     *
+     * @return String failed reason of the import task
+     */
+    public String getFailedReason() {
+        return getInfo(Constant.FAILED_REASON);
+    }
+
+    /**
+     * Gets target files of the import task.
+     *
+     * @return String target files of the import task
+     */
+    public String getFiles() {
+        return getInfo(Constant.IMPORT_FILES);
+    }
+
+    private String getInfo(@NonNull String key) {
+        List<KeyValuePair> infos = response.getInfosList();
+        for (KeyValuePair kv : infos) {
+            if (kv.getKey().compareTo(key) == 0) {
+                return kv.getValue();
+            }
+        }
+
+        return "";
+    }
+
+    /**
+     * Construct a <code>String</code> by {@link DescCollResponseWrapper} instance.
+     *
+     * @return <code>String</code>
+     */
+    @Override
+    public String toString() {
+        return "Import task state{" +
+                ", autoGeneratedIDs:" + getAutoGeneratedIDs() +
+                ", state:" + getState().name() +
+                ", failed reason:" + getFailedReason() +
+                ", files:" + getFiles() +
+                '}';
+    }
+}

+ 1 - 1
src/main/milvus-proto

@@ -1 +1 @@
-Subproject commit 614d2acd1d5b31d2e377a5a8b282a431847f7be2
+Subproject commit 3ad2f54e8372aeb5b3763d3a1d270d3ab5864088

+ 123 - 68
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -248,31 +248,31 @@ class MilvusServiceClientTest {
     void createCollectionParam() {
         // test throw exception with illegal input
         assertThrows(ParamException.class, () ->
-            FieldType.newBuilder()
-                    .withName("")
-                    .withDataType(DataType.Int64)
-                    .build()
+                FieldType.newBuilder()
+                        .withName("")
+                        .withDataType(DataType.Int64)
+                        .build()
         );
 
         assertThrows(ParamException.class, () ->
-            FieldType.newBuilder()
-                    .withName("userID")
-                    .build()
+                FieldType.newBuilder()
+                        .withName("userID")
+                        .build()
         );
 
         assertThrows(ParamException.class, () ->
-            FieldType.newBuilder()
-                    .withName("userID")
-                    .withDataType(DataType.FloatVector)
-                    .build()
+                FieldType.newBuilder()
+                        .withName("userID")
+                        .withDataType(DataType.FloatVector)
+                        .build()
         );
 
         assertThrows(ParamException.class, () ->
-            CreateCollectionParam
-                    .newBuilder()
-                    .withCollectionName("collection1")
-                    .withShardsNum(2)
-                    .build()
+                CreateCollectionParam
+                        .newBuilder()
+                        .withCollectionName("collection1")
+                        .withShardsNum(2)
+                        .build()
         );
 
         FieldType fieldType1 = FieldType.newBuilder()
@@ -284,12 +284,12 @@ class MilvusServiceClientTest {
                 .build();
 
         assertThrows(ParamException.class, () ->
-            CreateCollectionParam
-                    .newBuilder()
-                    .withCollectionName("")
-                    .withShardsNum(2)
-                    .addFieldType(fieldType1)
-                    .build()
+                CreateCollectionParam
+                        .newBuilder()
+                        .withCollectionName("")
+                        .withShardsNum(2)
+                        .addFieldType(fieldType1)
+                        .build()
         );
 
         assertThrows(ParamException.class, () ->
@@ -298,7 +298,7 @@ class MilvusServiceClientTest {
                         .withCollectionName("collection1")
                         .withShardsNum(0)
                         .addFieldType(fieldType1)
-                    .build()
+                        .build()
         );
 
         List<FieldType> fields = Collections.singletonList(null);
@@ -370,9 +370,9 @@ class MilvusServiceClientTest {
     void dropCollectionParam() {
         // test throw exception with illegal input
         assertThrows(ParamException.class, () ->
-            DropCollectionParam.newBuilder()
-                    .withCollectionName("")
-                    .build()
+                DropCollectionParam.newBuilder()
+                        .withCollectionName("")
+                        .build()
         );
     }
 
@@ -389,9 +389,9 @@ class MilvusServiceClientTest {
     void getCollectionStatisticsParam() {
         // test throw exception with illegal input
         assertThrows(ParamException.class, () ->
-            GetCollectionStatisticsParam.newBuilder()
-                    .withCollectionName("")
-                    .build()
+                GetCollectionStatisticsParam.newBuilder()
+                        .withCollectionName("")
+                        .build()
         );
     }
 
@@ -425,7 +425,7 @@ class MilvusServiceClientTest {
             mockServerImpl.setGetFlushStateResponse(GetFlushStateResponse.newBuilder()
                     .setFlushed(true)
                     .build());
-        },"RefreshFlushState").start();
+        }, "RefreshFlushState").start();
 
         R<GetCollectionStatisticsResponse> resp = client.getCollectionStatistics(param);
         assertEquals(R.Status.Success.getCode(), resp.getStatus());
@@ -447,9 +447,9 @@ class MilvusServiceClientTest {
     void hasCollectionParam() {
         // test throw exception with illegal input
         assertThrows(ParamException.class, () ->
-            HasCollectionParam.newBuilder()
-                    .withCollectionName("")
-                    .build()
+                HasCollectionParam.newBuilder()
+                        .withCollectionName("")
+                        .build()
         );
     }
 
@@ -466,33 +466,33 @@ class MilvusServiceClientTest {
     void loadCollectionParam() {
         // test throw exception with illegal input
         assertThrows(ParamException.class, () ->
-            LoadCollectionParam.newBuilder()
-                    .withCollectionName("")
-                    .build()
+                LoadCollectionParam.newBuilder()
+                        .withCollectionName("")
+                        .build()
         );
 
         assertThrows(ParamException.class, () ->
-            LoadCollectionParam.newBuilder()
-                    .withCollectionName("collection1")
-                    .withSyncLoad(Boolean.TRUE)
-                    .withSyncLoadWaitingInterval(0L)
-                    .build()
+                LoadCollectionParam.newBuilder()
+                        .withCollectionName("collection1")
+                        .withSyncLoad(Boolean.TRUE)
+                        .withSyncLoadWaitingInterval(0L)
+                        .build()
         );
 
         assertThrows(ParamException.class, () ->
-            LoadCollectionParam.newBuilder()
-                    .withCollectionName("collection1")
-                    .withSyncLoad(Boolean.TRUE)
-                    .withSyncLoadWaitingInterval(-1L)
-                    .build()
+                LoadCollectionParam.newBuilder()
+                        .withCollectionName("collection1")
+                        .withSyncLoad(Boolean.TRUE)
+                        .withSyncLoadWaitingInterval(-1L)
+                        .build()
         );
 
         assertThrows(ParamException.class, () ->
-            LoadCollectionParam.newBuilder()
-                    .withCollectionName("collection1")
-                    .withSyncLoad(Boolean.TRUE)
-                    .withSyncLoadWaitingInterval(Constant.MAX_WAITING_LOADING_INTERVAL + 1)
-                    .build()
+                LoadCollectionParam.newBuilder()
+                        .withCollectionName("collection1")
+                        .withSyncLoad(Boolean.TRUE)
+                        .withSyncLoadWaitingInterval(Constant.MAX_WAITING_LOADING_INTERVAL + 1)
+                        .build()
         );
 
         assertThrows(ParamException.class, () ->
@@ -557,7 +557,7 @@ class MilvusServiceClientTest {
                         .addInMemoryPercentages(100)
                         .build());
             }
-        },"RefreshMemState").start();
+        }, "RefreshMemState").start();
 
         param = LoadCollectionParam.newBuilder()
                 .withCollectionName(collectionName)
@@ -584,9 +584,9 @@ class MilvusServiceClientTest {
     void releaseCollectionParam() {
         // test throw exception with illegal input
         assertThrows(ParamException.class, () ->
-            ReleaseCollectionParam.newBuilder()
-                    .withCollectionName("")
-                    .build()
+                ReleaseCollectionParam.newBuilder()
+                        .withCollectionName("")
+                        .build()
         );
     }
 
@@ -611,9 +611,9 @@ class MilvusServiceClientTest {
         );
 
         assertThrows(ParamException.class, () ->
-            ShowCollectionsParam.newBuilder()
-                    .addCollectionName("")
-                    .build()
+                ShowCollectionsParam.newBuilder()
+                        .addCollectionName("")
+                        .build()
         );
 
         // verify internal param
@@ -888,7 +888,7 @@ class MilvusServiceClientTest {
                     TimeUnit.MILLISECONDS.sleep(100);
                     mockServerImpl.setShowPartitionsResponse(ShowPartitionsResponse.newBuilder()
                             .addPartitionNames(partitionName)
-                            .addInMemoryPercentages(i*10)
+                            .addInMemoryPercentages(i * 10)
                             .build());
                 }
             } catch (InterruptedException e) {
@@ -897,7 +897,7 @@ class MilvusServiceClientTest {
                         .addInMemoryPercentages(100)
                         .build());
             }
-        },"RefreshMemState").start();
+        }, "RefreshMemState").start();
 
         param = LoadPartitionsParam.newBuilder()
                 .withCollectionName(collectionName)
@@ -1401,7 +1401,7 @@ class MilvusServiceClientTest {
 
         List<DataType> testTypes = Arrays.asList(DataType.Int64, DataType.Int32, DataType.Int16, DataType.Int8,
                 DataType.Float, DataType.Double, DataType.Bool, DataType.BinaryVector);
-        testTypes.forEach((tp)->{
+        testTypes.forEach((tp) -> {
             fields.clear();
             List<String> fakeVectors3 = Arrays.asList("1", "2", "3");
             fields.add(new InsertParam.Field("field3", tp, fakeVectors3));
@@ -1460,15 +1460,15 @@ class MilvusServiceClientTest {
         List<ByteBuffer> bVectors = new ArrayList<>();
         List<List<Float>> fVectors = new ArrayList<>();
         for (int i = 0; i < 3; ++i) {
-            ids.add((long)i);
+            ids.add((long) i);
             nVal.add(i);
             bVal.add(Boolean.TRUE);
             fVal.add(0.5f);
             dVal.add(1.0);
             sVal.add(String.valueOf(i));
             ByteBuffer buf = ByteBuffer.allocate(2);
-            buf.put((byte)1);
-            buf.put((byte)2);
+            buf.put((byte) 1);
+            buf.put((byte) 2);
             bVectors.add(buf);
             List<Float> vec = Arrays.asList(0.1f, 0.2f);
             fVectors.add(vec);
@@ -1603,6 +1603,30 @@ class MilvusServiceClientTest {
         testFuncByName("delete", param);
     }
 
+    @Test
+    void import_() {
+        List<String> files = Collections.singletonList("f1");
+        ImportParam param = ImportParam.newBuilder()
+                .withCollectionName("collection1")
+                .withPartitionName("partition1")
+                .withRowBased(true)
+                .addFile("dummy.json")
+                .withFiles(files)
+                .withBucket("myBucket")
+                .build();
+
+        testFuncByName("importData", param);
+    }
+
+    @Test
+    void getImportState() {
+        GetImportStateParam param = GetImportStateParam.newBuilder()
+                .withTaskID(100L)
+                .build();
+
+        testFuncByName("getImportState", param);
+    }
+
     @Test
     void searchParam() {
         // test throw exception with illegal input
@@ -1799,8 +1823,8 @@ class MilvusServiceClientTest {
 
         List<ByteBuffer> bVectors = new ArrayList<>();
         ByteBuffer buf = ByteBuffer.allocate(2);
-        buf.put((byte)1);
-        buf.put((byte)2);
+        buf.put((byte) 1);
+        buf.put((byte) 2);
         bVectors.add(buf);
         param = SearchParam.newBuilder()
                 .withCollectionName("collection1")
@@ -2442,7 +2466,7 @@ class MilvusServiceClientTest {
         SearchResultsWrapper intWrapper = new SearchResultsWrapper(results);
         assertThrows(ParamException.class, () -> intWrapper.getFieldData(fieldName, -1));
         assertThrows(ParamException.class, () -> intWrapper.getFieldData("invalid", 0));
-        assertEquals(topK, intWrapper.getFieldData(fieldName, (int)numQueries-1).size());
+        assertEquals(topK, intWrapper.getFieldData(fieldName, (int) numQueries - 1).size());
 
         List<SearchResultsWrapper.IDScore> idScores = intWrapper.getIDScore(1);
         assertFalse(idScores.toString().isEmpty());
@@ -2468,7 +2492,7 @@ class MilvusServiceClientTest {
         assertFalse(idScores.toString().isEmpty());
         assertEquals(topK, idScores.size());
 
-        idScores.forEach((score)->assertFalse(score.toString().isEmpty()));
+        idScores.forEach((score) -> assertFalse(score.toString().isEmpty()));
     }
 
     @Test
@@ -2526,4 +2550,35 @@ class MilvusServiceClientTest {
             assertFalse(info.toString().isEmpty());
         }
     }
+
+    @Test
+    void testGetImportStateWrapper() {
+        long count = 1000;
+        long id = 88;
+        ImportState state = ImportState.ImportStarted;
+        String reason = "unexpected error";
+        String files = "1.json";
+        GetImportStateResponse reso = GetImportStateResponse.newBuilder()
+                .setState(state)
+                .setRowCount(count)
+                .addIdList(id)
+                .addInfos(KeyValuePair.newBuilder()
+                        .setKey(Constant.FAILED_REASON)
+                        .setValue(reason)
+                        .build())
+                .addInfos(KeyValuePair.newBuilder()
+                        .setKey(Constant.IMPORT_FILES)
+                        .setValue(files)
+                        .build())
+                .build();
+
+        GetImportStateWrapper wrapper = new GetImportStateWrapper(reso);
+        assertEquals(count, wrapper.getImportedCount());
+        assertEquals(1, wrapper.getAutoGeneratedIDs().size());
+        assertEquals(id, wrapper.getAutoGeneratedIDs().get(0));
+        assertEquals(reason, wrapper.getFailedReason());
+        assertEquals(files, wrapper.getFiles());
+
+        assertFalse(wrapper.toString().isEmpty());
+    }
 }

+ 20 - 0
src/test/java/io/milvus/server/MockMilvusServerImpl.java

@@ -50,6 +50,8 @@ public class MockMilvusServerImpl extends MilvusServiceGrpc.MilvusServiceImplBas
     private io.milvus.grpc.Status respDropIndex;
     private io.milvus.grpc.MutationResult respInsert;
     private io.milvus.grpc.MutationResult respDelete;
+    private io.milvus.grpc.ImportResponse respImport;
+    private io.milvus.grpc.GetImportStateResponse respImportState;
     private io.milvus.grpc.SearchResults respSearch;
     private io.milvus.grpc.FlushResponse respFlush;
     private io.milvus.grpc.QueryResults respQuery;
@@ -388,6 +390,24 @@ public class MockMilvusServerImpl extends MilvusServiceGrpc.MilvusServiceImplBas
         responseObserver.onCompleted();
     }
 
+    @Override
+    public void import_(io.milvus.grpc.ImportRequest request,
+                       io.grpc.stub.StreamObserver<io.milvus.grpc.ImportResponse> responseObserver) {
+        logger.info("import() call");
+
+        responseObserver.onNext(respImport);
+        responseObserver.onCompleted();
+    }
+
+    @Override
+    public void getImportState(io.milvus.grpc.GetImportStateRequest request,
+                        io.grpc.stub.StreamObserver<io.milvus.grpc.GetImportStateResponse> responseObserver) {
+        logger.info("getImportState() call");
+
+        responseObserver.onNext(respImportState);
+        responseObserver.onCompleted();
+    }
+
     public void setDeleteResponse(io.milvus.grpc.MutationResult resp) {
         respDelete = resp;
     }