Browse Source

Log request and response messages

jianghua 4 years ago
parent
commit
a021e5336d

+ 1 - 1
examples/src/main/java/MilvusClientExample.java

@@ -73,7 +73,7 @@ public class MilvusClientExample {
 
 
   public static void run(ConnectParam connectParam) {
   public static void run(ConnectParam connectParam) {
     // Create Milvus client
     // Create Milvus client
-    MilvusClient client = new MilvusGrpcClient(connectParam);
+    MilvusClient client = new MilvusGrpcClient(connectParam).withLogging();
 
 
     // Create a collection with the following collection mapping
     // Create a collection with the following collection mapping
     final String collectionName = "example"; // collection name
     final String collectionName = "example"; // collection name

+ 80 - 0
src/main/java/io/milvus/client/LoggingAdapter.java

@@ -0,0 +1,80 @@
+package io.milvus.client;
+
+import com.google.protobuf.Descriptors;
+import com.google.protobuf.MessageOrBuilder;
+import com.google.protobuf.TextFormat;
+import io.grpc.MethodDescriptor;
+import org.slf4j.Logger;
+
+import java.util.List;
+import java.util.concurrent.atomic.AtomicLong;
+
+public class LoggingAdapter {
+  public static final LoggingAdapter DEFAULT_LOGGING_ADAPTER = new LoggingAdapter();
+  private static final AtomicLong traceId = new AtomicLong(0);
+
+  protected LoggingAdapter() {
+  }
+
+  protected String getTraceId() {
+    return Long.toHexString(traceId.getAndIncrement());
+  }
+
+  protected void logRequest(Logger logger, String traceId, MethodDescriptor method, Object message) {
+    if (logger.isTraceEnabled()) {
+      logger.trace("TraceId: {}, Method: {}, Request: {}", traceId, method.getFullMethodName(), trace(message));
+    } else if (logger.isInfoEnabled()) {
+      logger.info("TraceId: {}, Method: {}, Request: {}", traceId, method.getFullMethodName(), info(message));
+    }
+  }
+
+  protected void logResponse(Logger logger, String traceId, MethodDescriptor method, Object message) {
+    if (logger.isTraceEnabled()) {
+      logger.trace("TraceId: {}, Method: {}, Response: {}", traceId, method.getFullMethodName(), trace(message));
+    } else if (logger.isInfoEnabled()) {
+      logger.info("TraceId: {}, Method: {}, Response: {}", traceId, method.getFullMethodName(), info(message));
+    }
+  }
+
+  protected String info(Object message) {
+    if (message instanceof MessageOrBuilder) {
+      MessageOrBuilder msg = (MessageOrBuilder) message;
+      StringBuilder output = new StringBuilder(msg.getDescriptorForType().getName());
+      write((MessageOrBuilder) message, output);
+      return output.toString();
+    }
+    return message.toString();
+  }
+
+  protected String trace(Object message) {
+    if (message instanceof MessageOrBuilder) {
+      return TextFormat.printer().printToString((MessageOrBuilder) message);
+    }
+    return message.toString();
+  }
+
+  protected void write(MessageOrBuilder message, StringBuilder output) {
+    output.append(" { ");
+    message.getAllFields().entrySet().stream().forEach(e -> {
+      if (e.getKey().isRepeated()) {
+        output.append(e.getKey().getName())
+            .append(" [ ")
+            .append(((List<?>) e.getValue()).size())
+            .append(" items ], ");
+      } else if (e.getKey().isMapField()) {
+        output.append(e.getKey().getName())
+            .append(" { ")
+            .append(((List<?>) e.getValue()).size())
+            .append(" entries }, ");
+      } else if (e.getKey().getJavaType() == Descriptors.FieldDescriptor.JavaType.MESSAGE) {
+        output.append(e.getKey().getName());
+        write((MessageOrBuilder) e.getValue(), output);
+      } else {
+        output.append(TextFormat.printer().shortDebugString(e.getKey(), e.getValue()))
+            .append(", ");
+      }
+    });
+    output.setLength(output.length() - 2);
+    output.append(" } ");
+  }
+}

+ 50 - 2
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -26,8 +26,11 @@ import io.grpc.CallOptions;
 import io.grpc.Channel;
 import io.grpc.Channel;
 import io.grpc.ClientCall;
 import io.grpc.ClientCall;
 import io.grpc.ClientInterceptor;
 import io.grpc.ClientInterceptor;
+import io.grpc.ForwardingClientCall;
+import io.grpc.ForwardingClientCallListener;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
 import io.grpc.ManagedChannelBuilder;
+import io.grpc.Metadata;
 import io.grpc.MethodDescriptor;
 import io.grpc.MethodDescriptor;
 import io.milvus.client.exception.ClientSideMilvusException;
 import io.milvus.client.exception.ClientSideMilvusException;
 import io.milvus.client.exception.MilvusException;
 import io.milvus.client.exception.MilvusException;
@@ -111,13 +114,25 @@ public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
     }
     }
   }
   }
 
 
+  public MilvusClient withLogging() {
+    return withLogging(LoggingAdapter.DEFAULT_LOGGING_ADAPTER);
+  }
+
+  public MilvusClient withLogging(LoggingAdapter loggingAdapter) {
+    return withInterceptors(new LoggingInterceptor(loggingAdapter));
+  }
+
   public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
   public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
     final long timeoutMillis = timeoutUnit.toMillis(timeout);
     final long timeoutMillis = timeoutUnit.toMillis(timeout);
     final TimeoutInterceptor timeoutInterceptor = new TimeoutInterceptor(timeoutMillis);
     final TimeoutInterceptor timeoutInterceptor = new TimeoutInterceptor(timeoutMillis);
+    return withInterceptors(timeoutInterceptor);
+  }
+
+  private MilvusClient withInterceptors(ClientInterceptor... interceptors) {
     final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub =
     final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub =
-        this.blockingStub.withInterceptors(timeoutInterceptor);
+        this.blockingStub.withInterceptors(interceptors);
     final MilvusServiceGrpc.MilvusServiceFutureStub futureStub =
     final MilvusServiceGrpc.MilvusServiceFutureStub futureStub =
-        this.futureStub.withInterceptors(timeoutInterceptor);
+        this.futureStub.withInterceptors(interceptors);
 
 
     return new AbstractMilvusGrpcClient() {
     return new AbstractMilvusGrpcClient() {
       @Override
       @Override
@@ -160,6 +175,39 @@ public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
       return next.newCall(method, callOptions.withDeadlineAfter(timeoutMillis, TimeUnit.MILLISECONDS));
       return next.newCall(method, callOptions.withDeadlineAfter(timeoutMillis, TimeUnit.MILLISECONDS));
     }
     }
   }
   }
+
+  private static class LoggingInterceptor implements ClientInterceptor {
+    private LoggingAdapter loggingAdapter;
+
+    LoggingInterceptor(LoggingAdapter loggingAdapter) {
+      this.loggingAdapter = loggingAdapter;
+    }
+
+    @Override
+    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+        MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
+      return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
+        private String traceId = loggingAdapter.getTraceId();
+
+        @Override
+        public void sendMessage(ReqT message) {
+          loggingAdapter.logRequest(logger, traceId, method, message);
+          super.sendMessage(message);
+        }
+
+        @Override
+        public void start(Listener<RespT> responseListener, Metadata headers) {
+          super.start(new ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT>(responseListener) {
+            @Override
+            public void onMessage(RespT message) {
+              loggingAdapter.logResponse(logger, traceId, method, message);
+              super.onMessage(message);
+            }
+          }, headers);
+        }
+      };
+    }
+  }
 }
 }
 
 
 abstract class AbstractMilvusGrpcClient implements MilvusClient {
 abstract class AbstractMilvusGrpcClient implements MilvusClient {

+ 3 - 3
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -88,7 +88,7 @@ class ContainerMilvusClientTest extends MilvusClientTest {
         .withTarget(testNameResolverProvider.getDefaultScheme() + ":///test")
         .withTarget(testNameResolverProvider.getDefaultScheme() + ":///test")
         .build();
         .build();
 
 
-    MilvusClient loadBalancingClient = new MilvusGrpcClient(connectParam);
+    MilvusClient loadBalancingClient = new MilvusGrpcClient(connectParam).withLogging();
     assertEquals(50, IntStream.range(0, 100)
     assertEquals(50, IntStream.range(0, 100)
             .filter(i -> loadBalancingClient.hasCollection(randomCollectionName))
             .filter(i -> loadBalancingClient.hasCollection(randomCollectionName))
             .count());
             .count());
@@ -203,7 +203,7 @@ class MilvusClientTest {
   void setUp() throws Exception {
   void setUp() throws Exception {
 
 
     ConnectParam connectParam = connectParamBuilder().build();
     ConnectParam connectParam = connectParamBuilder().build();
-    client = new MilvusGrpcClient(connectParam);
+    client = new MilvusGrpcClient(connectParam).withLogging();
 
 
     randomCollectionName = RandomStringUtils.randomAlphabetic(10);
     randomCollectionName = RandomStringUtils.randomAlphabetic(10);
     size = 100000;
     size = 100000;
@@ -233,7 +233,7 @@ class MilvusClientTest {
     ConnectParam connectParam = connectParamBuilder()
     ConnectParam connectParam = connectParamBuilder()
         .withIdleTimeout(1, TimeUnit.SECONDS)
         .withIdleTimeout(1, TimeUnit.SECONDS)
         .build();
         .build();
-    MilvusClient client = new MilvusGrpcClient(connectParam);
+    MilvusClient client = new MilvusGrpcClient(connectParam).withLogging();
     TimeUnit.SECONDS.sleep(2);
     TimeUnit.SECONDS.sleep(2);
     // A new RPC would take the channel out of idle mode
     // A new RPC would take the channel out of idle mode
     client.listCollections();
     client.listCollections();

+ 0 - 0
src/test/resources/log4j2.xml → src/test/resources/log4j2-test.xml