Browse Source

changes to propagate tracid from client (#955) (#961)

Signed-off-by: Shreesha Srinath Madogaran <smadogaran@salesforce.com>
Co-authored-by: madogar <36537062+madogar@users.noreply.github.com>
Co-authored-by: Shreesha Srinath Madogaran <smadogaran@salesforce.com>
groot 1 year ago
parent
commit
7d321f8956

+ 24 - 3
src/main/java/io/milvus/client/MilvusServiceClient.java

@@ -51,6 +51,8 @@ import org.apache.commons.lang3.StringUtils;
 
 import java.io.File;
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.TimeUnit;
 import java.net.InetAddress;
@@ -75,6 +77,25 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
             metadata.put(Metadata.Key.of("dbname", Metadata.ASCII_STRING_MARSHALLER), connectParam.getDatabaseName());
         }
 
+        List<ClientInterceptor> clientInterceptors = new ArrayList<>();
+        clientInterceptors.add(MetadataUtils.newAttachHeadersInterceptor(metadata));
+        //client interceptor used to fetch client_request_id from threadlocal variable and set it for every grpc request
+        clientInterceptors.add(new ClientInterceptor() {
+            @Override
+            public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
+                return new ForwardingClientCall
+                    .SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
+                    @Override
+                    public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) {
+                        if(connectParam.getClientRequestId() != null && !StringUtils.isEmpty(connectParam.getClientRequestId().get())) {
+                            headers.put(Metadata.Key.of("client_request_id", Metadata.ASCII_STRING_MARSHALLER), connectParam.getClientRequestId().get());
+                        }
+                        super.start(responseListener, headers);
+                    }
+                };
+            }
+        });
+
         try {
             if (StringUtils.isNotEmpty(connectParam.getServerPemPath())) {
                 // one-way tls
@@ -90,7 +111,7 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
                         .keepAliveTimeout(connectParam.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
                         .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
                         .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
-                        .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                        .intercept(clientInterceptors);
                 if(connectParam.isSecure()){
                     builder.useTransportSecurity();
                 }
@@ -111,7 +132,7 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
                         .keepAliveTimeout(connectParam.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
                         .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
                         .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
-                        .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                        .intercept(clientInterceptors);
                 if(connectParam.isSecure()){
                     builder.useTransportSecurity();
                 }
@@ -128,7 +149,7 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
                         .keepAliveTimeout(connectParam.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
                         .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
                         .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
-                        .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                        .intercept(clientInterceptors);
                 if(connectParam.isSecure()){
                     builder.useTransportSecurity();
                 }

+ 10 - 0
src/main/java/io/milvus/param/ConnectParam.java

@@ -58,6 +58,7 @@ public class ConnectParam {
     private final String serverPemPath;
     private final String serverName;
     private final String userName;
+    private final ThreadLocal<String> clientRequestId;
 
     protected ConnectParam(@NonNull Builder builder) {
         this.host = builder.host;
@@ -79,6 +80,7 @@ public class ConnectParam {
         this.serverPemPath = builder.serverPemPath;
         this.serverName = builder.serverName;
         this.userName = builder.userName;
+        this.clientRequestId = builder.clientRequestId;
     }
 
     public static Builder newBuilder() {
@@ -116,6 +118,9 @@ public class ConnectParam {
         // If the username is unknown, send it as an empty string.
         private String userName = "";
 
+        //used to set client_request_id in the grpc header uniquely for every request
+        private ThreadLocal<String> clientRequestId;
+
         protected Builder() {
         }
 
@@ -350,6 +355,11 @@ public class ConnectParam {
             return this;
         }
 
+        public Builder withClientRequestId(@NonNull ThreadLocal<String> clientRequestId) {
+            this.clientRequestId = clientRequestId;
+            return this;
+        }
+
         /**
          * Verifies parameters and creates a new {@link ConnectParam} instance.
          *

+ 30 - 0
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -48,6 +48,9 @@ import java.lang.reflect.Method;
 import java.nio.ByteBuffer;
 import java.util.*;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 
 import static org.junit.jupiter.api.Assertions.*;
@@ -309,6 +312,33 @@ class MilvusServiceClientTest {
         server.stop();
     }
 
+    @Test
+    void testConnectWithClientRequestId() {
+        ThreadLocal<String> clientRequestId = new ThreadLocal<>();
+        clientRequestId.set("req1");
+        ConnectParam connectParam = ConnectParam.newBuilder()
+            .withHost("localhost")
+            .withPort(testPort)
+            .withConnectTimeout(10000, TimeUnit.MILLISECONDS)
+            .withClientRequestId(clientRequestId)
+            .build();
+        RetryParam retryParam = RetryParam.newBuilder()
+            .withMaxRetryTimes(2)
+            .build();
+
+        MockMilvusServer server = startServer();
+        MilvusServiceClient client = new MilvusServiceClient(connectParam);
+        client.withRetry(retryParam);
+        DescribeCollectionParam param = DescribeCollectionParam.newBuilder()
+            .withCollectionName("collection1")
+            .build();
+        R<DescribeCollectionResponse> response = client.describeCollection(param);
+
+        assertTrue(response.getStatus() == 0);
+
+        server.stop();
+    }
+
     @Test
     void createCollectionParam() {
         // test throw exception with illegal input for FieldType