Преглед на файлове

ClientPool supports different ConnectConfig for different key (#1594)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot преди 3 седмици
родител
ревизия
a8d9c0d663

+ 2 - 2
docker-compose.yml

@@ -3,7 +3,7 @@ version: '3.5'
 services:
   standalone:
     container_name: milvus-javasdk-standalone-1
-    image: milvusdb/milvus:v2.6.0
+    image: milvusdb/milvus:v2.6.1
     command: [ "milvus", "run", "standalone" ]
     environment:
       - COMMON_STORAGETYPE=local
@@ -24,7 +24,7 @@ services:
 
   standaloneslave:
     container_name: milvus-javasdk-standalone-2
-    image: milvusdb/milvus:v2.6.0
+    image: milvusdb/milvus:v2.6.1
     command: [ "milvus", "run", "standalone" ]
     environment:
       - COMMON_STORAGETYPE=local

+ 214 - 76
examples/src/main/java/io/milvus/v1/ClientPoolExample.java

@@ -25,85 +25,135 @@ import io.milvus.client.MilvusClient;
 import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.grpc.DataType;
 import io.milvus.grpc.MutationResult;
+import io.milvus.grpc.QueryResults;
 import io.milvus.grpc.SearchResults;
 import io.milvus.param.*;
-import io.milvus.param.collection.CreateCollectionParam;
-import io.milvus.param.collection.DropCollectionParam;
-import io.milvus.param.collection.FieldType;
-import io.milvus.param.collection.LoadCollectionParam;
+import io.milvus.param.collection.*;
 import io.milvus.param.dml.InsertParam;
+import io.milvus.param.dml.QueryParam;
 import io.milvus.param.dml.SearchParam;
 import io.milvus.param.index.CreateIndexParam;
 import io.milvus.pool.MilvusClientV1Pool;
 import io.milvus.pool.PoolConfig;
+import io.milvus.response.QueryResultsWrapper;
 
 import java.time.Duration;
 import java.util.*;
 
 public class ClientPoolExample {
-    public static String CollectionName = "java_sdk_example_pool_v2";
+    public static String serverUri = "http://localhost:19530";
+    public static String CollectionName = "java_sdk_example_pool_v1";
     public static String VectorFieldName = "vector";
     public static int DIM = 128;
+    public static List<String> dbNames = Arrays.asList("AA", "BB", "CC");
 
-    public static void createCollection(MilvusClientV1Pool pool) {
+    private static void printKeyClientNumber(MilvusClientV1Pool pool, String key) {
+        System.out.printf("Key '%s': %d idle clients and %d active clients%n",
+                key, pool.getIdleClientNumber(key), pool.getActiveClientNumber(key));
+    }
+    private static void printClientNumber(MilvusClientV1Pool pool) {
+        System.out.println("======================================================================");
+        System.out.printf("Total %d idle clients and %d active clients%n",
+                pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber());
+        for (String dbName : dbNames) {
+            printKeyClientNumber(pool, dbName);
+        }
+        System.out.println("======================================================================");
+    }
+
+    public static void createDatabases(MilvusClientV1Pool pool) {
+        // get a client, the client uses the default config to connect milvus(to the default database)
         String tempKey = "temp";
         MilvusClient client = pool.getClient(tempKey);
         if (client == null) {
             throw new RuntimeException("Unable to create client");
         }
         try {
-            client.dropCollection(DropCollectionParam.newBuilder()
-                    .withCollectionName(CollectionName)
-                    .build());
-            List<FieldType> fieldsSchema = Arrays.asList(
-                    FieldType.newBuilder()
-                            .withName("id")
-                            .withDataType(DataType.Int64)
-                            .withPrimaryKey(true)
-                            .withAutoID(true)
-                            .build(),
-                    FieldType.newBuilder()
-                            .withName(VectorFieldName)
-                            .withDataType(DataType.FloatVector)
-                            .withDimension(DIM)
-                            .build()
-            );
-
-            // Create the collection with 3 fields
-            R<RpcStatus> ret = client.createCollection(CreateCollectionParam.newBuilder()
-                    .withCollectionName(CollectionName)
-                    .withFieldTypes(fieldsSchema)
-                    .build());
-            if (ret.getStatus() != R.Status.Success.getCode()) {
-                throw new RuntimeException("Failed to create collection, error: " + ret.getMessage());
-            }
-
-            ret = client.createIndex(CreateIndexParam.newBuilder()
-                    .withCollectionName(CollectionName)
-                    .withFieldName(VectorFieldName)
-                    .withIndexType(IndexType.FLAT)
-                    .withMetricType(MetricType.L2)
-                    .build());
-            if (ret.getStatus() != R.Status.Success.getCode()) {
-                throw new RuntimeException("Failed to create index on vector field, error: " + ret.getMessage());
+            for (String dbName : dbNames) {
+                client.createDatabase(CreateDatabaseParam.newBuilder()
+                        .withDatabaseName(dbName)
+                        .build());
+                System.out.println("Database created: " + dbName);
             }
-
-            client.loadCollection(LoadCollectionParam.newBuilder()
-                    .withCollectionName(CollectionName)
-                    .build());
-
-            System.out.printf("Collection '%s' created%n", CollectionName);
         } catch (Exception e) {
-            String msg = String.format("Failed to create collection, error: %s%n", e.getMessage());
+            String msg = String.format("Failed to create database, error: %s%n", e.getMessage());
             System.out.printf(msg);
             throw new RuntimeException(msg);
         } finally {
             pool.returnClient(tempKey, client);
-            pool.clear(tempKey);
+            pool.clear(tempKey); // this call will destroy the temp client
+        }
+
+        // predefine a connect config for each key(name of a database)
+        // the ClientPool will use different config to create client to connect to specific database
+        for (String dbName : dbNames) {
+            ConnectParam connectConfig = ConnectParam.newBuilder()
+                    .withUri(serverUri)
+                    .withDatabaseName(dbName)
+                    .build();
+            pool.configForKey(dbName, connectConfig);
         }
     }
 
-    public static Thread runInsertThread(MilvusClientV1Pool pool, String clientName, int repeatRequests) {
+    public static void createCollections(MilvusClientV1Pool pool) {
+        for (String dbName : dbNames) {
+            // this client connects to the database of dbName because we have predefined
+            // a ConnectConfig for this key
+            MilvusClient client = pool.getClient(dbName);
+            if (client == null) {
+                throw new RuntimeException("Unable to create client");
+            }
+            try {
+                client.dropCollection(DropCollectionParam.newBuilder()
+                        .withCollectionName(CollectionName)
+                        .build());
+                List<FieldType> fieldsSchema = Arrays.asList(
+                        FieldType.newBuilder()
+                                .withName("id")
+                                .withDataType(DataType.Int64)
+                                .withPrimaryKey(true)
+                                .withAutoID(true)
+                                .build(),
+                        FieldType.newBuilder()
+                                .withName(VectorFieldName)
+                                .withDataType(DataType.FloatVector)
+                                .withDimension(DIM)
+                                .build()
+                );
+                R<RpcStatus> ret = client.createCollection(CreateCollectionParam.newBuilder()
+                        .withCollectionName(CollectionName)
+                        .withFieldTypes(fieldsSchema)
+                        .build());
+                if (ret.getStatus() != R.Status.Success.getCode()) {
+                    throw new RuntimeException("Failed to create collection, error: " + ret.getMessage());
+                }
+
+                ret = client.createIndex(CreateIndexParam.newBuilder()
+                        .withCollectionName(CollectionName)
+                        .withFieldName(VectorFieldName)
+                        .withIndexType(IndexType.FLAT)
+                        .withMetricType(MetricType.L2)
+                        .build());
+                if (ret.getStatus() != R.Status.Success.getCode()) {
+                    throw new RuntimeException("Failed to create index on vector field, error: " + ret.getMessage());
+                }
+
+                client.loadCollection(LoadCollectionParam.newBuilder()
+                        .withCollectionName(CollectionName)
+                        .build());
+
+                System.out.printf("Collection '%s' created in database '%s'%n", CollectionName, dbName);
+            } catch (Exception e) {
+                String msg = String.format("Failed to create collection, error: %s%n", e.getMessage());
+                System.out.printf(msg);
+                throw new RuntimeException(msg);
+            } finally {
+                pool.returnClient(dbName, client);
+            }
+        }
+    }
+
+    public static Thread runInsertThread(MilvusClientV1Pool pool, String dbName, int repeatRequests) {
         Thread t = new Thread(() -> {
             Gson gson = new Gson();
             for (int i = 0; i < repeatRequests; i++) {
@@ -112,7 +162,7 @@ public class ClientPoolExample {
                     try {
                         // getClient() might exceeds the borrowMaxWaitMillis and throw exception
                         // retry to call until it return a client
-                        client = pool.getClient(clientName);
+                        client = pool.getClient(dbName);
                     } catch (Exception e) {
                         System.out.printf("Failed to get client, will retry, error: %s%n", e.getMessage());
                     }
@@ -133,20 +183,21 @@ public class ClientPoolExample {
                     if (insertRet.getStatus() != R.Status.Success.getCode()) {
                         throw new RuntimeException("Failed to insert, error: " + insertRet.getMessage());
                     }
-                    System.out.printf("%d rows inserted%n", rows.size());
+//                    System.out.printf("%d rows inserted%n", rows.size());
                 } catch (Exception e) {
                     System.out.printf("Failed to inserted, error: %s%n", e.getMessage());
                 } finally {
-                    pool.returnClient(clientName, client); // make sure the client is returned after use
+                    pool.returnClient(dbName, client); // make sure the client is returned after use
                 }
             }
             System.out.printf("Insert thread %s finished%n", Thread.currentThread().getName());
+            printKeyClientNumber(pool, dbName);
         });
         t.start();
         return t;
     }
 
-    public static Thread runSearchThread(MilvusClientV1Pool pool, String clientName, int repeatRequests) {
+    public static Thread runSearchThread(MilvusClientV1Pool pool, String dbName, int repeatRequests) {
         Thread t = new Thread(() -> {
             for (int i = 0; i < repeatRequests; i++) {
                 MilvusClient client = null;
@@ -154,7 +205,7 @@ public class ClientPoolExample {
                     try {
                         // getClient() might exceeds the borrowMaxWaitMillis and throw exception
                         // retry to call until it return a client
-                        client = pool.getClient(clientName);
+                        client = pool.getClient(dbName);
                     } catch (Exception e) {
                         System.out.printf("Failed to get client, will retry, error: %s%n", e.getMessage());
                     }
@@ -172,28 +223,109 @@ public class ClientPoolExample {
                     if (searchRet.getStatus() != R.Status.Success.getCode()) {
                         throw new RuntimeException("Failed to search, error: " + searchRet.getMessage());
                     }
-                    System.out.println("A search request completed");
+//                    System.out.println("A search request completed");
                 } catch (Exception e) {
                     System.out.printf("Failed to search, error: %s%n", e.getMessage());
                 } finally {
-                    pool.returnClient(clientName, client); // make sure the client is returned after use
+                    pool.returnClient(dbName, client); // make sure the client is returned after use
                 }
             }
             System.out.printf("Search thread %s finished%n", Thread.currentThread().getName());
+            printKeyClientNumber(pool, dbName);
         });
         t.start();
         return t;
     }
 
+    public static void verifyRowCount(MilvusClientV1Pool pool, long expectedCount) {
+        for (String dbName : dbNames) {
+            // this client connects to the database of dbName because we have predefined
+            // a ConnectConfig for this key
+            MilvusClient client = pool.getClient(dbName);
+            if (client == null) {
+                throw new RuntimeException("Unable to create client");
+            }
+            try {
+                R<QueryResults> queryRet = client.query(QueryParam.newBuilder()
+                        .withCollectionName(CollectionName)
+                        .withExpr("")
+                        .addOutField("count(*)")
+                        .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
+                        .build());
+                QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryRet.getData());
+                long rowCount = (long)queryWrapper.getFieldWrapper("count(*)").getFieldData().get(0);
+                System.out.printf("%d rows persisted in collection '%s' of database '%s'%n",
+                        rowCount, CollectionName, dbName);
+                if (rowCount != expectedCount) {
+                    throw new RuntimeException("The persisted row count is not equal to expected");
+                }
+            } catch (Exception e) {
+                String msg = String.format("Failed to get row count, error: %s%n", e.getMessage());
+                System.out.printf(msg);
+                throw new RuntimeException(msg);
+            } finally {
+                pool.returnClient(dbName, client);
+            }
+        }
+    }
+
+    public static void dropCollections(MilvusClientV1Pool pool) {
+        for (String dbName : dbNames) {
+            // this client connects to the database of dbName because we have predefined
+            // a ConnectConfig for this key
+            MilvusClient client = pool.getClient(dbName);
+            if (client == null) {
+                throw new RuntimeException("Unable to create client");
+            }
+            try {
+                client.dropCollection(DropCollectionParam.newBuilder()
+                        .withCollectionName(CollectionName)
+                        .build());
+                System.out.printf("Collection '%s' dropped in database '%s'%n", CollectionName, dbName);
+            } catch (Exception e) {
+                String msg = String.format("Failed to drop collection, error: %s%n", e.getMessage());
+                System.out.printf(msg);
+                throw new RuntimeException(msg);
+            } finally {
+                pool.returnClient(dbName, client);
+            }
+        }
+    }
+
+    public static void dropDatabases(MilvusClientV1Pool pool) {
+        // get a client, the client uses the default config to connect milvus(to the default database)
+        String tempKey = "temp";
+        MilvusClient client = pool.getClient(tempKey);
+        if (client == null) {
+            throw new RuntimeException("Unable to create client");
+        }
+        try {
+            for (String dbName : dbNames) {
+                client.dropDatabase(DropDatabaseParam.newBuilder()
+                        .withDatabaseName(dbName)
+                        .build());
+                System.out.println("Database dropped: " + dbName);
+            }
+        } catch (Exception e) {
+            String msg = String.format("Failed to drop database, error: %s%n", e.getMessage());
+            System.out.printf(msg);
+            throw new RuntimeException(msg);
+        } finally {
+            pool.returnClient(tempKey, client);
+            pool.clear(tempKey); // this call will destroy the temp client
+        }
+    }
+
     public static void main(String[] args) throws InterruptedException {
         ConnectParam connectConfig = ConnectParam.newBuilder()
-                .withHost("localhost")
-                .withPort(19530)
+                .withUri(serverUri)
                 .build();
+        // read this issue for more details about the pool configurations:
+        // https://github.com/milvus-io/milvus-sdk-java/issues/1577
         PoolConfig poolConfig = PoolConfig.builder()
                 .maxIdlePerKey(10) // max idle clients per key
-                .maxTotalPerKey(20) // max total(idle + active) clients per key
-                .maxTotal(100) // max total clients for all keys
+                .maxTotalPerKey(50) // max total(idle + active) clients per key
+                .maxTotal(1000) // max total clients for all keys
                 .maxBlockWaitDuration(Duration.ofSeconds(5L)) // getClient() will wait 5 seconds if no idle client available
                 .minEvictableIdleDuration(Duration.ofSeconds(10L)) // if number of idle clients is larger than maxIdlePerKey, redundant idle clients will be evicted after 10 seconds
                 .build();
@@ -205,35 +337,41 @@ public class ClientPoolExample {
             return;
         }
 
-        createCollection(pool);
+        // create some databases
+        createDatabases(pool);
+        // create a collection in each database
+        createCollections(pool);
 
         List<Thread> threadList = new ArrayList<>();
         int threadCount = 100;
         int repeatRequests = 100;
         long start = System.currentTimeMillis();
         for (int k = 0; k < threadCount; k++) {
-            threadList.add(runInsertThread(pool, "192.168.1.1", repeatRequests));
-            threadList.add(runInsertThread(pool, "192.168.1.2", repeatRequests));
-            threadList.add(runInsertThread(pool, "192.168.1.3", repeatRequests));
-
-            threadList.add(runSearchThread(pool, "192.168.1.1", repeatRequests));
-            threadList.add(runSearchThread(pool, "192.168.1.2", repeatRequests));
-            threadList.add(runSearchThread(pool, "192.168.1.3", repeatRequests));
-
-            System.out.printf("Total %d idle clients and %d active clients%n",
-                    pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber());
+            for (String dbName : dbNames) {
+                threadList.add(runInsertThread(pool, dbName, repeatRequests));
+                threadList.add(runSearchThread(pool, dbName, repeatRequests));
+            }
+            printClientNumber(pool);
         }
-
         for (Thread t : threadList) {
             t.join();
         }
+        printClientNumber(pool);
+
+        // check row count of each collection, there are threadCount*repeatRequests rows were inserted by multiple threads
+        verifyRowCount(pool, threadCount*repeatRequests);
+        // drop collections
+        dropCollections(pool);
+        // drop databases, only after database is empty, it is able to be dropped
+        dropDatabases(pool);
+
         long end = System.currentTimeMillis();
         System.out.printf("%d insert requests and %d search requests finished in %.3f seconds%n",
                 threadCount*repeatRequests*3, threadCount*repeatRequests*3, (end-start)*0.001);
-        System.out.printf("Total %d idle clients and %d active clients%n",
-                pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber());
+
+        printClientNumber(pool);
         pool.clear(); // clear idle clients
-        System.out.printf("After clear, total %d idle clients and %d active clients%n",
-                pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber());
+        printClientNumber(pool);
+        pool.close();
     }
 }

+ 190 - 45
examples/src/main/java/io/milvus/v2/ClientPoolExample.java

@@ -30,52 +30,109 @@ import io.milvus.v2.common.ConsistencyLevel;
 import io.milvus.v2.common.DataType;
 import io.milvus.v2.service.collection.request.CreateCollectionReq;
 import io.milvus.v2.service.collection.request.DropCollectionReq;
+import io.milvus.v2.service.database.request.CreateDatabaseReq;
+import io.milvus.v2.service.database.request.DropDatabaseReq;
 import io.milvus.v2.service.vector.request.InsertReq;
+import io.milvus.v2.service.vector.request.QueryReq;
 import io.milvus.v2.service.vector.request.SearchReq;
 import io.milvus.v2.service.vector.request.data.FloatVec;
 import io.milvus.v2.service.vector.response.InsertResp;
+import io.milvus.v2.service.vector.response.QueryResp;
 import io.milvus.v2.service.vector.response.SearchResp;
 
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
 public class ClientPoolExample {
+    public static String serverUri = "http://localhost:19530";
     public static String CollectionName = "java_sdk_example_pool_v2";
     public static String VectorFieldName = "vector";
     public static int DIM = 128;
+    public static List<String> dbNames = Arrays.asList("AA", "BB", "CC");
 
-    public static void createCollection(MilvusClientV2Pool pool) {
+    private static void printKeyClientNumber(MilvusClientV2Pool pool, String key) {
+        System.out.printf("Key '%s': %d idle clients and %d active clients%n",
+                key, pool.getIdleClientNumber(key), pool.getActiveClientNumber(key));
+    }
+    private static void printClientNumber(MilvusClientV2Pool pool) {
+        System.out.println("======================================================================");
+        System.out.printf("Total %d idle clients and %d active clients%n",
+                pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber());
+        for (String dbName : dbNames) {
+            printKeyClientNumber(pool, dbName);
+        }
+        System.out.println("======================================================================");
+    }
+
+    public static void createDatabases(MilvusClientV2Pool pool) {
+        // get a client, the client uses the default config to connect milvus(to the default database)
         String tempKey = "temp";
         MilvusClientV2 client = pool.getClient(tempKey);
         if (client == null) {
             throw new RuntimeException("Unable to create client");
         }
         try {
-            client.dropCollection(DropCollectionReq.builder()
-                    .collectionName(CollectionName)
-                    .build());
-            client.createCollection(CreateCollectionReq.builder()
-                    .collectionName(CollectionName)
-                    .primaryFieldName("id")
-                    .idType(DataType.Int64)
-                    .autoID(Boolean.TRUE)
-                    .vectorFieldName(VectorFieldName)
-                    .dimension(DIM)
-                    .build());
-            System.out.printf("Collection '%s' created%n", CollectionName);
+            for (String dbName : dbNames) {
+                client.createDatabase(CreateDatabaseReq.builder()
+                        .databaseName(dbName)
+                        .build());
+                System.out.println("Database created: " + dbName);
+            }
         } catch (Exception e) {
-            String msg = String.format("Failed to create collection, error: %s%n", e.getMessage());
+            String msg = String.format("Failed to create database, error: %s%n", e.getMessage());
             System.out.printf(msg);
             throw new RuntimeException(msg);
         } finally {
             pool.returnClient(tempKey, client);
-            pool.clear(tempKey);
+            pool.clear(tempKey); // this call will destroy the temp client
+        }
+
+        // predefine a connect config for each key(name of a database)
+        // the ClientPool will use different config to create client to connect to specific database
+        for (String dbName : dbNames) {
+            ConnectConfig config = ConnectConfig.builder()
+                    .uri(serverUri)
+                    .dbName(dbName)
+                    .build();
+            pool.configForKey(dbName, config);
         }
     }
 
-    public static Thread runInsertThread(MilvusClientV2Pool pool, String clientName, int repeatRequests) {
+    public static void createCollections(MilvusClientV2Pool pool) {
+        for (String dbName : dbNames) {
+            // this client connects to the database of dbName because we have predefined
+            // a ConnectConfig for this key
+            MilvusClientV2 client = pool.getClient(dbName);
+            if (client == null) {
+                throw new RuntimeException("Unable to create client");
+            }
+            try {
+                client.dropCollection(DropCollectionReq.builder()
+                        .collectionName(CollectionName)
+                        .build());
+                client.createCollection(CreateCollectionReq.builder()
+                        .collectionName(CollectionName)
+                        .primaryFieldName("id")
+                        .idType(DataType.Int64)
+                        .autoID(Boolean.TRUE)
+                        .vectorFieldName(VectorFieldName)
+                        .dimension(DIM)
+                        .build());
+                System.out.printf("Collection '%s' created in database '%s'%n", CollectionName, dbName);
+            } catch (Exception e) {
+                String msg = String.format("Failed to create collection, error: %s%n", e.getMessage());
+                System.out.printf(msg);
+                throw new RuntimeException(msg);
+            } finally {
+                pool.returnClient(dbName, client);
+            }
+        }
+    }
+
+    public static Thread runInsertThread(MilvusClientV2Pool pool, String dbName, int repeatRequests) {
         Thread t = new Thread(() -> {
             Gson gson = new Gson();
             for (int i = 0; i < repeatRequests; i++) {
@@ -84,7 +141,7 @@ public class ClientPoolExample {
                     try {
                         // getClient() might exceeds the borrowMaxWaitMillis and throw exception
                         // retry to call until it return a client
-                        client = pool.getClient(clientName);
+                        client = pool.getClient(dbName);
                     } catch (Exception e) {
                         System.out.printf("Failed to get client, will retry, error: %s%n", e.getMessage());
                     }
@@ -101,20 +158,21 @@ public class ClientPoolExample {
                             .collectionName(CollectionName)
                             .data(rows)
                             .build());
-                    System.out.printf("%d rows inserted%n", rows.size());
+//                    System.out.printf("%d rows inserted%n", insertR.getInsertCnt());
                 } catch (Exception e) {
                     System.out.printf("Failed to inserted, error: %s%n", e.getMessage());
                 } finally {
-                    pool.returnClient(clientName, client); // make sure the client is returned after use
+                    pool.returnClient(dbName, client); // make sure the client is returned after use
                 }
             }
             System.out.printf("Insert thread %s finished%n", Thread.currentThread().getName());
+            printKeyClientNumber(pool, dbName);
         });
         t.start();
         return t;
     }
 
-    public static Thread runSearchThread(MilvusClientV2Pool pool, String clientName, int repeatRequests) {
+    public static Thread runSearchThread(MilvusClientV2Pool pool, String dbName, int repeatRequests) {
         Thread t = new Thread(() -> {
             for (int i = 0; i < repeatRequests; i++) {
                 MilvusClientV2 client = null;
@@ -122,7 +180,7 @@ public class ClientPoolExample {
                     try {
                         // getClient() might exceeds the borrowMaxWaitMillis and throw exception
                         // retry to call until it return a client
-                        client = pool.getClient(clientName);
+                        client = pool.getClient(dbName);
                     } catch (Exception e) {
                         System.out.printf("Failed to get client, will retry, error: %s%n", e.getMessage());
                     }
@@ -135,70 +193,157 @@ public class ClientPoolExample {
                             .limit(10)
                             .data(Collections.singletonList(new FloatVec(CommonUtils.generateFloatVector(DIM))))
                             .build());
-                    System.out.println("A search request completed");
+//                    System.out.printf("A search request returns %d items with nq %d%n",
+//                            result.getSearchResults().get(0).size(), result.getSearchResults().size());
                 } catch (Exception e) {
                     System.out.printf("Failed to search, error: %s%n", e.getMessage());
                 } finally {
-                    pool.returnClient(clientName, client); // make sure the client is returned after use
+                    pool.returnClient(dbName, client); // make sure the client is returned after use
                 }
             }
             System.out.printf("Search thread %s finished%n", Thread.currentThread().getName());
+            printKeyClientNumber(pool, dbName);
         });
         t.start();
         return t;
     }
 
+    public static void verifyRowCount(MilvusClientV2Pool pool, long expectedCount) {
+        for (String dbName : dbNames) {
+            // this client connects to the database of dbName because we have predefined
+            // a ConnectConfig for this key
+            MilvusClientV2 client = pool.getClient(dbName);
+            if (client == null) {
+                throw new RuntimeException("Unable to create client");
+            }
+            try {
+                QueryResp countR = client.query(QueryReq.builder()
+                        .collectionName(CollectionName)
+                        .outputFields(Collections.singletonList("count(*)"))
+                        .consistencyLevel(ConsistencyLevel.STRONG)
+                        .build());
+                long rowCount = (long)countR.getQueryResults().get(0).getEntity().get("count(*)");
+                System.out.printf("%d rows persisted in collection '%s' of database '%s'%n",
+                        rowCount, CollectionName, dbName);
+                if (rowCount != expectedCount) {
+                    throw new RuntimeException("The persisted row count is not equal to expected");
+                }
+            } catch (Exception e) {
+                String msg = String.format("Failed to get row count, error: %s%n", e.getMessage());
+                System.out.printf(msg);
+                throw new RuntimeException(msg);
+            } finally {
+                pool.returnClient(dbName, client);
+            }
+        }
+    }
+
+    public static void dropCollections(MilvusClientV2Pool pool) {
+        for (String dbName : dbNames) {
+            // this client connects to the database of dbName because we have predefined
+            // a ConnectConfig for this key
+            MilvusClientV2 client = pool.getClient(dbName);
+            if (client == null) {
+                throw new RuntimeException("Unable to create client");
+            }
+            try {
+                client.dropCollection(DropCollectionReq.builder()
+                        .collectionName(CollectionName)
+                        .build());
+                System.out.printf("Collection '%s' dropped in database '%s'%n", CollectionName, dbName);
+            } catch (Exception e) {
+                String msg = String.format("Failed to drop collection, error: %s%n", e.getMessage());
+                System.out.printf(msg);
+                throw new RuntimeException(msg);
+            } finally {
+                pool.returnClient(dbName, client);
+            }
+        }
+    }
+
+    public static void dropDatabases(MilvusClientV2Pool pool) {
+        // get a client, the client uses the default config to connect milvus(to the default database)
+        String tempKey = "temp";
+        MilvusClientV2 client = pool.getClient(tempKey);
+        if (client == null) {
+            throw new RuntimeException("Unable to create client");
+        }
+        try {
+            for (String dbName : dbNames) {
+                client.dropDatabase(DropDatabaseReq.builder()
+                        .databaseName(dbName)
+                        .build());
+                System.out.println("Database dropped: " + dbName);
+            }
+        } catch (Exception e) {
+            String msg = String.format("Failed to drop database, error: %s%n", e.getMessage());
+            System.out.printf(msg);
+            throw new RuntimeException(msg);
+        } finally {
+            pool.returnClient(tempKey, client);
+            pool.clear(tempKey); // this call will destroy the temp client
+        }
+    }
+
     public static void main(String[] args) throws InterruptedException {
-        ConnectConfig connectConfig = ConnectConfig.builder()
-                .uri("http://localhost:19530")
+        ConnectConfig defaultConfig = ConnectConfig.builder()
+                .uri(serverUri)
                 .build();
+        // read this issue for more details about the pool configurations:
+        // https://github.com/milvus-io/milvus-sdk-java/issues/1577
         PoolConfig poolConfig = PoolConfig.builder()
                 .maxIdlePerKey(10) // max idle clients per key
-                .maxTotalPerKey(20) // max total(idle + active) clients per key
-                .maxTotal(100) // max total clients for all keys
+                .maxTotalPerKey(50) // max total(idle + active) clients per key
+                .maxTotal(1000) // max total clients for all keys
                 .maxBlockWaitDuration(Duration.ofSeconds(5L)) // getClient() will wait 5 seconds if no idle client available
                 .minEvictableIdleDuration(Duration.ofSeconds(10L)) // if number of idle clients is larger than maxIdlePerKey, redundant idle clients will be evicted after 10 seconds
                 .build();
         MilvusClientV2Pool pool;
         try {
-            pool = new MilvusClientV2Pool(poolConfig, connectConfig);
+            pool = new MilvusClientV2Pool(poolConfig, defaultConfig);
         } catch (Exception e) {
             System.out.println(e.getMessage());
             return;
         }
 
-        createCollection(pool);
+        // create some databases
+        createDatabases(pool);
+        // create a collection in each database
+        createCollections(pool);
 
         List<Thread> threadList = new ArrayList<>();
         int threadCount = 100;
         int repeatRequests = 100;
         long start = System.currentTimeMillis();
+        // for each database, we create threadCount of threads to call insert() for repeatRequests times
+        // each insert request will insert one row
+        // for each database, we create threadCount of threads to call search() for repeatRequests times
         for (int k = 0; k < threadCount; k++) {
-            threadList.add(runInsertThread(pool, "192.168.1.1", repeatRequests));
-            threadList.add(runInsertThread(pool, "192.168.1.2", repeatRequests));
-            threadList.add(runInsertThread(pool, "192.168.1.3", repeatRequests));
-
-            threadList.add(runSearchThread(pool, "192.168.1.1", repeatRequests));
-            threadList.add(runSearchThread(pool, "192.168.1.2", repeatRequests));
-            threadList.add(runSearchThread(pool, "192.168.1.3", repeatRequests));
-
-            System.out.printf("Total %d idle clients and %d active clients%n",
-                    pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber());
+            for (String dbName : dbNames) {
+                threadList.add(runInsertThread(pool, dbName, repeatRequests));
+                threadList.add(runSearchThread(pool, dbName, repeatRequests));
+            }
+            printClientNumber(pool);
         }
-
         for (Thread t : threadList) {
             t.join();
         }
+        printClientNumber(pool);
+
+        // check row count of each collection, there are threadCount*repeatRequests rows were inserted by multiple threads
+        verifyRowCount(pool, threadCount*repeatRequests);
+        // drop collections
+        dropCollections(pool);
+        // drop databases, only after database is empty, it is able to be dropped
+        dropDatabases(pool);
+
         long end = System.currentTimeMillis();
         System.out.printf("%d insert requests and %d search requests finished in %.3f seconds%n",
                 threadCount*repeatRequests*3, threadCount*repeatRequests*3, (end-start)*0.001);
-        System.out.printf("Total %d idle clients and %d active clients%n",
-                pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber());
 
+        printClientNumber(pool);
         pool.clear(); // clear idle clients
-        System.out.printf("After clear, total %d idle clients and %d active clients%n",
-                pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber());
-
+        printClientNumber(pool);
         pool.close();
     }
 }

+ 4 - 2
sdk-core/src/main/java/io/milvus/pool/ClientPool.java

@@ -7,8 +7,6 @@ import org.apache.commons.pool2.impl.GenericKeyedObjectPoolConfig;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.time.Duration;
-
 public class ClientPool<C, T> {
     protected static final Logger logger = LoggerFactory.getLogger(ClientPool.class);
     protected GenericKeyedObjectPool<String, T> clientPool;
@@ -40,6 +38,10 @@ public class ClientPool<C, T> {
         this.clientPool = new GenericKeyedObjectPool<String, T>(clientFactory, poolConfig);
     }
 
+    public void configForKey(String key, C config) {
+        this.clientFactory.configForKey(key, config);
+    }
+
     /**
      * Get a client object which is idle from the pool.
      * Once the client is hold by the caller, it will be marked as active state and cannot be fetched by other caller.

+ 17 - 6
sdk-core/src/main/java/io/milvus/pool/PoolClientFactory.java

@@ -10,19 +10,22 @@ import org.slf4j.LoggerFactory;
 
 import java.lang.reflect.Constructor;
 import java.lang.reflect.Method;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
 
 public class PoolClientFactory<C, T> extends BaseKeyedPooledObjectFactory<String, T> {
     protected static final Logger logger = LoggerFactory.getLogger(PoolClientFactory.class);
-    private final C config;
+    private final C configDefault;
+    private ConcurrentMap<String, C> configForKeys = new ConcurrentHashMap<>();
     private Constructor<?> constructor;
     private Method closeMethod;
     private Method verifyMethod;
 
-    public PoolClientFactory(C config, String clientClassName) throws ClassNotFoundException, NoSuchMethodException {
-        this.config = config;
+    public PoolClientFactory(C configDefault, String clientClassName) throws ClassNotFoundException, NoSuchMethodException {
+        this.configDefault = configDefault;
         try {
             Class<?> clientCls = Class.forName(clientClassName);
-            Class<?> configCls = Class.forName(config.getClass().getName());
+            Class<?> configCls = Class.forName(configDefault.getClass().getName());
             constructor = clientCls.getConstructor(configCls);
             closeMethod = clientCls.getMethod("close", long.class);
             verifyMethod = clientCls.getMethod("clientIsReady");
@@ -32,11 +35,19 @@ public class PoolClientFactory<C, T> extends BaseKeyedPooledObjectFactory<String
         }
     }
 
+    public void configForKey(String key, C config) {
+        configForKeys.put(key, config);
+    }
+
     @Override
     public T create(String key) throws Exception {
         try {
-            T client = (T) constructor.newInstance(this.config);
-            return client;
+            C keyConfig = configForKeys.get(key);
+            if (keyConfig == null) {
+                return (T) constructor.newInstance(this.configDefault);
+            } else {
+                return (T) constructor.newInstance(keyConfig);
+            }
         } catch (Exception e) {
             logger.error("Failed to create client, exception: ", e);
             throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e);

+ 3 - 3
sdk-core/src/main/java/io/milvus/pool/PoolConfig.java

@@ -10,13 +10,13 @@ import java.time.Duration;
 @SuperBuilder
 public class PoolConfig {
     @Builder.Default
-    private int maxIdlePerKey = 5;
+    private int maxIdlePerKey = 10;
     @Builder.Default
     private int minIdlePerKey = 0;
     @Builder.Default
-    private int maxTotalPerKey = 10;
+    private int maxTotalPerKey = 30;
     @Builder.Default
-    private int maxTotal = 50;
+    private int maxTotal = 1000;
     @Builder.Default
     private boolean blockWhenExhausted = true;
     @Builder.Default

+ 1 - 1
sdk-core/src/test/java/io/milvus/TestUtils.java

@@ -11,7 +11,7 @@ public class TestUtils {
     private int dimension = 256;
     private static final Random RANDOM = new Random();
 
-    public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.0";
+    public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.1";
 
     public TestUtils(int dimension) {
         this.dimension = dimension;

+ 20 - 1
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -2112,7 +2112,14 @@ class MilvusClientV2DockerTest {
 
     @Test
     void testClientPool() {
+        // create a temp database
+        String dummyDb = "dummy_db";
+        client.createDatabase(CreateDatabaseReq.builder()
+                .databaseName(dummyDb)
+                .build());
+
         try {
+            // the default connection config will connect to default db
             ConnectConfig connectConfig = ConnectConfig.builder()
                     .uri(milvus.getEndpoint())
                     .rpcDeadlineMs(100L)
@@ -2121,16 +2128,24 @@ class MilvusClientV2DockerTest {
                     .build();
             MilvusClientV2Pool pool = new MilvusClientV2Pool(poolConfig, connectConfig);
 
+            // clients of the key "dummy_db" will connect to this db
+            pool.configForKey(dummyDb, ConnectConfig.builder()
+                    .uri(milvus.getEndpoint())
+                    .dbName(dummyDb)
+                    .rpcDeadlineMs(100L)
+                    .build());
+
             List<Thread> threadList = new ArrayList<>();
             int threadCount = 10;
             int requestPerThread = 10;
-            String key = "192.168.1.1";
+            String key = "default";
             for (int k = 0; k < threadCount; k++) {
                 Thread t = new Thread(() -> {
                     for (int i = 0; i < requestPerThread; i++) {
                         MilvusClientV2 client = pool.getClient(key);
                         String version = client.getServerVersion();
 //                            System.out.printf("%d, %s%n", i, version);
+                        Assertions.assertEquals(client.currentUsedDatabase(), "default");
                         System.out.printf("idle %d, active %d%n", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key));
                         pool.returnClient(key, client);
                     }
@@ -2146,6 +2161,10 @@ class MilvusClientV2DockerTest {
 
             System.out.println(String.format("idle %d, active %d", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key)));
             System.out.println(String.format("total idle %d, total active %d", pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber()));
+
+            // get client connect to the dummy db
+            MilvusClientV2 dummyClient = pool.getClient(dummyDb);
+            Assertions.assertEquals(dummyClient.currentUsedDatabase(), dummyDb);
             pool.close();
         } catch (Exception e) {
             System.out.println(e.getMessage());