|
@@ -7,6 +7,9 @@
|
|
|
package org.elasticsearch.xpack.ml.inference;
|
|
|
|
|
|
import org.elasticsearch.Version;
|
|
|
+import org.elasticsearch.client.Client;
|
|
|
+import org.elasticsearch.client.OriginSettingClient;
|
|
|
+import org.elasticsearch.cluster.ClusterChangedEvent;
|
|
|
import org.elasticsearch.cluster.ClusterName;
|
|
|
import org.elasticsearch.cluster.ClusterState;
|
|
|
import org.elasticsearch.cluster.metadata.AliasMetadata;
|
|
@@ -19,13 +22,24 @@ import org.elasticsearch.cluster.routing.RecoverySource;
|
|
|
import org.elasticsearch.cluster.routing.RoutingTable;
|
|
|
import org.elasticsearch.cluster.routing.ShardRouting;
|
|
|
import org.elasticsearch.cluster.routing.UnassignedInfo;
|
|
|
+import org.elasticsearch.cluster.service.ClusterService;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.index.Index;
|
|
|
import org.elasticsearch.index.shard.ShardId;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
+import org.elasticsearch.threadpool.ThreadPool;
|
|
|
+import org.elasticsearch.xpack.core.ml.MlMetadata;
|
|
|
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
|
|
+import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
|
|
+
|
|
|
+import java.time.Instant;
|
|
|
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
+import static org.mockito.Mockito.any;
|
|
|
+import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.times;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
|
|
|
public class TrainedModelStatsServiceTests extends ESTestCase {
|
|
|
|
|
@@ -119,6 +133,110 @@ public class TrainedModelStatsServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testUpdateStatsUpgradeMode() {
|
|
|
+ String aliasName = MlStatsIndex.writeAlias();
|
|
|
+ String concreteIndex = ".ml-stats-000001";
|
|
|
+ IndexNameExpressionResolver resolver = new IndexNameExpressionResolver();
|
|
|
+
|
|
|
+ // create a valid index routing so persistence will occur
|
|
|
+ RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
|
|
|
+ addToRoutingTable(concreteIndex, routingTableBuilder);
|
|
|
+ RoutingTable routingTable = routingTableBuilder.build();
|
|
|
+
|
|
|
+ // cannot mock OriginSettingClient as it is final so mock the client
|
|
|
+ Client client = mock(Client.class);
|
|
|
+ OriginSettingClient originSettingClient = new OriginSettingClient(client, "modelstatsservicetests");
|
|
|
+ ClusterService clusterService = mock(ClusterService.class);
|
|
|
+ ThreadPool threadPool = mock(ThreadPool.class);
|
|
|
+ ResultsPersisterService persisterService = mock(ResultsPersisterService.class);
|
|
|
+
|
|
|
+ TrainedModelStatsService service = new TrainedModelStatsService(persisterService,
|
|
|
+ originSettingClient, resolver, clusterService, threadPool);
|
|
|
+
|
|
|
+ InferenceStats.Accumulator accumulator = new InferenceStats.Accumulator("testUpdateStatsUpgradeMode", "test-node", 1L);
|
|
|
+
|
|
|
+ {
|
|
|
+ IndexMetadata.Builder indexMetadata = IndexMetadata.builder(concreteIndex)
|
|
|
+ .putAlias(AliasMetadata.builder(aliasName).isHidden(true).build())
|
|
|
+ .settings(Settings.builder()
|
|
|
+ .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
|
|
|
+ .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
|
|
|
+ .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
|
|
|
+ );
|
|
|
+ Metadata.Builder metadata = Metadata.builder().put(indexMetadata);
|
|
|
+
|
|
|
+ ClusterState clusterState = ClusterState.builder(new ClusterName("upgrade-mode-test-initial-state"))
|
|
|
+ .routingTable(routingTable)
|
|
|
+ .metadata(metadata)
|
|
|
+ .build();
|
|
|
+ ClusterChangedEvent change = new ClusterChangedEvent("created-from-test", clusterState, clusterState);
|
|
|
+
|
|
|
+ service.setClusterState(change);
|
|
|
+
|
|
|
+ // queue some stats to be persisted
|
|
|
+ service.queueStats(accumulator.currentStats(Instant.now()), false);
|
|
|
+
|
|
|
+ service.updateStats();
|
|
|
+ verify(persisterService, times(1)).bulkIndexWithRetry(any(), any(), any(), any());
|
|
|
+ }
|
|
|
+ {
|
|
|
+ // test with upgrade mode turned on
|
|
|
+
|
|
|
+ IndexMetadata.Builder indexMetadata = IndexMetadata.builder(concreteIndex)
|
|
|
+ .putAlias(AliasMetadata.builder(aliasName).isHidden(true).build())
|
|
|
+ .settings(Settings.builder()
|
|
|
+ .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
|
|
|
+ .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
|
|
|
+ .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
|
|
|
+ );
|
|
|
+
|
|
|
+ // now set the upgrade mode
|
|
|
+ Metadata.Builder metadata = Metadata.builder()
|
|
|
+ .put(indexMetadata)
|
|
|
+ .putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(true).build());
|
|
|
+
|
|
|
+ ClusterState clusterState = ClusterState.builder(new ClusterName("upgrade-mode-test-upgrade-enabled"))
|
|
|
+ .routingTable(routingTable)
|
|
|
+ .metadata(metadata)
|
|
|
+ .build();
|
|
|
+ ClusterChangedEvent change = new ClusterChangedEvent("created-from-test", clusterState, clusterState);
|
|
|
+
|
|
|
+ service.setClusterState(change);
|
|
|
+
|
|
|
+ // queue some stats to be persisted
|
|
|
+ service.queueStats(accumulator.currentStats(Instant.now()), false);
|
|
|
+
|
|
|
+ service.updateStats();
|
|
|
+ verify(persisterService, times(1)).bulkIndexWithRetry(any(), any(), any(), any());
|
|
|
+ }
|
|
|
+ {
|
|
|
+ // This time turn off upgrade mode
|
|
|
+
|
|
|
+ IndexMetadata.Builder indexMetadata = IndexMetadata.builder(concreteIndex)
|
|
|
+ .putAlias(AliasMetadata.builder(aliasName).isHidden(true).build())
|
|
|
+ .settings(Settings.builder()
|
|
|
+ .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
|
|
|
+ .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
|
|
|
+ .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
|
|
|
+ );
|
|
|
+
|
|
|
+ Metadata.Builder metadata = Metadata.builder()
|
|
|
+ .put(indexMetadata)
|
|
|
+ .putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(false).build());
|
|
|
+
|
|
|
+ ClusterState clusterState = ClusterState.builder(new ClusterName("upgrade-mode-test-upgrade-disabled"))
|
|
|
+ .routingTable(routingTable)
|
|
|
+ .metadata(metadata)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ ClusterChangedEvent change = new ClusterChangedEvent("created-from-test", clusterState, clusterState);
|
|
|
+
|
|
|
+ service.setClusterState(change);
|
|
|
+ service.updateStats();
|
|
|
+ verify(persisterService, times(2)).bulkIndexWithRetry(any(), any(), any(), any());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private static void addToRoutingTable(String concreteIndex, RoutingTable.Builder routingTable) {
|
|
|
Index index = new Index(concreteIndex, "_uuid");
|
|
|
ShardId shardId = new ShardId(index, 0);
|