Browse Source

Support gRPC target string and load-balancing policy

jianghua 4 years ago
parent
commit
18406c3baa

+ 40 - 0
src/main/java/io/milvus/client/ConnectParam.java

@@ -19,13 +19,17 @@
 
 package io.milvus.client;
 
+import io.grpc.ManagedChannelBuilder;
+
 import javax.annotation.Nonnull;
 import java.util.concurrent.TimeUnit;
 
 /** Contains parameters for connecting to Milvus server */
 public class ConnectParam {
+  private final String target;
   private final String host;
   private final int port;
+  private final String defaultLoadBalancingPolicy;
   private final long connectTimeoutNanos;
   private final long keepAliveTimeNanos;
   private final long keepAliveTimeoutNanos;
@@ -33,8 +37,10 @@ public class ConnectParam {
   private final long idleTimeoutNanos;
 
   private ConnectParam(@Nonnull Builder builder) {
+    this.target = builder.target;
     this.host = builder.host;
     this.port = builder.port;
+    this.defaultLoadBalancingPolicy = builder.defaultLoadBalancingPolicy;
     this.connectTimeoutNanos = builder.connectTimeoutNanos;
     this.keepAliveTimeNanos = builder.keepAliveTimeNanos;
     this.keepAliveTimeoutNanos = builder.keepAliveTimeoutNanos;
@@ -42,6 +48,10 @@ public class ConnectParam {
     this.idleTimeoutNanos = builder.idleTimeoutNanos;
   }
 
+  public String getTarget() {
+    return target;
+  }
+
   public String getHost() {
     return host;
   }
@@ -50,6 +60,10 @@ public class ConnectParam {
     return port;
   }
 
+  public String getDefaultLoadBalancingPolicy() {
+    return defaultLoadBalancingPolicy;
+  }
+
   public long getConnectTimeout(@Nonnull TimeUnit timeUnit) {
     return timeUnit.convert(connectTimeoutNanos, TimeUnit.NANOSECONDS);
   }
@@ -73,14 +87,29 @@ public class ConnectParam {
   /** Builder for <code>ConnectParam</code> */
   public static class Builder {
     // Optional parameters - initialized to default values
+    private String target = null;
     private String host = "localhost";
     private int port = 19530;
+    private String defaultLoadBalancingPolicy = "round_robin";
     private long connectTimeoutNanos = TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
     private long keepAliveTimeNanos = Long.MAX_VALUE; // Disabling keepalive
     private long keepAliveTimeoutNanos = TimeUnit.NANOSECONDS.convert(20, TimeUnit.SECONDS);
     private boolean keepAliveWithoutCalls = false;
     private long idleTimeoutNanos = TimeUnit.NANOSECONDS.convert(24, TimeUnit.HOURS);
 
+    /**
+     * Optional. Defaults to null. Will be used in precedence to host and port.
+     *
+     * @param target a GRPC target string
+     * @return <code>Builder</code>
+     *
+     * @see ManagedChannelBuilder#forTarget(String)
+     */
+    public Builder withTarget(@Nonnull String target) {
+      this.target = target;
+      return this;
+    }
+
     /**
      * Optional. Defaults to "localhost".
      *
@@ -106,6 +135,17 @@ public class ConnectParam {
       return this;
     }
 
+    /**
+     * Optional. Defaults to "round_robin".
+     *
+     * @param defaultLoadBalancingPolicy the default load-balancing policy name
+     * @return <code>Builder</code>
+     */
+    public Builder withDefaultLoadBalancingPolicy(String defaultLoadBalancingPolicy) {
+      this.defaultLoadBalancingPolicy = defaultLoadBalancingPolicy;
+      return this;
+    }
+
     /**
      * Optional. Defaults to 10 seconds.
      *

+ 6 - 2
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -65,9 +65,13 @@ public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
   private final MilvusServiceGrpc.MilvusServiceFutureStub futureStub;
 
   public MilvusGrpcClient(ConnectParam connectParam) {
-    channel = ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
-        .usePlaintext()
+    ManagedChannelBuilder builder = connectParam.getTarget() != null
+        ? ManagedChannelBuilder.forTarget(connectParam.getTarget())
+        : ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort());
+
+    channel = builder.usePlaintext()
         .maxInboundMessageSize(Integer.MAX_VALUE)
+        .defaultLoadBalancingPolicy(connectParam.getDefaultLoadBalancingPolicy())
         .keepAliveTime(connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
         .keepAliveTimeout(connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
         .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())

+ 28 - 1
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -23,6 +23,8 @@ import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
+import io.grpc.NameResolverProvider;
+import io.grpc.NameResolverRegistry;
 import io.milvus.client.InsertParam.Builder;
 import io.milvus.client.Response.Status;
 import io.milvus.client.exception.InitializationException;
@@ -38,6 +40,7 @@ import org.testcontainers.containers.GenericContainer;
 import org.testcontainers.junit.jupiter.Container;
 import org.testcontainers.junit.jupiter.Testcontainers;
 
+import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -50,6 +53,7 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
 import java.util.stream.LongStream;
 
 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
@@ -66,10 +70,33 @@ class ContainerMilvusClientTest extends MilvusClientTest {
       new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu"))
           .withExposedPorts(19530);
 
+  @Container
+  private static GenericContainer milvusContainer2 =
+      new GenericContainer(System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu"))
+          .withExposedPorts(19530);
+
   @Override
   protected ConnectParam.Builder connectParamBuilder() {
     return connectParamBuilder(milvusContainer);
   }
+
+  @org.junit.jupiter.api.Test
+  void loadBalancing() {
+    NameResolverProvider testNameResolverProvider = new StaticNameResolverProvider(
+        new InetSocketAddress(milvusContainer.getHost(), milvusContainer.getFirstMappedPort()),
+        new InetSocketAddress(milvusContainer2.getHost(), milvusContainer2.getFirstMappedPort()));
+
+    NameResolverRegistry.getDefaultRegistry().register(testNameResolverProvider);
+
+    ConnectParam connectParam = connectParamBuilder()
+        .withTarget(testNameResolverProvider.getDefaultScheme() + ":///test")
+        .build();
+
+    MilvusClient loadBalancingClient = new MilvusGrpcClient(connectParam);
+    assertEquals(50, IntStream.range(0, 100)
+            .filter(i -> loadBalancingClient.hasCollection(randomCollectionName).hasCollection())
+            .count());
+  }
 }
 
 @Testcontainers
@@ -80,7 +107,7 @@ class MilvusClientTest {
 
   private RandomStringGenerator generator;
 
-  private String randomCollectionName;
+  protected String randomCollectionName;
   private int size;
   private int dimension;
 

+ 66 - 0
src/test/java/io/milvus/client/StaticNameResolverProvider.java

@@ -0,0 +1,66 @@
+package io.milvus.client;
+
+import io.grpc.Attributes;
+import io.grpc.EquivalentAddressGroup;
+import io.grpc.NameResolver;
+import io.grpc.NameResolverProvider;
+
+import java.net.SocketAddress;
+import java.net.URI;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class StaticNameResolverProvider extends NameResolverProvider {
+  private List<SocketAddress> addresses;
+
+  public StaticNameResolverProvider(SocketAddress... addresses) {
+    this.addresses = Arrays.asList(addresses);
+  }
+
+  @Override
+  public String getDefaultScheme() {
+    return "static";
+  }
+
+  @Override
+  protected boolean isAvailable() {
+    return true;
+  }
+
+  @Override
+  protected int priority() {
+    return 0;
+  }
+
+  @Override
+  public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
+    if (!getDefaultScheme().equals(targetUri.getScheme())) {
+      return null;
+    }
+    return new NameResolver() {
+      @Override
+      public String getServiceAuthority() {
+        return "localhost";
+      }
+
+      @Override
+      public void start(Listener2 listener) {
+        List<EquivalentAddressGroup> addrs = addresses.stream()
+            .map(addr -> new EquivalentAddressGroup(Collections.singletonList(addr)))
+            .collect(Collectors.toList());
+
+        listener.onResult(
+            ResolutionResult.newBuilder()
+                .setAddresses(addrs)
+                .setAttributes(Attributes.EMPTY)
+                .build());
+      }
+
+      @Override
+      public void shutdown() {
+      }
+    };
+  }
+}