|  | @@ -66,19 +66,30 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      static ShrinkAction randomInstance() {
 | 
	
		
			
				|  |  |          if (randomBoolean()) {
 | 
	
		
			
				|  |  | -            return new ShrinkAction(randomIntBetween(1, 100), null);
 | 
	
		
			
				|  |  | +            return new ShrinkAction(randomIntBetween(1, 100), null, randomBoolean());
 | 
	
		
			
				|  |  |          } else {
 | 
	
		
			
				|  |  | -            return new ShrinkAction(null, ByteSizeValue.ofBytes(randomIntBetween(1, 100)));
 | 
	
		
			
				|  |  | +            return new ShrinkAction(null, ByteSizeValue.ofBytes(randomIntBetween(1, 100)), randomBoolean());
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
		
			
				|  |  |      protected ShrinkAction mutateInstance(ShrinkAction action) {
 | 
	
		
			
				|  |  | -        if (action.getNumberOfShards() != null) {
 | 
	
		
			
				|  |  | -            return new ShrinkAction(action.getNumberOfShards() + randomIntBetween(1, 2), null);
 | 
	
		
			
				|  |  | -        } else {
 | 
	
		
			
				|  |  | -            return new ShrinkAction(null, ByteSizeValue.ofBytes(action.getMaxPrimaryShardSize().getBytes() + 1));
 | 
	
		
			
				|  |  | +        Integer numberOfShards = action.getNumberOfShards();
 | 
	
		
			
				|  |  | +        ByteSizeValue maxPrimaryShardSize = action.getMaxPrimaryShardSize();
 | 
	
		
			
				|  |  | +        boolean allowWriteAfterShrink = action.getAllowWriteAfterShrink();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        switch (randomInt(2)) {
 | 
	
		
			
				|  |  | +            case 0 -> {
 | 
	
		
			
				|  |  | +                numberOfShards = randomValueOtherThan(numberOfShards, () -> randomIntBetween(1, 100));
 | 
	
		
			
				|  |  | +                maxPrimaryShardSize = null;
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            case 1 -> {
 | 
	
		
			
				|  |  | +                maxPrimaryShardSize = randomValueOtherThan(maxPrimaryShardSize, () -> ByteSizeValue.ofBytes(randomIntBetween(1, 100)));
 | 
	
		
			
				|  |  | +                numberOfShards = null;
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            case 2 -> allowWriteAfterShrink = allowWriteAfterShrink == false;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | +        return new ShrinkAction(numberOfShards, maxPrimaryShardSize, allowWriteAfterShrink);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
	
		
			
				|  | @@ -87,24 +98,27 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void testNonPositiveShardNumber() {
 | 
	
		
			
				|  |  | -        Exception e = expectThrows(Exception.class, () -> new ShrinkAction(randomIntBetween(-100, 0), null));
 | 
	
		
			
				|  |  | +        Exception e = expectThrows(Exception.class, () -> new ShrinkAction(randomIntBetween(-100, 0), null, randomBoolean()));
 | 
	
		
			
				|  |  |          assertThat(e.getMessage(), equalTo("[number_of_shards] must be greater than 0"));
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void testMaxPrimaryShardSize() {
 | 
	
		
			
				|  |  |          ByteSizeValue maxPrimaryShardSize1 = ByteSizeValue.ofBytes(10);
 | 
	
		
			
				|  |  | -        Exception e1 = expectThrows(Exception.class, () -> new ShrinkAction(randomIntBetween(1, 100), maxPrimaryShardSize1));
 | 
	
		
			
				|  |  | +        Exception e1 = expectThrows(
 | 
	
		
			
				|  |  | +            Exception.class,
 | 
	
		
			
				|  |  | +            () -> new ShrinkAction(randomIntBetween(1, 100), maxPrimaryShardSize1, randomBoolean())
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  |          assertThat(e1.getMessage(), equalTo("Cannot set both [number_of_shards] and [max_primary_shard_size]"));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          ByteSizeValue maxPrimaryShardSize2 = ByteSizeValue.ZERO;
 | 
	
		
			
				|  |  | -        Exception e2 = expectThrows(Exception.class, () -> new ShrinkAction(null, maxPrimaryShardSize2));
 | 
	
		
			
				|  |  | +        Exception e2 = expectThrows(Exception.class, () -> new ShrinkAction(null, maxPrimaryShardSize2, randomBoolean()));
 | 
	
		
			
				|  |  |          assertThat(e2.getMessage(), equalTo("[max_primary_shard_size] must be greater than 0"));
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void testPerformActionWithSkipBecauseOfShardNumber() throws InterruptedException {
 | 
	
		
			
				|  |  |          String lifecycleName = randomAlphaOfLengthBetween(4, 10);
 | 
	
		
			
				|  |  |          int numberOfShards = randomIntBetween(1, 10);
 | 
	
		
			
				|  |  | -        ShrinkAction action = new ShrinkAction(numberOfShards, null);
 | 
	
		
			
				|  |  | +        ShrinkAction action = new ShrinkAction(numberOfShards, null, randomBoolean());
 | 
	
		
			
				|  |  |          StepKey nextStepKey = new StepKey(
 | 
	
		
			
				|  |  |              randomAlphaOfLengthBetween(1, 10),
 | 
	
		
			
				|  |  |              randomAlphaOfLengthBetween(1, 10),
 | 
	
	
		
			
				|  | @@ -121,7 +135,7 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |      public void testPerformActionWithSkipBecauseOfSearchableSnapshot() throws InterruptedException {
 | 
	
		
			
				|  |  |          String lifecycleName = randomAlphaOfLengthBetween(4, 10);
 | 
	
		
			
				|  |  |          int numberOfShards = randomIntBetween(1, 10);
 | 
	
		
			
				|  |  | -        ShrinkAction action = new ShrinkAction(numberOfShards, null);
 | 
	
		
			
				|  |  | +        ShrinkAction action = new ShrinkAction(numberOfShards, null, randomBoolean());
 | 
	
		
			
				|  |  |          StepKey nextStepKey = new StepKey(
 | 
	
		
			
				|  |  |              randomAlphaOfLengthBetween(1, 10),
 | 
	
		
			
				|  |  |              randomAlphaOfLengthBetween(1, 10),
 | 
	
	
		
			
				|  | @@ -143,7 +157,7 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |          int divisor = randomFrom(2, 3, 6);
 | 
	
		
			
				|  |  |          int expectedFinalShards = numShards / divisor;
 | 
	
		
			
				|  |  |          String lifecycleName = randomAlphaOfLengthBetween(4, 10);
 | 
	
		
			
				|  |  | -        ShrinkAction action = new ShrinkAction(expectedFinalShards, null);
 | 
	
		
			
				|  |  | +        ShrinkAction action = new ShrinkAction(expectedFinalShards, null, randomBoolean());
 | 
	
		
			
				|  |  |          StepKey nextStepKey = new StepKey(
 | 
	
		
			
				|  |  |              randomAlphaOfLengthBetween(1, 10),
 | 
	
		
			
				|  |  |              randomAlphaOfLengthBetween(1, 10),
 | 
	
	
		
			
				|  | @@ -160,7 +174,7 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |      public void testFailureIsPropagated() throws InterruptedException {
 | 
	
		
			
				|  |  |          String lifecycleName = randomAlphaOfLengthBetween(4, 10);
 | 
	
		
			
				|  |  |          int numberOfShards = randomIntBetween(1, 10);
 | 
	
		
			
				|  |  | -        ShrinkAction action = new ShrinkAction(numberOfShards, null);
 | 
	
		
			
				|  |  | +        ShrinkAction action = new ShrinkAction(numberOfShards, null, randomBoolean());
 | 
	
		
			
				|  |  |          StepKey nextStepKey = new StepKey(
 | 
	
		
			
				|  |  |              randomAlphaOfLengthBetween(1, 10),
 | 
	
		
			
				|  |  |              randomAlphaOfLengthBetween(1, 10),
 | 
	
	
		
			
				|  | @@ -185,7 +199,7 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |      ) throws InterruptedException {
 | 
	
		
			
				|  |  |          String phase = randomAlphaOfLengthBetween(1, 10);
 | 
	
		
			
				|  |  |          List<Step> steps = action.toSteps(client, phase, nextStepKey);
 | 
	
		
			
				|  |  | -        AsyncBranchingStep step = ((AsyncBranchingStep) steps.get(0));
 | 
	
		
			
				|  |  | +        AsyncBranchingStep branchStep = ((AsyncBranchingStep) steps.get(0));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          LifecyclePolicy policy = new LifecyclePolicy(
 | 
	
		
			
				|  |  |              lifecycleName,
 | 
	
	
		
			
				|  | @@ -211,11 +225,11 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |                          indexMetadataBuilder.putCustom(
 | 
	
		
			
				|  |  |                              LifecycleExecutionState.ILM_CUSTOM_METADATA_KEY,
 | 
	
		
			
				|  |  |                              LifecycleExecutionState.builder()
 | 
	
		
			
				|  |  | -                                .setPhase(step.getKey().phase())
 | 
	
		
			
				|  |  | +                                .setPhase(branchStep.getKey().phase())
 | 
	
		
			
				|  |  |                                  .setPhaseTime(0L)
 | 
	
		
			
				|  |  | -                                .setAction(step.getKey().action())
 | 
	
		
			
				|  |  | +                                .setAction(branchStep.getKey().action())
 | 
	
		
			
				|  |  |                                  .setActionTime(0L)
 | 
	
		
			
				|  |  | -                                .setStep(step.getKey().name())
 | 
	
		
			
				|  |  | +                                .setStep(branchStep.getKey().name())
 | 
	
		
			
				|  |  |                                  .setStepTime(0L)
 | 
	
		
			
				|  |  |                                  .build()
 | 
	
		
			
				|  |  |                                  .asMap()
 | 
	
	
		
			
				|  | @@ -226,7 +240,7 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |          setUpIndicesStatsRequestMock(indexName, withError);
 | 
	
		
			
				|  |  |          CountDownLatch countDownLatch = new CountDownLatch(1);
 | 
	
		
			
				|  |  |          AtomicBoolean failurePropagated = new AtomicBoolean(false);
 | 
	
		
			
				|  |  | -        step.performAction(state.metadata().index(indexName), state, null, new ActionListener<>() {
 | 
	
		
			
				|  |  | +        branchStep.performAction(state.metadata().index(indexName), state, null, new ActionListener<>() {
 | 
	
		
			
				|  |  |              @Override
 | 
	
		
			
				|  |  |              public void onResponse(Void unused) {
 | 
	
		
			
				|  |  |                  countDownLatch.countDown();
 | 
	
	
		
			
				|  | @@ -244,12 +258,18 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |          });
 | 
	
		
			
				|  |  |          assertTrue(countDownLatch.await(5, TimeUnit.SECONDS));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          if (withError) {
 | 
	
		
			
				|  |  |              assertTrue(failurePropagated.get());
 | 
	
		
			
				|  |  |          } else if (shouldSkip) {
 | 
	
		
			
				|  |  | -            assertThat(step.getNextStepKey(), equalTo(nextStepKey));
 | 
	
		
			
				|  |  | +            if (action.getAllowWriteAfterShrink()) {
 | 
	
		
			
				|  |  | +                Step lastStep = steps.get(steps.size() - 1);
 | 
	
		
			
				|  |  | +                assertThat(branchStep.getNextStepKey(), equalTo(lastStep.getKey()));
 | 
	
		
			
				|  |  | +            } else {
 | 
	
		
			
				|  |  | +                assertThat(branchStep.getNextStepKey(), equalTo(nextStepKey));
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |          } else {
 | 
	
		
			
				|  |  | -            assertThat(step.getNextStepKey(), equalTo(steps.get(1).getKey()));
 | 
	
		
			
				|  |  | +            assertThat(branchStep.getNextStepKey(), equalTo(steps.get(1).getKey()));
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -262,7 +282,7 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |              randomAlphaOfLengthBetween(1, 10)
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |          List<Step> steps = action.toSteps(client, phase, nextStepKey);
 | 
	
		
			
				|  |  | -        assertThat(steps.size(), equalTo(18));
 | 
	
		
			
				|  |  | +        assertThat(steps.size(), equalTo(action.getAllowWriteAfterShrink() ? 19 : 18));
 | 
	
		
			
				|  |  |          StepKey expectedFirstKey = new StepKey(phase, ShrinkAction.NAME, ShrinkAction.CONDITIONAL_SKIP_SHRINK_STEP);
 | 
	
		
			
				|  |  |          StepKey expectedSecondKey = new StepKey(phase, ShrinkAction.NAME, CheckNotDataStreamWriteIndexStep.NAME);
 | 
	
		
			
				|  |  |          StepKey expectedThirdKey = new StepKey(phase, ShrinkAction.NAME, WaitForNoFollowersStep.NAME);
 | 
	
	
		
			
				|  | @@ -281,12 +301,16 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |          StepKey expectedSixteenKey = new StepKey(phase, ShrinkAction.NAME, ShrunkenIndexCheckStep.NAME);
 | 
	
		
			
				|  |  |          StepKey expectedSeventeenKey = new StepKey(phase, ShrinkAction.NAME, ReplaceDataStreamBackingIndexStep.NAME);
 | 
	
		
			
				|  |  |          StepKey expectedEighteenKey = new StepKey(phase, ShrinkAction.NAME, DeleteStep.NAME);
 | 
	
		
			
				|  |  | +        StepKey expectedNineteenthKey = new StepKey(phase, ShrinkAction.NAME, UpdateSettingsStep.NAME);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          assertTrue(steps.get(0) instanceof AsyncBranchingStep);
 | 
	
		
			
				|  |  |          assertThat(steps.get(0).getKey(), equalTo(expectedFirstKey));
 | 
	
		
			
				|  |  |          expectThrows(IllegalStateException.class, () -> steps.get(0).getNextStepKey());
 | 
	
		
			
				|  |  |          assertThat(((AsyncBranchingStep) steps.get(0)).getNextStepKeyOnFalse(), equalTo(expectedSecondKey));
 | 
	
		
			
				|  |  | -        assertThat(((AsyncBranchingStep) steps.get(0)).getNextStepKeyOnTrue(), equalTo(nextStepKey));
 | 
	
		
			
				|  |  | +        assertThat(
 | 
	
		
			
				|  |  | +            ((AsyncBranchingStep) steps.get(0)).getNextStepKeyOnTrue(),
 | 
	
		
			
				|  |  | +            equalTo(action.getAllowWriteAfterShrink() ? expectedNineteenthKey : nextStepKey)
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          assertTrue(steps.get(1) instanceof CheckNotDataStreamWriteIndexStep);
 | 
	
		
			
				|  |  |          assertThat(steps.get(1).getKey(), equalTo(expectedSecondKey));
 | 
	
	
		
			
				|  | @@ -357,7 +381,7 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          assertTrue(steps.get(15) instanceof ShrunkenIndexCheckStep);
 | 
	
		
			
				|  |  |          assertThat(steps.get(15).getKey(), equalTo(expectedSixteenKey));
 | 
	
		
			
				|  |  | -        assertThat(steps.get(15).getNextStepKey(), equalTo(nextStepKey));
 | 
	
		
			
				|  |  | +        assertThat(steps.get(15).getNextStepKey(), equalTo(action.getAllowWriteAfterShrink() ? expectedNineteenthKey : nextStepKey));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          assertTrue(steps.get(16) instanceof ReplaceDataStreamBackingIndexStep);
 | 
	
		
			
				|  |  |          assertThat(steps.get(16).getKey(), equalTo(expectedSeventeenKey));
 | 
	
	
		
			
				|  | @@ -366,6 +390,12 @@ public class ShrinkActionTests extends AbstractActionTestCase<ShrinkAction> {
 | 
	
		
			
				|  |  |          assertTrue(steps.get(17) instanceof DeleteStep);
 | 
	
		
			
				|  |  |          assertThat(steps.get(17).getKey(), equalTo(expectedEighteenKey));
 | 
	
		
			
				|  |  |          assertThat(steps.get(17).getNextStepKey(), equalTo(expectedSixteenKey));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if (action.getAllowWriteAfterShrink()) {
 | 
	
		
			
				|  |  | +            assertTrue(steps.get(18) instanceof UpdateSettingsStep);
 | 
	
		
			
				|  |  | +            assertThat(steps.get(18).getKey(), equalTo(expectedNineteenthKey));
 | 
	
		
			
				|  |  | +            assertThat(steps.get(18).getNextStepKey(), equalTo(nextStepKey));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private void setUpIndicesStatsRequestMock(String index, boolean withError) {
 |