Jelajahi Sumber

Client pool for V1 V2 (#1016)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 9 bulan lalu
induk
melakukan
be3bf1c73c

+ 7 - 0
pom.xml

@@ -95,6 +95,7 @@
         <kotlin.version>1.9.10</kotlin.version>
         <mockito.version>4.11.0</mockito.version>
         <testcontainers.version>1.19.8</testcontainers.version>
+        <apache.commons.pool2.version>2.12.0</apache.commons.pool2.version>
 
         <hadoop.version>3.3.6</hadoop.version>
         <hbase.version>1.2.0</hbase.version>
@@ -109,6 +110,7 @@
         <minio-java-sdk.veresion>8.5.7</minio-java-sdk.veresion>
         <azure-java-blob-sdk.version>12.25.3</azure-java-blob-sdk.version>
         <azure-java-identity-sdk.version>1.10.1</azure-java-identity-sdk.version>
+
     </properties>
 
     <dependencyManagement>
@@ -389,6 +391,11 @@
                 </exclusion>
             </exclusions>
         </dependency>
+        <dependency>
+            <groupId>org.apache.commons</groupId>
+            <artifactId>commons-pool2</artifactId>
+            <version>${apache.commons.pool2.version}</version>
+        </dependency>
     </dependencies>
 
     <profiles>

+ 2 - 3
src/main/java/io/milvus/client/MilvusServiceClient.java

@@ -202,9 +202,8 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
     }
 
     @Override
-    protected boolean clientIsReady() {
-        ConnectivityState state = channel.getState(false);
-        return state != ConnectivityState.SHUTDOWN;
+    public boolean clientIsReady() {
+        return channel != null && !channel.isShutdown() && !channel.isTerminated();
     }
 
     @Override

+ 79 - 0
src/main/java/io/milvus/pool/ClientPool.java

@@ -0,0 +1,79 @@
+package io.milvus.pool;
+
+import org.apache.commons.pool2.impl.GenericKeyedObjectPool;
+import org.apache.commons.pool2.impl.GenericKeyedObjectPoolConfig;
+
+import java.time.Duration;
+
+public class ClientPool<C, T> {
+    protected GenericKeyedObjectPool<String, T> clientPool;
+    protected PoolConfig config;
+    protected PoolClientFactory<C, T> clientFactory;
+
+    protected ClientPool() {
+
+    }
+
+    protected ClientPool(PoolConfig config, PoolClientFactory clientFactory) {
+        this.config = config;
+        this.clientFactory = clientFactory;
+
+        GenericKeyedObjectPoolConfig poolConfig = new GenericKeyedObjectPoolConfig();
+        poolConfig.setMaxIdlePerKey(config.getMaxIdlePerKey());
+        poolConfig.setMinIdlePerKey(config.getMinIdlePerKey());
+        poolConfig.setMaxTotal(config.getMaxTotal());
+        poolConfig.setMaxTotalPerKey(config.getMaxTotalPerKey());
+        poolConfig.setBlockWhenExhausted(config.isBlockWhenExhausted());
+        poolConfig.setMaxWait(config.getMaxBlockWaitDuration());
+        poolConfig.setTestOnBorrow(config.isTestOnBorrow());
+        poolConfig.setTestOnReturn(config.isTestOnReturn());
+        poolConfig.setTestOnCreate(false);
+        poolConfig.setTestWhileIdle(false);
+        poolConfig.setTimeBetweenEvictionRuns(config.getEvictionPollingInterval());
+        poolConfig.setNumTestsPerEvictionRun(5);
+        poolConfig.setMinEvictableIdleTime(config.getMinEvictableIdleDuration());
+        this.clientPool = new GenericKeyedObjectPool<String, T>(clientFactory, poolConfig);
+    }
+
+    public T getClient(String key) {
+        try {
+            return clientPool.borrowObject(key);
+        } catch (Exception e) {
+            System.out.println("Failed to get client, exception: " + e.getMessage());
+            return null;
+        }
+    }
+
+
+    public void returnClient(String key, T grpcClient) {
+        try {
+            clientPool.returnObject(key, grpcClient);
+        } catch (Exception e) {
+            System.out.println("Failed to return client, exception: " + e.getMessage());
+            throw e;
+        }
+    }
+
+    public void close() {
+        if (clientPool != null && !clientPool.isClosed()) {
+            clientPool.close();
+            clientPool = null;
+        }
+    }
+
+    public int getIdleClientNumber(String key) {
+        return clientPool.getNumIdle(key);
+    }
+
+    public int getActiveClientNumber(String key) {
+        return clientPool.getNumActive(key);
+    }
+
+    public int getTotalIdleClientNumber() {
+        return clientPool.getNumIdle();
+    }
+
+    public int getTotalActiveClientNumber() {
+        return clientPool.getNumActive();
+    }
+}

+ 11 - 0
src/main/java/io/milvus/pool/MilvusClientV1Pool.java

@@ -0,0 +1,11 @@
+package io.milvus.pool;
+
+import io.milvus.client.MilvusClient;
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.param.ConnectParam;
+
+public class MilvusClientV1Pool extends ClientPool<ConnectParam, MilvusClient> {
+    public MilvusClientV1Pool(PoolConfig poolConfig, ConnectParam connectParam) throws ClassNotFoundException, NoSuchMethodException {
+        super(poolConfig, new PoolClientFactory<ConnectParam, MilvusClient>(connectParam, MilvusServiceClient.class.getName()));
+    }
+}

+ 10 - 0
src/main/java/io/milvus/pool/MilvusClientV2Pool.java

@@ -0,0 +1,10 @@
+package io.milvus.pool;
+
+import io.milvus.v2.client.ConnectConfig;
+import io.milvus.v2.client.MilvusClientV2;
+
+public class MilvusClientV2Pool extends ClientPool<ConnectConfig, MilvusClientV2> {
+    public MilvusClientV2Pool(PoolConfig poolConfig, ConnectConfig connectConfig) throws ClassNotFoundException, NoSuchMethodException {
+        super(poolConfig, new PoolClientFactory<ConnectConfig, MilvusClientV2>(connectConfig, MilvusClientV2.class.getName()));
+    }
+}

+ 71 - 0
src/main/java/io/milvus/pool/PoolClientFactory.java

@@ -0,0 +1,71 @@
+package io.milvus.pool;
+
+import org.apache.commons.pool2.BaseKeyedPooledObjectFactory;
+import org.apache.commons.pool2.PooledObject;
+import org.apache.commons.pool2.impl.DefaultPooledObject;
+
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Method;
+
+public class PoolClientFactory<C, T> extends BaseKeyedPooledObjectFactory<String, T> {
+    private final C config;
+    private Constructor<?> constructor;
+    private Method closeMethod;
+    private Method verifyMethod;
+
+    public PoolClientFactory(C config, String clientClassName) throws ClassNotFoundException, NoSuchMethodException {
+        this.config = config;
+        try {
+            Class<?> clientCls = Class.forName(clientClassName);
+            Class<?> configCls = Class.forName(config.getClass().getName());
+            constructor = clientCls.getConstructor(configCls);
+            closeMethod = clientCls.getMethod("close", long.class);
+            verifyMethod = clientCls.getMethod("clientIsReady");
+        } catch (Exception e) {
+            System.out.println("Failed to create client pool factory, exception: " + e.getMessage());
+            throw e;
+        }
+    }
+
+    @Override
+    public T create(String key) throws Exception {
+        try {
+            T client = (T) constructor.newInstance(this.config);
+            return client;
+        } catch (Exception e) {
+            return null;
+        }
+    }
+
+    @Override
+    public PooledObject<T> wrap(T client) {
+        return new DefaultPooledObject<>(client);
+    }
+
+    @Override
+    public void destroyObject(String key, PooledObject<T> p) throws Exception {
+        T client = p.getObject();
+        closeMethod.invoke(client, 3L);
+    }
+
+    @Override
+    public boolean validateObject(String key, PooledObject<T> p) {
+        try {
+            T client = p.getObject();
+            return (boolean) verifyMethod.invoke(client);
+        } catch (Exception e) {
+            System.out.println("Failed to validate client, exception: " + e.getMessage());
+            return true;
+        }
+    }
+
+    @Override
+    public void activateObject(String key, PooledObject<T> p) throws Exception {
+        super.activateObject(key, p);
+    }
+
+    @Override
+    public void passivateObject(String key, PooledObject<T> p) throws Exception {
+        super.passivateObject(key, p);
+    }
+}

+ 32 - 0
src/main/java/io/milvus/pool/PoolConfig.java

@@ -0,0 +1,32 @@
+package io.milvus.pool;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.experimental.SuperBuilder;
+
+import java.time.Duration;
+
+@Data
+@SuperBuilder
+public class PoolConfig {
+    @Builder.Default
+    private int maxIdlePerKey = 5;
+    @Builder.Default
+    private int minIdlePerKey = 0;
+    @Builder.Default
+    private int maxTotalPerKey = 10;
+    @Builder.Default
+    private int maxTotal = 50;
+    @Builder.Default
+    private boolean blockWhenExhausted = true;
+    @Builder.Default
+    private Duration maxBlockWaitDuration = Duration.ofSeconds(3L);
+    @Builder.Default
+    private Duration evictionPollingInterval = Duration.ofSeconds(60L);
+    @Builder.Default
+    private Duration minEvictableIdleDuration = Duration.ofSeconds(10L);
+    @Builder.Default
+    private boolean testOnBorrow = false;
+    @Builder.Default
+    private boolean testOnReturn = true;
+}

+ 4 - 0
src/main/java/io/milvus/v2/client/MilvusClientV2.java

@@ -750,4 +750,8 @@ public class MilvusClientV2 {
             channel.awaitTermination(maxWaitSeconds, TimeUnit.SECONDS);
         }
     }
+
+    public boolean clientIsReady() {
+        return channel != null && !channel.isShutdown() && !channel.isTerminated();
+    }
 }

+ 46 - 0
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -50,6 +50,8 @@ import io.milvus.param.index.DropIndexParam;
 import io.milvus.param.index.GetIndexStateParam;
 import io.milvus.param.partition.GetPartitionStatisticsParam;
 import io.milvus.param.partition.ShowPartitionsParam;
+import io.milvus.pool.MilvusClientV1Pool;
+import io.milvus.pool.PoolConfig;
 import io.milvus.response.*;
 
 import org.apache.avro.generic.GenericData;
@@ -3093,4 +3095,48 @@ class MilvusClientDockerTest {
                 .build());
         Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
     }
+
+    @Test
+    void testClientPool() {
+        try {
+            ConnectParam connectParam = ConnectParam.newBuilder()
+                    .withUri(milvus.getEndpoint())
+                    .build();
+            PoolConfig poolConfig = PoolConfig.builder()
+                    .build();
+            MilvusClientV1Pool pool = new MilvusClientV1Pool(poolConfig, connectParam);
+
+            List<Thread> threadList = new ArrayList<>();
+            int threadCount = 10;
+            int requestPerThread = 10;
+            String key = "192.168.1.1";
+            for (int k = 0; k < threadCount; k++) {
+                Thread t = new Thread(new Runnable() {
+                    @Override
+                    public void run() {
+                        for (int i = 0; i < requestPerThread; i++) {
+                            MilvusClient client = pool.getClient(key);
+                            R<GetVersionResponse> resp = client.getVersion();
+//                            System.out.printf("%d, %s%n", i, resp.getData().getVersion());
+                            System.out.printf("idle %d, active %d%n", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key));
+                            pool.returnClient(key, client);
+                        }
+                        System.out.println(String.format("Thread %s finished", Thread.currentThread().getName()));
+                    }
+                });
+                t.start();
+                threadList.add(t);
+            }
+
+            for (Thread t : threadList) {
+                t.join();
+            }
+
+            System.out.println(String.format("idle %d, active %d", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key)));
+            pool.close();
+        } catch (Exception e) {
+            System.out.println(e.getMessage());
+            Assertions.fail(e.getMessage());
+        }
+    }
 }

+ 47 - 3
src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -23,15 +23,15 @@ import com.google.common.collect.Lists;
 import com.google.gson.*;
 
 import com.google.gson.reflect.TypeToken;
-import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.common.utils.Float16Utils;
 import io.milvus.orm.iterator.QueryIterator;
 import io.milvus.orm.iterator.SearchIterator;
 import io.milvus.param.Constant;
+import io.milvus.pool.MilvusClientV2Pool;
+import io.milvus.pool.PoolConfig;
 import io.milvus.response.QueryResultsWrapper;
 import io.milvus.v2.common.ConsistencyLevel;
 import io.milvus.v2.common.DataType;
-import io.milvus.v2.common.IndexBuildState;
 import io.milvus.v2.common.IndexParam;
 import io.milvus.v2.exception.MilvusClientException;
 import io.milvus.v2.service.collection.request.*;
@@ -67,7 +67,6 @@ import org.testcontainers.milvus.MilvusContainer;
 
 import java.nio.ByteBuffer;
 import java.util.*;
-import java.util.concurrent.TimeUnit;
 
 @Testcontainers(disabledWithoutDocker = true)
 class MilvusClientV2DockerTest {
@@ -1587,4 +1586,49 @@ class MilvusClientV2DockerTest {
         dbNames = listDatabasesResp.getDatabaseNames();
         Assertions.assertFalse(dbNames.contains(tempDatabaseName));
     }
+
+    @Test
+    void testClientPool() {
+        try {
+            ConnectConfig connectConfig = ConnectConfig.builder()
+                    .uri(milvus.getEndpoint())
+                    .build();
+            PoolConfig poolConfig = PoolConfig.builder()
+                    .build();
+            MilvusClientV2Pool pool = new MilvusClientV2Pool(poolConfig, connectConfig);
+
+            List<Thread> threadList = new ArrayList<>();
+            int threadCount = 10;
+            int requestPerThread = 10;
+            String key = "192.168.1.1";
+            for (int k = 0; k < threadCount; k++) {
+                Thread t = new Thread(new Runnable() {
+                    @Override
+                    public void run() {
+                        for (int i = 0; i < requestPerThread; i++) {
+                            MilvusClientV2 client = pool.getClient(key);
+                            String version = client.getVersion();
+//                            System.out.printf("%d, %s%n", i, version);
+                            System.out.printf("idle %d, active %d%n", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key));
+                            pool.returnClient(key, client);
+                        }
+                        System.out.println(String.format("Thread %s finished", Thread.currentThread().getName()));
+                    }
+                });
+                t.start();
+                threadList.add(t);
+            }
+
+            for (Thread t : threadList) {
+                t.join();
+            }
+
+            System.out.println(String.format("idle %d, active %d", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key)));
+            System.out.println(String.format("total idle %d, total active %d", pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber()));
+            pool.close();
+        } catch (Exception e) {
+            System.out.println(e.getMessage());
+            Assertions.fail(e.getMessage());
+        }
+    }
 }