Browse Source

Respect ML upgrade mode in TrainedModelStatsService (#61143)

When in upgrade mode the ml stats service should not write to the stats index.
David Kyle 5 years ago
parent
commit
2a1e8e3068

+ 15 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.action.update.UpdateRequest;
 import org.elasticsearch.client.OriginSettingClient;
+import org.elasticsearch.cluster.ClusterChangedEvent;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
@@ -29,6 +30,7 @@ import org.elasticsearch.script.Script;
 import org.elasticsearch.script.ScriptType;
 import org.elasticsearch.threadpool.Scheduler;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.MlStatsIndex;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
@@ -102,7 +104,12 @@ public class TrainedModelStatsService {
                 stop();
             }
         });
-        clusterService.addListener((event) -> this.clusterState = event.state());
+        clusterService.addListener(this::setClusterState);
+    }
+
+    // visible for testing
+    void setClusterState(ClusterChangedEvent event) {
+        clusterState = event.state();
     }
 
     /**
@@ -146,6 +153,13 @@ public class TrainedModelStatsService {
         if (clusterState == null || statsQueue.isEmpty() || stopped) {
             return;
         }
+
+        boolean isInUpgradeMode = MlMetadata.getMlMetadata(clusterState).isUpgradeMode();
+        if (isInUpgradeMode) {
+            logger.debug("Model stats not persisted as ml upgrade mode is enabled");
+            return;
+        }
+
         if (verifyIndicesExistAndPrimaryShardsAreActive(clusterState, indexNameExpressionResolver) == false) {
             try {
                 logger.debug("About to create the stats index as it does not exist yet");
@@ -251,5 +265,4 @@ public class TrainedModelStatsService {
         }
         return null;
     }
-
 }

+ 118 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsServiceTests.java

@@ -7,6 +7,9 @@
 package org.elasticsearch.xpack.ml.inference;
 
 import org.elasticsearch.Version;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.client.OriginSettingClient;
+import org.elasticsearch.cluster.ClusterChangedEvent;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.AliasMetadata;
@@ -19,13 +22,24 @@ import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingTable;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
+import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.MlStatsIndex;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
+import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
+
+import java.time.Instant;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 
 public class TrainedModelStatsServiceTests extends ESTestCase {
 
@@ -119,6 +133,110 @@ public class TrainedModelStatsServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateStatsUpgradeMode() {
+        String aliasName = MlStatsIndex.writeAlias();
+        String concreteIndex = ".ml-stats-000001";
+        IndexNameExpressionResolver resolver = new IndexNameExpressionResolver();
+
+        // create a valid index routing so persistence will occur
+        RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
+        addToRoutingTable(concreteIndex, routingTableBuilder);
+        RoutingTable routingTable = routingTableBuilder.build();
+
+        // cannot mock OriginSettingClient as it is final so mock the client
+        Client client = mock(Client.class);
+        OriginSettingClient originSettingClient = new OriginSettingClient(client, "modelstatsservicetests");
+        ClusterService clusterService = mock(ClusterService.class);
+        ThreadPool threadPool = mock(ThreadPool.class);
+        ResultsPersisterService persisterService = mock(ResultsPersisterService.class);
+
+        TrainedModelStatsService service = new TrainedModelStatsService(persisterService,
+            originSettingClient, resolver, clusterService, threadPool);
+
+        InferenceStats.Accumulator accumulator = new InferenceStats.Accumulator("testUpdateStatsUpgradeMode", "test-node", 1L);
+
+        {
+            IndexMetadata.Builder indexMetadata = IndexMetadata.builder(concreteIndex)
+                .putAlias(AliasMetadata.builder(aliasName).isHidden(true).build())
+                .settings(Settings.builder()
+                    .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
+                    .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
+                    .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+                );
+            Metadata.Builder metadata = Metadata.builder().put(indexMetadata);
+
+            ClusterState clusterState = ClusterState.builder(new ClusterName("upgrade-mode-test-initial-state"))
+                .routingTable(routingTable)
+                .metadata(metadata)
+                .build();
+            ClusterChangedEvent change = new ClusterChangedEvent("created-from-test", clusterState, clusterState);
+
+            service.setClusterState(change);
+
+            // queue some stats to be persisted
+            service.queueStats(accumulator.currentStats(Instant.now()), false);
+
+            service.updateStats();
+            verify(persisterService, times(1)).bulkIndexWithRetry(any(), any(), any(), any());
+        }
+        {
+            // test with upgrade mode turned on
+
+            IndexMetadata.Builder indexMetadata = IndexMetadata.builder(concreteIndex)
+                .putAlias(AliasMetadata.builder(aliasName).isHidden(true).build())
+                .settings(Settings.builder()
+                    .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
+                    .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
+                    .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+                );
+
+            // now set the upgrade mode
+            Metadata.Builder metadata = Metadata.builder()
+                .put(indexMetadata)
+                .putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(true).build());
+
+            ClusterState clusterState = ClusterState.builder(new ClusterName("upgrade-mode-test-upgrade-enabled"))
+                .routingTable(routingTable)
+                .metadata(metadata)
+                .build();
+            ClusterChangedEvent change = new ClusterChangedEvent("created-from-test", clusterState, clusterState);
+
+            service.setClusterState(change);
+
+            // queue some stats to be persisted
+            service.queueStats(accumulator.currentStats(Instant.now()), false);
+
+            service.updateStats();
+            verify(persisterService, times(1)).bulkIndexWithRetry(any(), any(), any(), any());
+        }
+        {
+            // This time turn off upgrade mode
+
+            IndexMetadata.Builder indexMetadata = IndexMetadata.builder(concreteIndex)
+                .putAlias(AliasMetadata.builder(aliasName).isHidden(true).build())
+                .settings(Settings.builder()
+                    .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
+                    .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
+                    .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+                );
+
+            Metadata.Builder metadata = Metadata.builder()
+                .put(indexMetadata)
+                .putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(false).build());
+
+            ClusterState clusterState = ClusterState.builder(new ClusterName("upgrade-mode-test-upgrade-disabled"))
+                .routingTable(routingTable)
+                .metadata(metadata)
+                .build();
+
+            ClusterChangedEvent change = new ClusterChangedEvent("created-from-test", clusterState, clusterState);
+
+            service.setClusterState(change);
+            service.updateStats();
+            verify(persisterService, times(2)).bulkIndexWithRetry(any(), any(), any(), any());
+        }
+    }
+
     private static void addToRoutingTable(String concreteIndex, RoutingTable.Builder routingTable) {
         Index index = new Index(concreteIndex, "_uuid");
         ShardId shardId = new ShardId(index, 0);