Browse Source

support GRPC target string and load-balancing policy

jianghua 4 years ago
parent
commit
34e5305041

+ 40 - 0
src/main/java/io/milvus/client/ConnectParam.java

@@ -19,13 +19,17 @@
 
 package io.milvus.client;
 
+import io.grpc.ManagedChannelBuilder;
+
 import javax.annotation.Nonnull;
 import java.util.concurrent.TimeUnit;
 
 /** Contains parameters for connecting to Milvus server */
 public class ConnectParam {
+  private final String target;
   private final String host;
   private final int port;
+  private final String defaultLoadBalancingPolicy;
   private final long connectTimeoutNanos;
   private final long keepAliveTimeNanos;
   private final long keepAliveTimeoutNanos;
@@ -33,8 +37,10 @@ public class ConnectParam {
   private final long idleTimeoutNanos;
 
   private ConnectParam(@Nonnull Builder builder) {
+    this.target = builder.target;
     this.host = builder.host;
     this.port = builder.port;
+    this.defaultLoadBalancingPolicy = builder.defaultLoadBalancingPolicy;
     this.connectTimeoutNanos = builder.connectTimeoutNanos;
     this.keepAliveTimeNanos = builder.keepAliveTimeNanos;
     this.keepAliveTimeoutNanos = builder.keepAliveTimeoutNanos;
@@ -42,6 +48,10 @@ public class ConnectParam {
     this.idleTimeoutNanos = builder.idleTimeoutNanos;
   }
 
+  public String getTarget() {
+    return target;
+  }
+
   public String getHost() {
     return host;
   }
@@ -50,6 +60,10 @@ public class ConnectParam {
     return port;
   }
 
+  public String getDefaultLoadBalancingPolicy() {
+    return defaultLoadBalancingPolicy;
+  }
+
   public long getConnectTimeout(@Nonnull TimeUnit timeUnit) {
     return timeUnit.convert(connectTimeoutNanos, TimeUnit.NANOSECONDS);
   }
@@ -73,14 +87,29 @@ public class ConnectParam {
   /** Builder for <code>ConnectParam</code> */
   public static class Builder {
     // Optional parameters - initialized to default values
+    private String target = null;
     private String host = "localhost";
     private int port = 19530;
+    private String defaultLoadBalancingPolicy = "round_robin";
     private long connectTimeoutNanos = TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
     private long keepAliveTimeNanos = Long.MAX_VALUE; // Disabling keepalive
     private long keepAliveTimeoutNanos = TimeUnit.NANOSECONDS.convert(20, TimeUnit.SECONDS);
     private boolean keepAliveWithoutCalls = false;
     private long idleTimeoutNanos = TimeUnit.NANOSECONDS.convert(24, TimeUnit.HOURS);
 
+    /**
+     * Optional. Defaults to null. Will be used in precedence to host and port.
+     *
+     * @param target a GRPC target string
+     * @return <code>Builder</code>
+     *
+     * @see ManagedChannelBuilder#forTarget(String)
+     */
+    public Builder withTarget(@Nonnull String target) {
+      this.target = target;
+      return this;
+    }
+
     /**
      * Optional. Defaults to "localhost".
      *
@@ -106,6 +135,17 @@ public class ConnectParam {
       return this;
     }
 
+    /**
+     * Optional. Defaults to "round_robin".
+     *
+     * @param defaultLoadBalancingPolicy the default load-balancing policy name
+     * @return <code>Builder</code>
+     */
+    public Builder withDefaultLoadBalancingPolicy(String defaultLoadBalancingPolicy) {
+      this.defaultLoadBalancingPolicy = defaultLoadBalancingPolicy;
+      return this;
+    }
+
     /**
      * Optional. Defaults to 10 seconds.
      *

+ 6 - 2
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -53,9 +53,13 @@ public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
   private final MilvusServiceGrpc.MilvusServiceFutureStub futureStub;
 
   public MilvusGrpcClient(ConnectParam connectParam) {
-    channel = ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
-        .usePlaintext()
+    ManagedChannelBuilder builder = connectParam.getTarget() != null
+        ? ManagedChannelBuilder.forTarget(connectParam.getTarget())
+        : ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort());
+
+    channel = builder.usePlaintext()
         .maxInboundMessageSize(Integer.MAX_VALUE)
+        .defaultLoadBalancingPolicy(connectParam.getDefaultLoadBalancingPolicy())
         .keepAliveTime(connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
         .keepAliveTimeout(connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
         .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())

+ 30 - 0
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -20,6 +20,8 @@
 package io.milvus.client;
 
 import com.google.common.util.concurrent.ListenableFuture;
+import io.grpc.NameResolverProvider;
+import io.grpc.NameResolverRegistry;
 import io.milvus.client.exception.InitializationException;
 import io.milvus.client.exception.UnsupportedServerVersion;
 import org.apache.commons.text.RandomStringGenerator;
@@ -30,12 +32,14 @@ import org.testcontainers.containers.GenericContainer;
 import org.testcontainers.junit.jupiter.Container;
 import org.testcontainers.junit.jupiter.Testcontainers;
 
+import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
 import java.util.*;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
 import java.util.stream.LongStream;
 
 import static org.junit.jupiter.api.Assertions.*;
@@ -48,6 +52,11 @@ class ContainerMilvusClientTest extends MilvusClientTest {
       new GenericContainer("milvusdb/milvus:0.10.1-cpu-d072020-bd02b1")
           .withExposedPorts(19530);
 
+  @Container
+  private static GenericContainer milvusContainer2 =
+      new GenericContainer("milvusdb/milvus:0.10.0-cpu-d061620-5f3c00")
+          .withExposedPorts(19530);
+
   @Container
   private static GenericContainer unsupportedMilvusContainer =
       new GenericContainer("milvusdb/milvus:0.9.1-cpu-d052920-e04ed5")
@@ -63,6 +72,27 @@ class ContainerMilvusClientTest extends MilvusClientTest {
     ConnectParam connectParam = connectParamBuilder(unsupportedMilvusContainer).build();
     assertThrows(UnsupportedServerVersion.class, () -> new MilvusGrpcClient(connectParam));
   }
+
+  @org.junit.jupiter.api.Test
+  void loadBalancing() {
+    NameResolverProvider testNameResolverProvider = new StaticNameResolverProvider(
+        "test",
+        new InetSocketAddress(milvusContainer.getHost(), milvusContainer.getFirstMappedPort()),
+        new InetSocketAddress(milvusContainer2.getHost(), milvusContainer2.getFirstMappedPort()));
+
+    NameResolverRegistry.getDefaultRegistry().register(testNameResolverProvider);
+
+    ConnectParam connectParam = connectParamBuilder()
+        .withTarget(testNameResolverProvider.getDefaultScheme() + ":///test")
+        .build();
+
+    MilvusClient client = new MilvusGrpcClient(connectParam);
+    List<String> serverVersions = IntStream.range(0, 100)
+        .mapToObj(i -> client.getServerVersion().getMessage())
+        .collect(Collectors.toList());
+    assertTrue(serverVersions.stream().allMatch(version -> version.matches("0\\.10\\.[01]")));
+    assertEquals(50, serverVersions.stream().filter(version -> version.equals("0.10.0")).count());
+  }
 }
 
 @DisabledIfSystemProperty(named = "with-containers", matches = "true")

+ 68 - 0
src/test/java/io/milvus/client/StaticNameResolverProvider.java

@@ -0,0 +1,68 @@
+package io.milvus.client;
+
+import io.grpc.Attributes;
+import io.grpc.EquivalentAddressGroup;
+import io.grpc.NameResolver;
+import io.grpc.NameResolverProvider;
+
+import java.net.SocketAddress;
+import java.net.URI;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class StaticNameResolverProvider extends NameResolverProvider {
+  private String name;
+  private List<SocketAddress> addresses;
+
+  public StaticNameResolverProvider(String name, SocketAddress... addresses) {
+    this.name = name;
+    this.addresses = Arrays.asList(addresses);
+  }
+
+  @Override
+  public String getDefaultScheme() {
+    return "static";
+  }
+
+  @Override
+  protected boolean isAvailable() {
+    return true;
+  }
+
+  @Override
+  protected int priority() {
+    return 0;
+  }
+
+  @Override
+  public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
+    if (!getDefaultScheme().equals(targetUri.getScheme())) {
+      return null;
+    }
+    return new NameResolver() {
+      @Override
+      public String getServiceAuthority() {
+        return "localhost";
+      }
+
+      @Override
+      public void start(Listener2 listener) {
+        List<EquivalentAddressGroup> addrs = addresses.stream()
+            .map(addr -> new EquivalentAddressGroup(Collections.singletonList(addr)))
+            .collect(Collectors.toList());
+
+        listener.onResult(
+            ResolutionResult.newBuilder()
+                .setAddresses(addrs)
+                .setAttributes(Attributes.EMPTY)
+                .build());
+      }
+
+      @Override
+      public void shutdown() {
+      }
+    };
+  }
+}