Browse Source

[ML][Data Frame] using transform creation version for node assignment (#43764)

* [ML][Data Frame] using transform creation version for node assignment

* removing unused imports

* Addressing PR comment
Benjamin Trent 6 years ago
parent
commit
d38a48bd21

+ 26 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransform.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.dataframe.transforms;
 
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.AbstractDiffable;
+import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@@ -22,22 +23,35 @@ import java.util.Objects;
 public class DataFrameTransform extends AbstractDiffable<DataFrameTransform> implements XPackPlugin.XPackPersistentTaskParams {
 
     public static final String NAME = DataFrameField.TASK_NAME;
+    public static final ParseField VERSION = new ParseField(DataFrameField.VERSION);
 
     private final String transformId;
+    private final Version version;
 
     public static final ConstructingObjectParser<DataFrameTransform, Void> PARSER = new ConstructingObjectParser<>(NAME,
-            a -> new DataFrameTransform((String) a[0]));
+            a -> new DataFrameTransform((String) a[0], (String) a[1]));
 
     static {
         PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameField.ID);
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), VERSION);
     }
 
-    public DataFrameTransform(String transformId) {
+    private DataFrameTransform(String transformId, String version) {
+        this(transformId, version == null ? null : Version.fromString(version));
+    }
+
+    public DataFrameTransform(String transformId, Version version) {
         this.transformId = transformId;
+        this.version = version == null ? Version.V_7_2_0 : version;
     }
 
     public DataFrameTransform(StreamInput in) throws IOException {
         this.transformId  = in.readString();
+        if (in.getVersion().onOrAfter(Version.V_7_3_0)) {
+            this.version = Version.readVersion(in);
+        } else {
+            this.version = Version.V_7_2_0;
+        }
     }
 
     @Override
@@ -53,12 +67,16 @@ public class DataFrameTransform extends AbstractDiffable<DataFrameTransform> imp
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeString(transformId);
+        if (out.getVersion().onOrAfter(Version.V_7_3_0)) {
+            Version.writeVersion(version, out);
+        }
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(DataFrameField.ID.getPreferredName(), transformId);
+        builder.field(VERSION.getPreferredName(), version);
         builder.endObject();
         return builder;
     }
@@ -67,6 +85,10 @@ public class DataFrameTransform extends AbstractDiffable<DataFrameTransform> imp
         return transformId;
     }
 
+    public Version getVersion() {
+        return version;
+    }
+
     public static DataFrameTransform fromXContent(XContentParser parser) throws IOException {
         return PARSER.parse(parser, null);
     }
@@ -83,11 +105,11 @@ public class DataFrameTransform extends AbstractDiffable<DataFrameTransform> imp
 
         DataFrameTransform that = (DataFrameTransform) other;
 
-        return Objects.equals(this.transformId, that.transformId);
+        return Objects.equals(this.transformId, that.transformId) && Objects.equals(this.version, that.version);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(transformId);
+        return Objects.hash(transformId, version);
     }
 }

+ 52 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformTests.java

@@ -0,0 +1,52 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.core.dataframe.transforms;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.Writeable.Reader;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class DataFrameTransformTests extends AbstractSerializingDataFrameTestCase<DataFrameTransform> {
+
+    @Override
+    protected DataFrameTransform doParseInstance(XContentParser parser) throws IOException {
+        return DataFrameTransform.PARSER.apply(parser, null);
+    }
+
+    @Override
+    protected DataFrameTransform createTestInstance() {
+        return new DataFrameTransform(randomAlphaOfLength(10), randomBoolean() ? null : Version.CURRENT);
+    }
+
+    @Override
+    protected Reader<DataFrameTransform> instanceReader() {
+        return DataFrameTransform::new;
+    }
+
+    public void testBackwardsSerialization() throws IOException {
+        for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) {
+            DataFrameTransform transformTask = createTestInstance();
+            try (BytesStreamOutput output = new BytesStreamOutput()) {
+                output.setVersion(Version.V_7_2_0);
+                transformTask.writeTo(output);
+                try (StreamInput in = output.bytes().streamInput()) {
+                    in.setVersion(Version.V_7_2_0);
+                    // Since the old version does not have the version serialized, the version NOW is 7.2.0
+                    DataFrameTransform streamedTask = new DataFrameTransform(in);
+                    assertThat(streamedTask.getVersion(), equalTo(Version.V_7_2_0));
+                    assertThat(streamedTask.getId(), equalTo(transformTask.getId()));
+                }
+            }
+        }
+    }
+}

+ 12 - 4
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/action/TransportStartDataFrameTransformAction.java

@@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse;
 import org.elasticsearch.action.support.ActionFilters;
@@ -50,6 +51,7 @@ import java.io.IOException;
 import java.time.Clock;
 import java.util.Collection;
 import java.util.Map;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 
@@ -102,12 +104,14 @@ public class TransportStartDataFrameTransformAction extends
             listener.onFailure(LicenseUtils.newComplianceException(XPackField.DATA_FRAME));
             return;
         }
-        final DataFrameTransform transformTask = createDataFrameTransform(request.getId(), threadPool);
+        final AtomicReference<DataFrameTransform> transformTaskHolder = new AtomicReference<>();
 
-        // <3> Wait for the allocated task's state to STARTED
+        // <4> Wait for the allocated task's state to STARTED
         ActionListener<PersistentTasksCustomMetaData.PersistentTask<DataFrameTransform>> newPersistentTaskActionListener =
             ActionListener.wrap(
                 task -> {
+                    DataFrameTransform transformTask = transformTaskHolder.get();
+                    assert transformTask != null;
                     waitForDataFrameTaskStarted(task.getId(),
                         transformTask,
                         request.timeout(),
@@ -121,6 +125,8 @@ public class TransportStartDataFrameTransformAction extends
         // <3> Create the task in cluster state so that it will start executing on the node
         ActionListener<Void> createOrGetIndexListener = ActionListener.wrap(
             unused -> {
+                DataFrameTransform transformTask = transformTaskHolder.get();
+                assert transformTask != null;
                 PersistentTasksCustomMetaData.PersistentTask<DataFrameTransform> existingTask =
                     getExistingTask(transformTask.getId(), state);
                 if (existingTask == null) {
@@ -179,6 +185,8 @@ public class TransportStartDataFrameTransformAction extends
                     ));
                     return;
                 }
+
+                transformTaskHolder.set(createDataFrameTransform(config.getId(), config.getVersion()));
                 final String destinationIndex = config.getDestination().getIndex();
                 String[] dest = indexNameExpressionResolver.concreteIndexNames(state,
                     IndicesOptions.lenientExpandOpen(),
@@ -247,8 +255,8 @@ public class TransportStartDataFrameTransformAction extends
         return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
     }
 
-    private static DataFrameTransform createDataFrameTransform(String transformId, ThreadPool threadPool) {
-        return new DataFrameTransform(transformId);
+    private static DataFrameTransform createDataFrameTransform(String transformId, Version transformVersion) {
+        return new DataFrameTransform(transformId, transformVersion);
     }
 
     @SuppressWarnings("unchecked")

+ 5 - 1
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformPersistentTasksExecutor.java

@@ -15,6 +15,7 @@ import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.persistent.AllocatedPersistentTask;
@@ -84,7 +85,10 @@ public class DataFrameTransformPersistentTasksExecutor extends PersistentTasksEx
             logger.debug(reason);
             return new PersistentTasksCustomMetaData.Assignment(null, reason);
         }
-        return super.getAssignment(params, clusterState);
+        DiscoveryNode discoveryNode = selectLeastLoadedNode(clusterState, (node) ->
+            node.isDataNode() && node.getVersion().onOrAfter(params.getVersion())
+        );
+        return discoveryNode == null ? NO_NODE_FOUND : new PersistentTasksCustomMetaData.Assignment(discoveryNode.getId(), "");
     }
 
     static List<String> verifyIndicesPrimaryShardsAreActive(ClusterState clusterState) {

+ 2 - 2
x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/action/DataFrameNodesTests.java

@@ -31,10 +31,10 @@ public class DataFrameNodesTests extends ESTestCase {
 
         PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
         tasksBuilder.addTask(dataFrameIdFoo,
-                DataFrameField.TASK_NAME, new DataFrameTransform(dataFrameIdFoo),
+                DataFrameField.TASK_NAME, new DataFrameTransform(dataFrameIdFoo, Version.CURRENT),
                 new PersistentTasksCustomMetaData.Assignment("node-1", "test assignment"));
         tasksBuilder.addTask(dataFrameIdBar,
-                DataFrameField.TASK_NAME, new DataFrameTransform(dataFrameIdBar),
+                DataFrameField.TASK_NAME, new DataFrameTransform(dataFrameIdBar, Version.CURRENT),
                 new PersistentTasksCustomMetaData.Assignment("node-2", "test assignment"));
         tasksBuilder.addTask("test-task1", "testTasks", new PersistentTaskParams() {
                 @Override

+ 83 - 0
x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformPersistentTasksExecutorTests.java

@@ -7,10 +7,14 @@
 package org.elasticsearch.xpack.dataframe.transforms;
 
 import org.elasticsearch.Version;
+import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.metadata.MetaData;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
 import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.cluster.routing.RecoverySource;
@@ -20,14 +24,93 @@ import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransform;
+import org.elasticsearch.xpack.core.scheduler.SchedulerEngine;
+import org.elasticsearch.xpack.dataframe.checkpoint.DataFrameTransformsCheckpointService;
+import org.elasticsearch.xpack.dataframe.notifications.DataFrameAuditor;
 import org.elasticsearch.xpack.dataframe.persistence.DataFrameInternalIndex;
+import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
+import java.util.Set;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.mockito.Mockito.mock;
 
 public class DataFrameTransformPersistentTasksExecutorTests extends ESTestCase {
 
+    public void testNodeVersionAssignment() {
+        MetaData.Builder metaData = MetaData.builder();
+        RoutingTable.Builder routingTable = RoutingTable.builder();
+        addIndices(metaData, routingTable);
+        PersistentTasksCustomMetaData.Builder pTasksBuilder = PersistentTasksCustomMetaData.builder()
+            .addTask("data-frame-task-1",
+                DataFrameTransform.NAME,
+                new DataFrameTransform("data-frame-task-1", Version.CURRENT),
+                new PersistentTasksCustomMetaData.Assignment("current-data-node-with-1-tasks", ""))
+            .addTask("data-frame-task-2",
+                DataFrameTransform.NAME,
+                new DataFrameTransform("data-frame-task-2", Version.CURRENT),
+                new PersistentTasksCustomMetaData.Assignment("current-data-node-with-2-tasks", ""))
+            .addTask("data-frame-task-3",
+                DataFrameTransform.NAME,
+                new DataFrameTransform("data-frame-task-3", Version.CURRENT),
+                new PersistentTasksCustomMetaData.Assignment("current-data-node-with-2-tasks", ""));
+
+        PersistentTasksCustomMetaData pTasks = pTasksBuilder.build();
+
+        metaData.putCustom(PersistentTasksCustomMetaData.TYPE, pTasks);
+
+        DiscoveryNodes.Builder nodes = DiscoveryNodes.builder()
+            .add(new DiscoveryNode("past-data-node-1",
+                buildNewFakeTransportAddress(),
+                Collections.emptyMap(),
+                Set.of(DiscoveryNodeRole.DATA_ROLE, DiscoveryNodeRole.MASTER_ROLE),
+                Version.V_7_2_0))
+            .add(new DiscoveryNode("current-data-node-with-2-tasks",
+                buildNewFakeTransportAddress(),
+                Collections.emptyMap(),
+                Set.of(DiscoveryNodeRole.DATA_ROLE, DiscoveryNodeRole.MASTER_ROLE),
+                Version.CURRENT))
+            .add(new DiscoveryNode("non-data-node-1",
+                buildNewFakeTransportAddress(),
+                Collections.emptyMap(),
+                Set.of(DiscoveryNodeRole.MASTER_ROLE),
+                Version.CURRENT))
+            .add(new DiscoveryNode("current-data-node-with-1-tasks",
+                buildNewFakeTransportAddress(),
+                Collections.emptyMap(),
+                Set.of(DiscoveryNodeRole.DATA_ROLE, DiscoveryNodeRole.MASTER_ROLE),
+                Version.CURRENT));
+
+        ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name"))
+            .nodes(nodes);
+        csBuilder.routingTable(routingTable.build());
+        csBuilder.metaData(metaData);
+
+        ClusterState cs = csBuilder.build();
+        Client client = mock(Client.class);
+        DataFrameTransformsConfigManager transformsConfigManager = new DataFrameTransformsConfigManager(client, xContentRegistry());
+        DataFrameTransformsCheckpointService dataFrameTransformsCheckpointService = new DataFrameTransformsCheckpointService(client,
+            transformsConfigManager);
+
+        DataFrameTransformPersistentTasksExecutor executor = new DataFrameTransformPersistentTasksExecutor(client,
+            transformsConfigManager,
+            dataFrameTransformsCheckpointService, mock(SchedulerEngine.class),
+            new DataFrameAuditor(client, ""),
+            mock(ThreadPool.class));
+
+        assertThat(executor.getAssignment(new DataFrameTransform("new-task-id", Version.CURRENT), cs).getExecutorNode(),
+            equalTo("current-data-node-with-1-tasks"));
+        assertThat(executor.getAssignment(new DataFrameTransform("new-old-task-id", Version.V_7_2_0), cs).getExecutorNode(),
+            equalTo("past-data-node-1"));
+    }
+
     public void testVerifyIndicesPrimaryShardsAreActive() {
         MetaData.Builder metaData = MetaData.builder();
         RoutingTable.Builder routingTable = RoutingTable.builder();