Browse Source

[TEST] Make AbstractSimpleTransportTestCase#testTimeoutSendExceptionWithDelayedResponse more robust and wait for in-flight request

Simon Willnauer 9 years ago
parent
commit
ec55f9fff7

+ 40 - 22
test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java

@@ -37,13 +37,16 @@ import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static java.util.Collections.emptyMap;
@@ -553,31 +556,40 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
     }
 
     public void testTimeoutSendExceptionWithDelayedResponse() throws Exception {
-        CountDownLatch doneLatch = new CountDownLatch(1);
-        CountDownLatch allResponded = new CountDownLatch(1);
+        CountDownLatch waitForever = new CountDownLatch(1);
+        CountDownLatch doneWaitingForever = new CountDownLatch(1);
+        AtomicInteger inFlight = new AtomicInteger(0);
         serviceA.registerRequestHandler("sayHelloTimeoutDelayedResponse", StringMessageRequest::new, ThreadPool.Names.GENERIC,
             new TransportRequestHandler<StringMessageRequest>() {
                 @Override
-                public void messageReceived(StringMessageRequest request, TransportChannel channel) {
-                    TimeValue sleep = TimeValue.parseTimeValue(request.message, null, "sleep");
-                    try {
-                        doneLatch.await(sleep.millis(), TimeUnit.MILLISECONDS);
-                    } catch (InterruptedException e) {
-                        // ignore
-                    }
+                public void messageReceived(StringMessageRequest request, TransportChannel channel) throws InterruptedException {
+                    inFlight.incrementAndGet();
                     try {
-                        channel.sendResponse(new StringMessageResponse("hello " + request.message));
-                    } catch (IOException e) {
-                        logger.error("Unexpected failure", e);
-                        fail(e.getMessage());
+                        String message = request.message;
+                        if ("forever".equals(message)) {
+                            waitForever.await();
+                        } else {
+                            TimeValue sleep = TimeValue.parseTimeValue(message, null, "sleep");
+                            Thread.sleep(sleep.millis());
+                        }
+                        try {
+                            channel.sendResponse(new StringMessageResponse("hello " + request.message));
+                        } catch (IOException e) {
+                            logger.error("Unexpected failure", e);
+                            fail(e.getMessage());
+                        } finally {
+                            if ("forever".equals(message)) {
+                                doneWaitingForever.countDown();
+                            }
+                        }
                     } finally {
-                        allResponded.countDown();
+                        inFlight.decrementAndGet();
                     }
                 }
             });
         final CountDownLatch latch = new CountDownLatch(1);
         TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse",
-            new StringMessageRequest("2m"), TransportRequestOptions.builder().withTimeout(100).build(),
+            new StringMessageRequest("forever"), TransportRequestOptions.builder().withTimeout(100).build(),
             new TransportResponseHandler<StringMessageResponse>() {
                 @Override
                 public StringMessageResponse newInstance() {
@@ -603,17 +615,18 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
             });
 
         try {
-            StringMessageResponse message = res.txGet();
+            res.txGet();
             fail("exception should be thrown");
         } catch (Exception e) {
             assertThat(e, instanceOf(ReceiveTimeoutTransportException.class));
         }
         latch.await();
 
+        List<Runnable> assertions = new ArrayList<>();
         for (int i = 0; i < 10; i++) {
             final int counter = i;
             // now, try and send another request, this times, with a short timeout
-            res = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse",
+            TransportFuture<StringMessageResponse> result = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse",
                 new StringMessageRequest(counter + "ms"), TransportRequestOptions.builder().withTimeout(3000).build(),
                 new TransportResponseHandler<StringMessageResponse>() {
                     @Override
@@ -638,13 +651,18 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
                     }
                 });
 
-            StringMessageResponse message = res.txGet();
-            assertThat(message.message, equalTo("hello " + counter + "ms"));
+            assertions.add(() -> {
+                StringMessageResponse message = result.txGet();
+                assertThat(message.message, equalTo("hello " + counter + "ms"));
+            });
+        }
+        for (Runnable runnable : assertions) {
+            runnable.run();
         }
-
         serviceA.removeHandler("sayHelloTimeoutDelayedResponse");
-        doneLatch.countDown();
-        allResponded.await();
+        waitForever.countDown();
+        doneWaitingForever.await();
+        assertEquals(0, inFlight.get());
     }
 
     @TestLogging(value = "test. transport.tracer:TRACE")