|  | @@ -9,20 +9,27 @@
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  package org.elasticsearch.action.admin.cluster.allocation;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import org.elasticsearch.action.ActionListener;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.support.ActionFilters;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.support.PlainActionFuture;
 | 
	
		
			
				|  |  |  import org.elasticsearch.cluster.ClusterState;
 | 
	
		
			
				|  |  |  import org.elasticsearch.cluster.routing.allocation.AllocationStatsService;
 | 
	
		
			
				|  |  | +import org.elasticsearch.cluster.routing.allocation.NodeAllocationStats;
 | 
	
		
			
				|  |  |  import org.elasticsearch.cluster.routing.allocation.NodeAllocationStatsTests;
 | 
	
		
			
				|  |  |  import org.elasticsearch.cluster.service.ClusterService;
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.settings.ClusterSettings;
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.settings.Settings;
 | 
	
		
			
				|  |  | +import org.elasticsearch.core.CheckedConsumer;
 | 
	
		
			
				|  |  |  import org.elasticsearch.core.TimeValue;
 | 
	
		
			
				|  |  | +import org.elasticsearch.node.Node;
 | 
	
		
			
				|  |  |  import org.elasticsearch.tasks.Task;
 | 
	
		
			
				|  |  |  import org.elasticsearch.tasks.TaskId;
 | 
	
		
			
				|  |  | +import org.elasticsearch.telemetry.metric.MeterRegistry;
 | 
	
		
			
				|  |  |  import org.elasticsearch.test.ClusterServiceUtils;
 | 
	
		
			
				|  |  |  import org.elasticsearch.test.ESTestCase;
 | 
	
		
			
				|  |  |  import org.elasticsearch.test.transport.CapturingTransport;
 | 
	
		
			
				|  |  | -import org.elasticsearch.threadpool.TestThreadPool;
 | 
	
		
			
				|  |  | +import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
 | 
	
		
			
				|  |  |  import org.elasticsearch.threadpool.ThreadPool;
 | 
	
		
			
				|  |  |  import org.elasticsearch.transport.TransportService;
 | 
	
		
			
				|  |  |  import org.junit.After;
 | 
	
	
		
			
				|  | @@ -35,6 +42,7 @@ import java.util.Set;
 | 
	
		
			
				|  |  |  import java.util.concurrent.CyclicBarrier;
 | 
	
		
			
				|  |  |  import java.util.concurrent.atomic.AtomicBoolean;
 | 
	
		
			
				|  |  |  import java.util.concurrent.atomic.AtomicInteger;
 | 
	
		
			
				|  |  | +import java.util.concurrent.atomic.AtomicReference;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import static org.hamcrest.Matchers.anEmptyMap;
 | 
	
		
			
				|  |  |  import static org.hamcrest.Matchers.containsString;
 | 
	
	
		
			
				|  | @@ -47,7 +55,9 @@ import static org.mockito.Mockito.when;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  public class TransportGetAllocationStatsActionTests extends ESTestCase {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private ThreadPool threadPool;
 | 
	
		
			
				|  |  | +    private long startTimeMillis;
 | 
	
		
			
				|  |  | +    private TimeValue allocationStatsCacheTTL;
 | 
	
		
			
				|  |  | +    private ControlledRelativeTimeThreadPool threadPool;
 | 
	
		
			
				|  |  |      private ClusterService clusterService;
 | 
	
		
			
				|  |  |      private TransportService transportService;
 | 
	
		
			
				|  |  |      private AllocationStatsService allocationStatsService;
 | 
	
	
		
			
				|  | @@ -58,8 +68,16 @@ public class TransportGetAllocationStatsActionTests extends ESTestCase {
 | 
	
		
			
				|  |  |      @Before
 | 
	
		
			
				|  |  |      public void setUp() throws Exception {
 | 
	
		
			
				|  |  |          super.setUp();
 | 
	
		
			
				|  |  | -        threadPool = new TestThreadPool(TransportClusterAllocationExplainActionTests.class.getName());
 | 
	
		
			
				|  |  | -        clusterService = ClusterServiceUtils.createClusterService(threadPool);
 | 
	
		
			
				|  |  | +        startTimeMillis = 0L;
 | 
	
		
			
				|  |  | +        allocationStatsCacheTTL = TimeValue.timeValueMinutes(1);
 | 
	
		
			
				|  |  | +        threadPool = new ControlledRelativeTimeThreadPool(TransportClusterAllocationExplainActionTests.class.getName(), startTimeMillis);
 | 
	
		
			
				|  |  | +        clusterService = ClusterServiceUtils.createClusterService(
 | 
	
		
			
				|  |  | +            threadPool,
 | 
	
		
			
				|  |  | +            new ClusterSettings(
 | 
	
		
			
				|  |  | +                Settings.builder().put(TransportGetAllocationStatsAction.CACHE_TTL_SETTING.getKey(), allocationStatsCacheTTL).build(),
 | 
	
		
			
				|  |  | +                ClusterSettings.BUILT_IN_CLUSTER_SETTINGS
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  |          transportService = new CapturingTransport().createTransportService(
 | 
	
		
			
				|  |  |              clusterService.getSettings(),
 | 
	
		
			
				|  |  |              threadPool,
 | 
	
	
		
			
				|  | @@ -87,7 +105,17 @@ public class TransportGetAllocationStatsActionTests extends ESTestCase {
 | 
	
		
			
				|  |  |          transportService.close();
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    private void disableAllocationStatsCache() {
 | 
	
		
			
				|  |  | +        setAllocationStatsCacheTTL(TimeValue.ZERO);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private void setAllocationStatsCacheTTL(TimeValue ttl) {
 | 
	
		
			
				|  |  | +        clusterService.getClusterSettings()
 | 
	
		
			
				|  |  | +            .applySettings(Settings.builder().put(TransportGetAllocationStatsAction.CACHE_TTL_SETTING.getKey(), ttl).build());
 | 
	
		
			
				|  |  | +    };
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      public void testReturnsOnlyRequestedStats() throws Exception {
 | 
	
		
			
				|  |  | +        disableAllocationStatsCache();
 | 
	
		
			
				|  |  |          int expectedNumberOfStatsServiceCalls = 0;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          for (final var metrics : List.of(
 | 
	
	
		
			
				|  | @@ -129,6 +157,7 @@ public class TransportGetAllocationStatsActionTests extends ESTestCase {
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void testDeduplicatesStatsComputations() throws InterruptedException {
 | 
	
		
			
				|  |  | +        disableAllocationStatsCache();
 | 
	
		
			
				|  |  |          final var requestCounter = new AtomicInteger();
 | 
	
		
			
				|  |  |          final var isExecuting = new AtomicBoolean();
 | 
	
		
			
				|  |  |          when(allocationStatsService.stats()).thenAnswer(invocation -> {
 | 
	
	
		
			
				|  | @@ -173,4 +202,84 @@ public class TransportGetAllocationStatsActionTests extends ESTestCase {
 | 
	
		
			
				|  |  |              thread.join();
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testGetStatsWithCachingEnabled() throws Exception {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        final AtomicReference<Map<String, NodeAllocationStats>> allocationStats = new AtomicReference<>();
 | 
	
		
			
				|  |  | +        int numExpectedAllocationStatsServiceCalls = 0;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        final Runnable resetExpectedAllocationStats = () -> {
 | 
	
		
			
				|  |  | +            final var stats = Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats());
 | 
	
		
			
				|  |  | +            allocationStats.set(stats);
 | 
	
		
			
				|  |  | +            when(allocationStatsService.stats()).thenReturn(stats);
 | 
	
		
			
				|  |  | +        };
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        final CheckedConsumer<ActionListener<Void>, Exception> threadTask = l -> {
 | 
	
		
			
				|  |  | +            final var request = new TransportGetAllocationStatsAction.Request(
 | 
	
		
			
				|  |  | +                TEST_REQUEST_TIMEOUT,
 | 
	
		
			
				|  |  | +                new TaskId(randomIdentifier(), randomNonNegativeLong()),
 | 
	
		
			
				|  |  | +                EnumSet.of(Metric.ALLOCATIONS)
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            action.masterOperation(mock(Task.class), request, ClusterState.EMPTY_STATE, l.map(response -> {
 | 
	
		
			
				|  |  | +                assertSame("Expected the cached allocation stats to be returned", response.getNodeAllocationStats(), allocationStats.get());
 | 
	
		
			
				|  |  | +                return null;
 | 
	
		
			
				|  |  | +            }));
 | 
	
		
			
				|  |  | +        };
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Initial cache miss, all threads should get the same value.
 | 
	
		
			
				|  |  | +        resetExpectedAllocationStats.run();
 | 
	
		
			
				|  |  | +        ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask));
 | 
	
		
			
				|  |  | +        verify(allocationStatsService, times(++numExpectedAllocationStatsServiceCalls)).stats();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Advance the clock to a time less than or equal to the TTL and verify we still get the cached stats.
 | 
	
		
			
				|  |  | +        threadPool.setCurrentTimeInMillis(startTimeMillis + between(0, (int) allocationStatsCacheTTL.millis()));
 | 
	
		
			
				|  |  | +        ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask));
 | 
	
		
			
				|  |  | +        verify(allocationStatsService, times(numExpectedAllocationStatsServiceCalls)).stats();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Force the cached stats to expire.
 | 
	
		
			
				|  |  | +        threadPool.setCurrentTimeInMillis(startTimeMillis + allocationStatsCacheTTL.getMillis() + 1);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Expect a single call to the stats service on the cache miss.
 | 
	
		
			
				|  |  | +        resetExpectedAllocationStats.run();
 | 
	
		
			
				|  |  | +        ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask));
 | 
	
		
			
				|  |  | +        verify(allocationStatsService, times(++numExpectedAllocationStatsServiceCalls)).stats();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Update the TTL setting to disable the cache, we expect a service call each time.
 | 
	
		
			
				|  |  | +        setAllocationStatsCacheTTL(TimeValue.ZERO);
 | 
	
		
			
				|  |  | +        safeAwait(threadTask);
 | 
	
		
			
				|  |  | +        safeAwait(threadTask);
 | 
	
		
			
				|  |  | +        numExpectedAllocationStatsServiceCalls += 2;
 | 
	
		
			
				|  |  | +        verify(allocationStatsService, times(numExpectedAllocationStatsServiceCalls)).stats();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Re-enable the cache, only one thread should call the stats service.
 | 
	
		
			
				|  |  | +        setAllocationStatsCacheTTL(TimeValue.timeValueMinutes(5));
 | 
	
		
			
				|  |  | +        resetExpectedAllocationStats.run();
 | 
	
		
			
				|  |  | +        ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask));
 | 
	
		
			
				|  |  | +        verify(allocationStatsService, times(++numExpectedAllocationStatsServiceCalls)).stats();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private static class ControlledRelativeTimeThreadPool extends ThreadPool {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        private long currentTimeInMillis;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ControlledRelativeTimeThreadPool(String name, long startTimeMillis) {
 | 
	
		
			
				|  |  | +            super(
 | 
	
		
			
				|  |  | +                Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).build(),
 | 
	
		
			
				|  |  | +                MeterRegistry.NOOP,
 | 
	
		
			
				|  |  | +                new DefaultBuiltInExecutorBuilders()
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  | +            this.currentTimeInMillis = startTimeMillis;
 | 
	
		
			
				|  |  | +            stopCachedTimeThread();
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        @Override
 | 
	
		
			
				|  |  | +        public long relativeTimeInMillis() {
 | 
	
		
			
				|  |  | +            return currentTimeInMillis;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        void setCurrentTimeInMillis(long currentTimeInMillis) {
 | 
	
		
			
				|  |  | +            this.currentTimeInMillis = currentTimeInMillis;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  |  }
 |