Browse Source

Merge pull request #16356 from nik9000/task_status

Add task status
Nik Everett 9 years ago
parent
commit
9ef7ff4904
27 changed files with 471 additions and 93 deletions
  1. 2 1
      core/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/ListTasksResponse.java
  2. 29 3
      core/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/TaskInfo.java
  3. 14 0
      core/src/main/java/org/elasticsearch/action/support/ChildTaskActionRequest.java
  4. 6 1
      core/src/main/java/org/elasticsearch/action/support/replication/ReplicationRequest.java
  5. 97 0
      core/src/main/java/org/elasticsearch/action/support/replication/ReplicationTask.java
  6. 53 9
      core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java
  7. 8 0
      core/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java
  8. 8 0
      core/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java
  9. 30 2
      core/src/main/java/org/elasticsearch/tasks/Task.java
  10. 6 2
      core/src/main/java/org/elasticsearch/transport/TransportService.java
  11. 1 0
      core/src/main/java/org/elasticsearch/transport/local/LocalTransport.java
  12. 1 0
      core/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java
  13. 62 0
      core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java
  14. 1 2
      core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java
  15. 107 39
      core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java
  16. 3 3
      core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java
  17. 2 2
      core/src/test/java/org/elasticsearch/client/transport/TransportClientNodesServiceTests.java
  18. 3 1
      core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java
  19. 1 3
      core/src/test/java/org/elasticsearch/discovery/zen/publish/PublishClusterStateActionTests.java
  20. 2 2
      core/src/test/java/org/elasticsearch/transport/TransportModuleTests.java
  21. 1 1
      core/src/test/java/org/elasticsearch/transport/local/SimpleLocalTransportTests.java
  22. 6 4
      core/src/test/java/org/elasticsearch/transport/netty/NettyScheduledPingTests.java
  23. 1 3
      core/src/test/java/org/elasticsearch/transport/netty/SimpleNettyTransportTests.java
  24. 3 3
      modules/lang-groovy/src/test/java/org/elasticsearch/messy/tests/IndicesRequestTests.java
  25. 2 5
      plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java
  26. 1 5
      plugins/discovery-gce/src/test/java/org/elasticsearch/discovery/gce/GceDiscoveryTests.java
  27. 21 2
      test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java

+ 2 - 1
core/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/ListTasksResponse.java

@@ -20,6 +20,7 @@
 package org.elasticsearch.action.admin.cluster.node.tasks.list;
 package org.elasticsearch.action.admin.cluster.node.tasks.list;
 
 
 import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
 import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
+
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.support.tasks.BaseTasksResponse;
 import org.elasticsearch.action.support.tasks.BaseTasksResponse;
@@ -111,7 +112,7 @@ public class ListTasksResponse extends BaseTasksResponse implements ToXContent {
 
 
         if (getNodeFailures() != null && getNodeFailures().size() > 0) {
         if (getNodeFailures() != null && getNodeFailures().size() > 0) {
             builder.startArray("node_failures");
             builder.startArray("node_failures");
-            for (FailedNodeException ex : getNodeFailures()){
+            for (FailedNodeException ex : getNodeFailures()) {
                 builder.value(ex);
                 builder.value(ex);
             }
             }
             builder.endArray();
             builder.endArray();

+ 29 - 3
core/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/TaskInfo.java

@@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.tasks.Task;
 
 
 import java.io.IOException;
 import java.io.IOException;
 
 
@@ -48,20 +49,23 @@ public class TaskInfo implements Writeable<TaskInfo>, ToXContent {
 
 
     private final String description;
     private final String description;
 
 
+    private final Task.Status status;
+
     private final String parentNode;
     private final String parentNode;
 
 
     private final long parentId;
     private final long parentId;
 
 
-    public TaskInfo(DiscoveryNode node, long id, String type, String action, String description) {
-        this(node, id, type, action, description, null, -1L);
+    public TaskInfo(DiscoveryNode node, long id, String type, String action, String description, Task.Status status) {
+        this(node, id, type, action, description, status, null, -1L);
     }
     }
 
 
-    public TaskInfo(DiscoveryNode node, long id, String type, String action, String description, String parentNode, long parentId) {
+    public TaskInfo(DiscoveryNode node, long id, String type, String action, String description, Task.Status status, String parentNode, long parentId) {
         this.node = node;
         this.node = node;
         this.id = id;
         this.id = id;
         this.type = type;
         this.type = type;
         this.action = action;
         this.action = action;
         this.description = description;
         this.description = description;
+        this.status = status;
         this.parentNode = parentNode;
         this.parentNode = parentNode;
         this.parentId = parentId;
         this.parentId = parentId;
     }
     }
@@ -72,6 +76,11 @@ public class TaskInfo implements Writeable<TaskInfo>, ToXContent {
         type = in.readString();
         type = in.readString();
         action = in.readString();
         action = in.readString();
         description = in.readOptionalString();
         description = in.readOptionalString();
+        if (in.readBoolean()) {
+            status = in.readTaskStatus();
+        } else {
+            status = null;
+        }
         parentNode = in.readOptionalString();
         parentNode = in.readOptionalString();
         parentId = in.readLong();
         parentId = in.readLong();
     }
     }
@@ -96,6 +105,14 @@ public class TaskInfo implements Writeable<TaskInfo>, ToXContent {
         return description;
         return description;
     }
     }
 
 
+    /**
+     * The status of the running task. Only available if TaskInfos were build
+     * with the detailed flag.
+     */
+    public Task.Status getStatus() {
+        return status;
+    }
+
     public String getParentNode() {
     public String getParentNode() {
         return parentNode;
         return parentNode;
     }
     }
@@ -116,6 +133,12 @@ public class TaskInfo implements Writeable<TaskInfo>, ToXContent {
         out.writeString(type);
         out.writeString(type);
         out.writeString(action);
         out.writeString(action);
         out.writeOptionalString(description);
         out.writeOptionalString(description);
+        if (status != null) {
+            out.writeBoolean(true);
+            out.writeTaskStatus(status);
+        } else {
+            out.writeBoolean(false);
+        }
         out.writeOptionalString(parentNode);
         out.writeOptionalString(parentNode);
         out.writeLong(parentId);
         out.writeLong(parentId);
     }
     }
@@ -127,6 +150,9 @@ public class TaskInfo implements Writeable<TaskInfo>, ToXContent {
         builder.field("id", id);
         builder.field("id", id);
         builder.field("type", type);
         builder.field("type", type);
         builder.field("action", action);
         builder.field("action", action);
+        if (status != null) {
+            builder.field("status", status, params);
+        }
         if (description != null) {
         if (description != null) {
             builder.field("description", description);
             builder.field("description", description);
         }
         }

+ 14 - 0
core/src/main/java/org/elasticsearch/action/support/ChildTaskActionRequest.java

@@ -44,6 +44,20 @@ public abstract class ChildTaskActionRequest<Request extends ActionRequest<Reque
         this.parentTaskId = parentTaskId;
         this.parentTaskId = parentTaskId;
     }
     }
 
 
+    /**
+     * The node that owns the parent task.
+     */
+    public String getParentTaskNode() {
+        return parentTaskNode;
+    }
+
+    /**
+     * The task id of the parent task on the parent node.
+     */
+    public long getParentTaskId() {
+        return parentTaskId;
+    }
+
     @Override
     @Override
     public void readFrom(StreamInput in) throws IOException {
     public void readFrom(StreamInput in) throws IOException {
         super.readFrom(in);
         super.readFrom(in);

+ 6 - 1
core/src/main/java/org/elasticsearch/action/support/replication/ReplicationRequest.java

@@ -19,7 +19,6 @@
 
 
 package org.elasticsearch.action.support.replication;
 package org.elasticsearch.action.support.replication;
 
 
-import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.IndicesRequest;
 import org.elasticsearch.action.IndicesRequest;
 import org.elasticsearch.action.WriteConsistencyLevel;
 import org.elasticsearch.action.WriteConsistencyLevel;
@@ -30,6 +29,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.tasks.Task;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
@@ -195,6 +195,11 @@ public abstract class ReplicationRequest<Request extends ReplicationRequest<Requ
         out.writeVLong(routedBasedOnClusterVersion);
         out.writeVLong(routedBasedOnClusterVersion);
     }
     }
 
 
+    @Override
+    public Task createTask(long id, String type, String action) {
+        return new ReplicationTask(id, type, action, this::getDescription, getParentTaskNode(), getParentTaskId());
+    }
+
     /**
     /**
      * Sets the target shard id for the request. The shard id is set when a
      * Sets the target shard id for the request. The shard id is set when a
      * index/delete request is resolved by the transport action
      * index/delete request is resolved by the transport action

+ 97 - 0
core/src/main/java/org/elasticsearch/action/support/replication/ReplicationTask.java

@@ -0,0 +1,97 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.action.support.replication;
+
+import org.elasticsearch.common.inject.Provider;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.tasks.Task;
+
+import java.io.IOException;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Task that tracks replication actions.
+ */
+public class ReplicationTask extends Task {
+    private volatile String phase = "starting";
+
+    public ReplicationTask(long id, String type, String action, Provider<String> description, String parentNode, long parentId) {
+        super(id, type, action, description, parentNode, parentId);
+    }
+
+    /**
+     * Set the current phase of the task.
+     */
+    public void setPhase(String phase) {
+        this.phase = phase;
+    }
+
+    /**
+     * Get the current phase of the task.
+     */
+    public String getPhase() {
+        return phase;
+    }
+
+    @Override
+    public Status getStatus() {
+        return new Status(phase);
+    }
+
+    public static class Status implements Task.Status {
+        public static final Status PROTOTYPE = new Status("prototype");
+
+        private final String phase;
+
+        public Status(String phase) {
+            this.phase = requireNonNull(phase, "Phase cannot be null");
+        }
+
+        public Status(StreamInput in) throws IOException {
+            phase = in.readString();
+        }
+
+        @Override
+        public String getWriteableName() {
+            return "replication";
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field("phase", phase);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(phase);
+        }
+
+        @Override
+        public Status readFrom(StreamInput in) throws IOException {
+            return new Status(in);
+        }
+    }
+}

+ 53 - 9
core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java

@@ -142,7 +142,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
 
 
     @Override
     @Override
     protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
     protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
-        new ReroutePhase(task, request, listener).run();
+        new ReroutePhase((ReplicationTask) task, request, listener).run();
     }
     }
 
 
     protected abstract Response newResponseInstance();
     protected abstract Response newResponseInstance();
@@ -283,14 +283,24 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
     class PrimaryOperationTransportHandler implements TransportRequestHandler<Request> {
     class PrimaryOperationTransportHandler implements TransportRequestHandler<Request> {
         @Override
         @Override
         public void messageReceived(final Request request, final TransportChannel channel) throws Exception {
         public void messageReceived(final Request request, final TransportChannel channel) throws Exception {
-            new PrimaryPhase(request, channel).run();
+            throw new UnsupportedOperationException("the task parameter is required for this operation");
+        }
+
+        @Override
+        public void messageReceived(Request request, TransportChannel channel, Task task) throws Exception {
+            new PrimaryPhase((ReplicationTask) task, request, channel).run();
         }
         }
     }
     }
 
 
     class ReplicaOperationTransportHandler implements TransportRequestHandler<ReplicaRequest> {
     class ReplicaOperationTransportHandler implements TransportRequestHandler<ReplicaRequest> {
         @Override
         @Override
         public void messageReceived(final ReplicaRequest request, final TransportChannel channel) throws Exception {
         public void messageReceived(final ReplicaRequest request, final TransportChannel channel) throws Exception {
-            new AsyncReplicaAction(request, channel).run();
+            throw new UnsupportedOperationException("the task parameter is required for this operation");
+        }
+
+        @Override
+        public void messageReceived(ReplicaRequest request, TransportChannel channel, Task task) throws Exception {
+            new AsyncReplicaAction(request, channel, (ReplicationTask) task).run();
         }
         }
     }
     }
 
 
@@ -309,13 +319,18 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
     private final class AsyncReplicaAction extends AbstractRunnable {
     private final class AsyncReplicaAction extends AbstractRunnable {
         private final ReplicaRequest request;
         private final ReplicaRequest request;
         private final TransportChannel channel;
         private final TransportChannel channel;
+        /**
+         * The task on the node with the replica shard.
+         */
+        private final ReplicationTask task;
         // important: we pass null as a timeout as failing a replica is
         // important: we pass null as a timeout as failing a replica is
         // something we want to avoid at all costs
         // something we want to avoid at all costs
         private final ClusterStateObserver observer = new ClusterStateObserver(clusterService, null, logger, threadPool.getThreadContext());
         private final ClusterStateObserver observer = new ClusterStateObserver(clusterService, null, logger, threadPool.getThreadContext());
 
 
-        AsyncReplicaAction(ReplicaRequest request, TransportChannel channel) {
+        AsyncReplicaAction(ReplicaRequest request, TransportChannel channel, ReplicationTask task) {
             this.request = request;
             this.request = request;
             this.channel = channel;
             this.channel = channel;
+            this.task = task;
         }
         }
 
 
         @Override
         @Override
@@ -385,6 +400,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
 
 
         @Override
         @Override
         protected void doRun() throws Exception {
         protected void doRun() throws Exception {
+            setPhase(task, "replica");
             assert request.shardId() != null : "request shardId must be set";
             assert request.shardId() != null : "request shardId must be set";
             try (Releasable ignored = getIndexShardReferenceOnReplica(request.shardId())) {
             try (Releasable ignored = getIndexShardReferenceOnReplica(request.shardId())) {
                 shardOperationOnReplica(request);
                 shardOperationOnReplica(request);
@@ -392,6 +408,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
                     logger.trace("action [{}] completed on shard [{}] for request [{}]", transportReplicaAction, request.shardId(), request);
                     logger.trace("action [{}] completed on shard [{}] for request [{}]", transportReplicaAction, request.shardId(), request);
                 }
                 }
             }
             }
+            setPhase(task, "finished");
             channel.sendResponse(TransportResponse.Empty.INSTANCE);
             channel.sendResponse(TransportResponse.Empty.INSTANCE);
         }
         }
     }
     }
@@ -417,15 +434,17 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
     final class ReroutePhase extends AbstractRunnable {
     final class ReroutePhase extends AbstractRunnable {
         private final ActionListener<Response> listener;
         private final ActionListener<Response> listener;
         private final Request request;
         private final Request request;
+        private final ReplicationTask task;
         private final ClusterStateObserver observer;
         private final ClusterStateObserver observer;
         private final AtomicBoolean finished = new AtomicBoolean();
         private final AtomicBoolean finished = new AtomicBoolean();
 
 
-        ReroutePhase(Task task, Request request, ActionListener<Response> listener) {
+        ReroutePhase(ReplicationTask task, Request request, ActionListener<Response> listener) {
             this.request = request;
             this.request = request;
             if (task != null) {
             if (task != null) {
                 this.request.setParentTask(clusterService.localNode().getId(), task.getId());
                 this.request.setParentTask(clusterService.localNode().getId(), task.getId());
             }
             }
             this.listener = listener;
             this.listener = listener;
+            this.task = task;
             this.observer = new ClusterStateObserver(clusterService, request.timeout(), logger, threadPool.getThreadContext());
             this.observer = new ClusterStateObserver(clusterService, request.timeout(), logger, threadPool.getThreadContext());
         }
         }
 
 
@@ -436,6 +455,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
 
 
         @Override
         @Override
         protected void doRun() {
         protected void doRun() {
+            setPhase(task, "routing");
             final ClusterState state = observer.observedState();
             final ClusterState state = observer.observedState();
             ClusterBlockException blockException = state.blocks().globalBlockedException(globalBlockLevel());
             ClusterBlockException blockException = state.blocks().globalBlockedException(globalBlockLevel());
             if (blockException != null) {
             if (blockException != null) {
@@ -467,6 +487,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
             }
             }
             final DiscoveryNode node = state.nodes().get(primary.currentNodeId());
             final DiscoveryNode node = state.nodes().get(primary.currentNodeId());
             if (primary.currentNodeId().equals(state.nodes().localNodeId())) {
             if (primary.currentNodeId().equals(state.nodes().localNodeId())) {
+                setPhase(task, "waiting_on_primary");
                 if (logger.isTraceEnabled()) {
                 if (logger.isTraceEnabled()) {
                     logger.trace("send action [{}] on primary [{}] for request [{}] with cluster state version [{}] to [{}] ", transportPrimaryAction, request.shardId(), request, state.version(), primary.currentNodeId());
                     logger.trace("send action [{}] on primary [{}] for request [{}] with cluster state version [{}] to [{}] ", transportPrimaryAction, request.shardId(), request, state.version(), primary.currentNodeId());
                 }
                 }
@@ -484,6 +505,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
                 if (logger.isTraceEnabled()) {
                 if (logger.isTraceEnabled()) {
                     logger.trace("send action [{}] on primary [{}] for request [{}] with cluster state version [{}] to [{}]", actionName, request.shardId(), request, state.version(), primary.currentNodeId());
                     logger.trace("send action [{}] on primary [{}] for request [{}] with cluster state version [{}] to [{}]", actionName, request.shardId(), request, state.version(), primary.currentNodeId());
                 }
                 }
+                setPhase(task, "rerouted");
                 performAction(node, actionName, false);
                 performAction(node, actionName, false);
             }
             }
         }
         }
@@ -540,6 +562,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
                 finishAsFailed(failure);
                 finishAsFailed(failure);
                 return;
                 return;
             }
             }
+            setPhase(task, "waiting_for_retry");
             final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext();
             final ThreadContext.StoredContext context = threadPool.getThreadContext().newStoredContext();
             observer.waitForNextChange(new ClusterStateObserver.Listener() {
             observer.waitForNextChange(new ClusterStateObserver.Listener() {
                 @Override
                 @Override
@@ -564,6 +587,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
 
 
         void finishAsFailed(Throwable failure) {
         void finishAsFailed(Throwable failure) {
             if (finished.compareAndSet(false, true)) {
             if (finished.compareAndSet(false, true)) {
+                setPhase(task, "failed");
                 logger.trace("operation failed. action [{}], request [{}]", failure, actionName, request);
                 logger.trace("operation failed. action [{}], request [{}]", failure, actionName, request);
                 listener.onFailure(failure);
                 listener.onFailure(failure);
             } else {
             } else {
@@ -574,6 +598,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         void finishWithUnexpectedFailure(Throwable failure) {
         void finishWithUnexpectedFailure(Throwable failure) {
             logger.warn("unexpected error during the primary phase for action [{}], request [{}]", failure, actionName, request);
             logger.warn("unexpected error during the primary phase for action [{}], request [{}]", failure, actionName, request);
             if (finished.compareAndSet(false, true)) {
             if (finished.compareAndSet(false, true)) {
+                setPhase(task, "failed");
                 listener.onFailure(failure);
                 listener.onFailure(failure);
             } else {
             } else {
                 assert false : "finishWithUnexpectedFailure called but operation is already finished";
                 assert false : "finishWithUnexpectedFailure called but operation is already finished";
@@ -582,6 +607,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
 
 
         void finishOnSuccess(Response response) {
         void finishOnSuccess(Response response) {
             if (finished.compareAndSet(false, true)) {
             if (finished.compareAndSet(false, true)) {
+                setPhase(task, "finished");
                 if (logger.isTraceEnabled()) {
                 if (logger.isTraceEnabled()) {
                     logger.trace("operation succeeded. action [{}],request [{}]", actionName, request);
                     logger.trace("operation succeeded. action [{}],request [{}]", actionName, request);
                 }
                 }
@@ -603,6 +629,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
      * Note that as soon as we move to replication action, state responsibility is transferred to {@link ReplicationPhase}.
      * Note that as soon as we move to replication action, state responsibility is transferred to {@link ReplicationPhase}.
      */
      */
     class PrimaryPhase extends AbstractRunnable {
     class PrimaryPhase extends AbstractRunnable {
+        private final ReplicationTask task;
         private final Request request;
         private final Request request;
         private final ShardId shardId;
         private final ShardId shardId;
         private final TransportChannel channel;
         private final TransportChannel channel;
@@ -610,8 +637,9 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         private final AtomicBoolean finished = new AtomicBoolean();
         private final AtomicBoolean finished = new AtomicBoolean();
         private IndexShardReference indexShardReference;
         private IndexShardReference indexShardReference;
 
 
-        PrimaryPhase(Request request, TransportChannel channel) {
+        PrimaryPhase(ReplicationTask task, Request request, TransportChannel channel) {
             this.state = clusterService.state();
             this.state = clusterService.state();
+            this.task = task;
             this.request = request;
             this.request = request;
             assert request.shardId() != null : "request shardId must be set prior to primary phase";
             assert request.shardId() != null : "request shardId must be set prior to primary phase";
             this.shardId = request.shardId();
             this.shardId = request.shardId();
@@ -634,6 +662,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
 
 
         @Override
         @Override
         protected void doRun() throws Exception {
         protected void doRun() throws Exception {
+            setPhase(task, "primary");
             // request shardID was set in ReroutePhase
             // request shardID was set in ReroutePhase
             final String writeConsistencyFailure = checkWriteConsistency(shardId);
             final String writeConsistencyFailure = checkWriteConsistency(shardId);
             if (writeConsistencyFailure != null) {
             if (writeConsistencyFailure != null) {
@@ -648,7 +677,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
                 if (logger.isTraceEnabled()) {
                 if (logger.isTraceEnabled()) {
                     logger.trace("action [{}] completed on shard [{}] for request [{}] with cluster state version [{}]", transportPrimaryAction, shardId, request, state.version());
                     logger.trace("action [{}] completed on shard [{}] for request [{}] with cluster state version [{}]", transportPrimaryAction, shardId, request, state.version());
                 }
                 }
-                ReplicationPhase replicationPhase = new ReplicationPhase(primaryResponse.v2(), primaryResponse.v1(), shardId, channel, indexShardReference);
+                ReplicationPhase replicationPhase = new ReplicationPhase(task, primaryResponse.v2(), primaryResponse.v1(), shardId, channel, indexShardReference);
                 finishAndMoveToReplication(replicationPhase);
                 finishAndMoveToReplication(replicationPhase);
             } else {
             } else {
                 // delegate primary phase to relocation target
                 // delegate primary phase to relocation target
@@ -728,6 +757,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
          */
          */
         void finishAsFailed(Throwable failure) {
         void finishAsFailed(Throwable failure) {
             if (finished.compareAndSet(false, true)) {
             if (finished.compareAndSet(false, true)) {
+                setPhase(task, "failed");
                 Releasables.close(indexShardReference);
                 Releasables.close(indexShardReference);
                 logger.trace("operation failed", failure);
                 logger.trace("operation failed", failure);
                 try {
                 try {
@@ -770,7 +800,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
      * relocating copies
      * relocating copies
      */
      */
     final class ReplicationPhase extends AbstractRunnable {
     final class ReplicationPhase extends AbstractRunnable {
-
+        private final ReplicationTask task;
         private final ReplicaRequest replicaRequest;
         private final ReplicaRequest replicaRequest;
         private final Response finalResponse;
         private final Response finalResponse;
         private final TransportChannel channel;
         private final TransportChannel channel;
@@ -785,8 +815,9 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         private final int totalShards;
         private final int totalShards;
         private final IndexShardReference indexShardReference;
         private final IndexShardReference indexShardReference;
 
 
-        public ReplicationPhase(ReplicaRequest replicaRequest, Response finalResponse, ShardId shardId,
+        public ReplicationPhase(ReplicationTask task, ReplicaRequest replicaRequest, Response finalResponse, ShardId shardId,
                                 TransportChannel channel, IndexShardReference indexShardReference) {
                                 TransportChannel channel, IndexShardReference indexShardReference) {
+            this.task = task;
             this.replicaRequest = replicaRequest;
             this.replicaRequest = replicaRequest;
             this.channel = channel;
             this.channel = channel;
             this.finalResponse = finalResponse;
             this.finalResponse = finalResponse;
@@ -870,6 +901,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
          */
          */
         @Override
         @Override
         protected void doRun() {
         protected void doRun() {
+            setPhase(task, "replicating");
             if (pending.get() == 0) {
             if (pending.get() == 0) {
                 doFinish();
                 doFinish();
                 return;
                 return;
@@ -981,6 +1013,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         }
         }
 
 
         private void forceFinishAsFailed(Throwable t) {
         private void forceFinishAsFailed(Throwable t) {
+            setPhase(task, "failed");
             if (finished.compareAndSet(false, true)) {
             if (finished.compareAndSet(false, true)) {
                 Releasables.close(indexShardReference);
                 Releasables.close(indexShardReference);
                 try {
                 try {
@@ -994,6 +1027,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
 
 
         private void doFinish() {
         private void doFinish() {
             if (finished.compareAndSet(false, true)) {
             if (finished.compareAndSet(false, true)) {
+                setPhase(task, "finished");
                 Releasables.close(indexShardReference);
                 Releasables.close(indexShardReference);
                 final ReplicationResponse.ShardInfo.Failure[] failuresArray;
                 final ReplicationResponse.ShardInfo.Failure[] failuresArray;
                 if (!shardReplicaFailures.isEmpty()) {
                 if (!shardReplicaFailures.isEmpty()) {
@@ -1082,4 +1116,14 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         }
         }
         indexShard.maybeFlush();
         indexShard.maybeFlush();
     }
     }
+
+    /**
+     * Sets the current phase on the task if it isn't null. Pulled into its own
+     * method because its more convenient that way.
+     */
+    static void setPhase(ReplicationTask task, String phase) {
+        if (task != null) {
+            task.setPhase(phase);
+        }
+    }
 }
 }

+ 8 - 0
core/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java

@@ -38,6 +38,7 @@ import org.elasticsearch.common.text.Text;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
 import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
 import org.elasticsearch.search.rescore.RescoreBuilder;
 import org.elasticsearch.search.rescore.RescoreBuilder;
+import org.elasticsearch.tasks.Task;
 import org.joda.time.DateTime;
 import org.joda.time.DateTime;
 import org.joda.time.DateTimeZone;
 import org.joda.time.DateTimeZone;
 
 
@@ -690,6 +691,13 @@ public abstract class StreamInput extends InputStream {
         return readNamedWriteable(ScoreFunctionBuilder.class);
         return readNamedWriteable(ScoreFunctionBuilder.class);
     }
     }
 
 
+    /**
+     * Reads a {@link Task.Status} from the current stream.
+     */
+    public Task.Status readTaskStatus() throws IOException {
+        return readNamedWriteable(Task.Status.class);
+    }
+
     /**
     /**
      * Reads a list of objects
      * Reads a list of objects
      */
      */

+ 8 - 0
core/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java

@@ -37,6 +37,7 @@ import org.elasticsearch.common.text.Text;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
 import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
 import org.elasticsearch.search.rescore.RescoreBuilder;
 import org.elasticsearch.search.rescore.RescoreBuilder;
+import org.elasticsearch.tasks.Task;
 import org.joda.time.ReadableInstant;
 import org.joda.time.ReadableInstant;
 
 
 import java.io.EOFException;
 import java.io.EOFException;
@@ -660,6 +661,13 @@ public abstract class StreamOutput extends OutputStream {
         writeNamedWriteable(scoreFunctionBuilder);
         writeNamedWriteable(scoreFunctionBuilder);
     }
     }
 
 
+    /**
+     * Writes a {@link Task.Status} to the current stream.
+     */
+    public void writeTaskStatus(Task.Status status) throws IOException {
+        writeNamedWriteable(status);
+    }
+
     /**
     /**
      * Writes the given {@link GeoPoint} to the stream
      * Writes the given {@link GeoPoint} to the stream
      */
      */

+ 30 - 2
core/src/main/java/org/elasticsearch/tasks/Task.java

@@ -23,6 +23,8 @@ package org.elasticsearch.tasks;
 import org.elasticsearch.action.admin.cluster.node.tasks.list.TaskInfo;
 import org.elasticsearch.action.admin.cluster.node.tasks.list.TaskInfo;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.inject.Provider;
 import org.elasticsearch.common.inject.Provider;
+import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.common.xcontent.ToXContent;
 
 
 /**
 /**
  * Current task information
  * Current task information
@@ -57,9 +59,24 @@ public class Task {
         this.parentId = parentId;
         this.parentId = parentId;
     }
     }
 
 
-
+    /**
+     * Build a version of the task status you can throw over the wire and back
+     * to the user.
+     *
+     * @param node
+     *            the node this task is running on
+     * @param detailed
+     *            should the information include detailed, potentially slow to
+     *            generate data?
+     */
     public TaskInfo taskInfo(DiscoveryNode node, boolean detailed) {
     public TaskInfo taskInfo(DiscoveryNode node, boolean detailed) {
-        return new TaskInfo(node, getId(), getType(), getAction(), detailed ? getDescription() : null, parentNode, parentId);
+        String description = null;
+        Task.Status status = null;
+        if (detailed) {
+            description = getDescription();
+            status = getStatus();
+        }
+        return new TaskInfo(node, getId(), getType(), getAction(), description, status, parentNode, parentId);
     }
     }
 
 
     /**
     /**
@@ -104,4 +121,15 @@ public class Task {
         return parentId;
         return parentId;
     }
     }
 
 
+    /**
+     * Build a status for this task or null if this task doesn't have status.
+     * Since most tasks don't have status this defaults to returning null. While
+     * this can never perform IO it might be a costly operation, requiring
+     * collating lists of results, etc. So only use it if you need the value.
+     */
+    public Status getStatus() {
+        return null;
+    }
+
+    public interface Status extends ToXContent, NamedWriteable<Status> {}
 }
 }

+ 6 - 2
core/src/main/java/org/elasticsearch/transport/TransportService.java

@@ -20,11 +20,13 @@
 package org.elasticsearch.transport;
 package org.elasticsearch.transport;
 
 
 import org.elasticsearch.action.admin.cluster.node.liveness.TransportLivenessAction;
 import org.elasticsearch.action.admin.cluster.node.liveness.TransportLivenessAction;
+import org.elasticsearch.action.support.replication.ReplicationTask;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.collect.MapBuilder;
 import org.elasticsearch.common.collect.MapBuilder;
 import org.elasticsearch.common.component.AbstractLifecycleComponent;
 import org.elasticsearch.common.component.AbstractLifecycleComponent;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.logging.ESLogger;
 import org.elasticsearch.common.logging.ESLogger;
 import org.elasticsearch.common.logging.Loggers;
 import org.elasticsearch.common.logging.Loggers;
 import org.elasticsearch.common.metrics.MeanMetric;
 import org.elasticsearch.common.metrics.MeanMetric;
@@ -41,6 +43,7 @@ import org.elasticsearch.common.util.concurrent.ConcurrentMapLong;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.util.concurrent.FutureUtils;
 import org.elasticsearch.common.util.concurrent.FutureUtils;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 
 
@@ -109,11 +112,11 @@ public class TransportService extends AbstractLifecycleComponent<TransportServic
     volatile DiscoveryNode localNode = null;
     volatile DiscoveryNode localNode = null;
 
 
     public TransportService(Transport transport, ThreadPool threadPool) {
     public TransportService(Transport transport, ThreadPool threadPool) {
-        this(EMPTY_SETTINGS, transport, threadPool);
+        this(EMPTY_SETTINGS, transport, threadPool, new NamedWriteableRegistry());
     }
     }
 
 
     @Inject
     @Inject
-    public TransportService(Settings settings, Transport transport, ThreadPool threadPool) {
+    public TransportService(Settings settings, Transport transport, ThreadPool threadPool, NamedWriteableRegistry namedWriteableRegistry) {
         super(settings);
         super(settings);
         this.transport = transport;
         this.transport = transport;
         this.threadPool = threadPool;
         this.threadPool = threadPool;
@@ -122,6 +125,7 @@ public class TransportService extends AbstractLifecycleComponent<TransportServic
         tracerLog = Loggers.getLogger(logger, ".tracer");
         tracerLog = Loggers.getLogger(logger, ".tracer");
         adapter = createAdapter();
         adapter = createAdapter();
         taskManager = createTaskManager();
         taskManager = createTaskManager();
+        namedWriteableRegistry.registerPrototype(Task.Status.class, ReplicationTask.Status.PROTOTYPE);
     }
     }
 
 
     /**
     /**

+ 1 - 0
core/src/main/java/org/elasticsearch/transport/local/LocalTransport.java

@@ -333,6 +333,7 @@ public class LocalTransport extends AbstractLifecycleComponent<Transport> implem
     }
     }
 
 
     protected void handleResponse(StreamInput buffer, LocalTransport sourceTransport, final TransportResponseHandler handler) {
     protected void handleResponse(StreamInput buffer, LocalTransport sourceTransport, final TransportResponseHandler handler) {
+        buffer = new NamedWriteableAwareStreamInput(buffer, namedWriteableRegistry);
         final TransportResponse response = handler.newInstance();
         final TransportResponse response = handler.newInstance();
         response.remoteAddress(sourceTransport.boundAddress.publishAddress());
         response.remoteAddress(sourceTransport.boundAddress.publishAddress());
         try {
         try {

+ 1 - 0
core/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java

@@ -192,6 +192,7 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler {
     }
     }
 
 
     protected void handleResponse(Channel channel, StreamInput buffer, final TransportResponseHandler handler) {
     protected void handleResponse(Channel channel, StreamInput buffer, final TransportResponseHandler handler) {
+        buffer = new NamedWriteableAwareStreamInput(buffer, transport.namedWriteableRegistry);
         final TransportResponse response = handler.newInstance();
         final TransportResponse response = handler.newInstance();
         response.remoteAddress(new InetSocketTransportAddress((InetSocketAddress) channel.getRemoteAddress()));
         response.remoteAddress(new InetSocketTransportAddress((InetSocketAddress) channel.getRemoteAddress()));
         response.remoteAddress();
         response.remoteAddress();

+ 62 - 0
core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java

@@ -18,6 +18,7 @@
  */
  */
 package org.elasticsearch.action.admin.cluster.node.tasks;
 package org.elasticsearch.action.admin.cluster.node.tasks;
 
 
+import org.elasticsearch.action.ListenableActionFuture;
 import org.elasticsearch.action.admin.cluster.health.ClusterHealthAction;
 import org.elasticsearch.action.admin.cluster.health.ClusterHealthAction;
 import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksAction;
 import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksAction;
 import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
 import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
@@ -25,6 +26,7 @@ import org.elasticsearch.action.admin.cluster.node.tasks.list.TaskInfo;
 import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
 import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
 import org.elasticsearch.action.admin.indices.upgrade.post.UpgradeAction;
 import org.elasticsearch.action.admin.indices.upgrade.post.UpgradeAction;
 import org.elasticsearch.action.admin.indices.validate.query.ValidateQueryAction;
 import org.elasticsearch.action.admin.indices.validate.query.ValidateQueryAction;
+import org.elasticsearch.action.index.IndexAction;
 import org.elasticsearch.action.percolate.PercolateAction;
 import org.elasticsearch.action.percolate.PercolateAction;
 import org.elasticsearch.cluster.ClusterService;
 import org.elasticsearch.cluster.ClusterService;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -32,20 +34,27 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.tasks.Task;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.tasks.MockTaskManager;
 import org.elasticsearch.test.tasks.MockTaskManager;
+import org.elasticsearch.test.tasks.MockTaskManagerListener;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.test.transport.MockTransportService;
 
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.locks.ReentrantLock;
 import java.util.function.Function;
 import java.util.function.Function;
 
 
+import static org.hamcrest.Matchers.emptyCollectionOf;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.hamcrest.Matchers.not;
 
 
 /**
 /**
  * Integration tests for task management API
  * Integration tests for task management API
@@ -218,6 +227,59 @@ public class TasksIT extends ESIntegTestCase {
         }
         }
     }
     }
 
 
+    /**
+     * Very basic "is it plugged in" style test that indexes a document and
+     * makes sure that you can fetch the status of the process. The goal here is
+     * to verify that the large moving parts that make fetching task status work
+     * fit together rather than to verify any particular status results from
+     * indexing. For that, look at
+     * {@link org.elasticsearch.action.support.replication.TransportReplicationActionTests}
+     * . We intentionally don't use the task recording mechanism used in other
+     * places in this test so we can make sure that the status fetching works
+     * properly over the wire.
+     */
+    public void testCanFetchIndexStatus() throws InterruptedException, ExecutionException, IOException {
+        /*
+         * We prevent any tasks from unregistering until the test is done so we
+         * can fetch them. This will gum up the server if we leave it enabled
+         * but we'll be quick so it'll be OK (TM).
+         */
+        ReentrantLock taskFinishLock = new ReentrantLock();
+        taskFinishLock.lock();
+        for (ClusterService clusterService : internalCluster().getInstances(ClusterService.class)) {
+            ((MockTaskManager)clusterService.getTaskManager()).addListener(new MockTaskManagerListener() {
+                @Override
+                public void onTaskRegistered(Task task) {
+                    // Intentional noop
+                }
+
+                @Override
+                public void onTaskUnregistered(Task task) {
+                    /*
+                     * We can't block all tasks here or the task listing task
+                     * would never return.
+                     */
+                    if (false == task.getAction().startsWith(IndexAction.NAME)) {
+                        return;
+                    }
+                    logger.debug("Blocking {} from being unregistered", task);
+                    taskFinishLock.lock();
+                    taskFinishLock.unlock();
+                }
+            });
+        }
+        ListenableActionFuture<?> indexFuture = client().prepareIndex("test", "test").setSource("test", "test").execute();
+        ListTasksResponse tasks = client().admin().cluster().prepareListTasks().setActions("indices:data/write/index*").setDetailed(true)
+                .get();
+        taskFinishLock.unlock();
+        indexFuture.get();
+        assertThat(tasks.getTasks(), not(emptyCollectionOf(TaskInfo.class)));
+        for (TaskInfo task : tasks.getTasks()) {
+            assertNotNull(task.getStatus());
+        }
+    }
+
+
     @Override
     @Override
     public void tearDown() throws Exception {
     public void tearDown() throws Exception {
         for (Map.Entry<Tuple<String, String>, RecordingTaskManagerListener> entry : listeners.entrySet()) {
         for (Map.Entry<Tuple<String, String>, RecordingTaskManagerListener> entry : listeners.entrySet()) {

+ 1 - 2
core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java

@@ -58,7 +58,6 @@ import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.transport.local.LocalTransport;
 import org.elasticsearch.transport.local.LocalTransport;
 import org.junit.After;
 import org.junit.After;
 import org.junit.AfterClass;
 import org.junit.AfterClass;
-import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.BeforeClass;
 
 
 import java.io.IOException;
 import java.io.IOException;
@@ -115,7 +114,7 @@ public class TransportTasksActionTests extends ESTestCase {
         public TestNode(String name, ThreadPool threadPool, Settings settings) {
         public TestNode(String name, ThreadPool threadPool, Settings settings) {
             transportService = new TransportService(settings,
             transportService = new TransportService(settings,
                 new LocalTransport(settings, threadPool, Version.CURRENT, new NamedWriteableRegistry()),
                 new LocalTransport(settings, threadPool, Version.CURRENT, new NamedWriteableRegistry()),
-                threadPool){
+                threadPool, new NamedWriteableRegistry()) {
                 @Override
                 @Override
                 protected TaskManager createTaskManager() {
                 protected TaskManager createTaskManager() {
                     if (MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.get(settings)) {
                     if (MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.get(settings)) {

+ 107 - 39
core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java

@@ -18,6 +18,8 @@
  */
  */
 package org.elasticsearch.action.support.replication;
 package org.elasticsearch.action.support.replication;
 
 
+import com.carrotsearch.randomizedtesting.annotations.Repeat;
+
 import org.apache.lucene.index.CorruptIndexException;
 import org.apache.lucene.index.CorruptIndexException;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ReplicationResponse;
 import org.elasticsearch.action.ReplicationResponse;
@@ -44,10 +46,10 @@ import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.allocation.AllocationService;
 import org.elasticsearch.cluster.routing.allocation.AllocationService;
 import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
 import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.IndexNotFoundException;
 import org.elasticsearch.index.IndexNotFoundException;
 import org.elasticsearch.index.shard.IndexShardNotStartedException;
 import org.elasticsearch.index.shard.IndexShardNotStartedException;
@@ -64,6 +66,7 @@ import org.elasticsearch.transport.TransportChannel;
 import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportResponseOptions;
 import org.elasticsearch.transport.TransportResponseOptions;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.transport.TransportService;
+import org.hamcrest.Matcher;
 import org.junit.AfterClass;
 import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.BeforeClass;
@@ -86,6 +89,7 @@ import static org.elasticsearch.action.support.replication.ClusterStateCreationU
 import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithActivePrimary;
 import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithActivePrimary;
 import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.Matchers.arrayWithSize;
 import static org.hamcrest.Matchers.arrayWithSize;
+import static org.hamcrest.Matchers.either;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.hasItem;
@@ -142,27 +146,30 @@ public class TransportReplicationActionTests extends ESTestCase {
     public void testBlocks() throws ExecutionException, InterruptedException {
     public void testBlocks() throws ExecutionException, InterruptedException {
         Request request = new Request();
         Request request = new Request();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        ReplicationTask task = maybeTask();
 
 
         ClusterBlocks.Builder block = ClusterBlocks.builder()
         ClusterBlocks.Builder block = ClusterBlocks.builder()
                 .addGlobalBlock(new ClusterBlock(1, "non retryable", false, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL));
                 .addGlobalBlock(new ClusterBlock(1, "non retryable", false, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL));
         clusterService.setState(ClusterState.builder(clusterService.state()).blocks(block));
         clusterService.setState(ClusterState.builder(clusterService.state()).blocks(block));
-        TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(null, request, listener);
+        TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(task, request, listener);
         reroutePhase.run();
         reroutePhase.run();
         assertListenerThrows("primary phase should fail operation", listener, ClusterBlockException.class);
         assertListenerThrows("primary phase should fail operation", listener, ClusterBlockException.class);
+        assertPhase(task, "failed");
 
 
         block = ClusterBlocks.builder()
         block = ClusterBlocks.builder()
                 .addGlobalBlock(new ClusterBlock(1, "retryable", true, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL));
                 .addGlobalBlock(new ClusterBlock(1, "retryable", true, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL));
         clusterService.setState(ClusterState.builder(clusterService.state()).blocks(block));
         clusterService.setState(ClusterState.builder(clusterService.state()).blocks(block));
         listener = new PlainActionFuture<>();
         listener = new PlainActionFuture<>();
-        reroutePhase = action.new ReroutePhase(null, new Request().timeout("5ms"), listener);
+        reroutePhase = action.new ReroutePhase(task, new Request().timeout("5ms"), listener);
         reroutePhase.run();
         reroutePhase.run();
         assertListenerThrows("failed to timeout on retryable block", listener, ClusterBlockException.class);
         assertListenerThrows("failed to timeout on retryable block", listener, ClusterBlockException.class);
-
+        assertPhase(task, "failed");
 
 
         listener = new PlainActionFuture<>();
         listener = new PlainActionFuture<>();
-        reroutePhase = action.new ReroutePhase(null, new Request(), listener);
+        reroutePhase = action.new ReroutePhase(task, new Request(), listener);
         reroutePhase.run();
         reroutePhase.run();
         assertFalse("primary phase should wait on retryable block", listener.isDone());
         assertFalse("primary phase should wait on retryable block", listener.isDone());
+        assertPhase(task, "waiting_for_retry");
 
 
         block = ClusterBlocks.builder()
         block = ClusterBlocks.builder()
                 .addGlobalBlock(new ClusterBlock(1, "non retryable", false, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL));
                 .addGlobalBlock(new ClusterBlock(1, "non retryable", false, true, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL));
@@ -181,20 +188,23 @@ public class TransportReplicationActionTests extends ESTestCase {
         // no replicas in oder to skip the replication part
         // no replicas in oder to skip the replication part
         clusterService.setState(state(index, true,
         clusterService.setState(state(index, true,
                 randomBoolean() ? ShardRoutingState.INITIALIZING : ShardRoutingState.UNASSIGNED));
                 randomBoolean() ? ShardRoutingState.INITIALIZING : ShardRoutingState.UNASSIGNED));
+        ReplicationTask task = maybeTask();
 
 
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
 
 
         Request request = new Request(shardId).timeout("1ms");
         Request request = new Request(shardId).timeout("1ms");
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
-        TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(null, request, listener);
+        TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(task, request, listener);
         reroutePhase.run();
         reroutePhase.run();
         assertListenerThrows("unassigned primary didn't cause a timeout", listener, UnavailableShardsException.class);
         assertListenerThrows("unassigned primary didn't cause a timeout", listener, UnavailableShardsException.class);
+        assertPhase(task, "failed");
 
 
         request = new Request(shardId);
         request = new Request(shardId);
         listener = new PlainActionFuture<>();
         listener = new PlainActionFuture<>();
-        reroutePhase = action.new ReroutePhase(null, request, listener);
+        reroutePhase = action.new ReroutePhase(task, request, listener);
         reroutePhase.run();
         reroutePhase.run();
         assertFalse("unassigned primary didn't cause a retry", listener.isDone());
         assertFalse("unassigned primary didn't cause a retry", listener.isDone());
+        assertPhase(task, "waiting_for_retry");
 
 
         clusterService.setState(state(index, true, ShardRoutingState.STARTED));
         clusterService.setState(state(index, true, ShardRoutingState.STARTED));
         logger.debug("--> primary assigned state:\n{}", clusterService.state().prettyPrint());
         logger.debug("--> primary assigned state:\n{}", clusterService.state().prettyPrint());
@@ -267,9 +277,12 @@ public class TransportReplicationActionTests extends ESTestCase {
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         Request request = new Request(new ShardId("unknown_index", "_na_", 0)).timeout("1ms");
         Request request = new Request(new ShardId("unknown_index", "_na_", 0)).timeout("1ms");
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
-        TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(null, request, listener);
+        ReplicationTask task = maybeTask();
+
+        TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(task, request, listener);
         reroutePhase.run();
         reroutePhase.run();
         assertListenerThrows("must throw index not found exception", listener, IndexNotFoundException.class);
         assertListenerThrows("must throw index not found exception", listener, IndexNotFoundException.class);
+        assertPhase(task, "failed");
         request = new Request(new ShardId(index, "_na_", 10)).timeout("1ms");
         request = new Request(new ShardId(index, "_na_", 10)).timeout("1ms");
         listener = new PlainActionFuture<>();
         listener = new PlainActionFuture<>();
         reroutePhase = action.new ReroutePhase(null, request, listener);
         reroutePhase = action.new ReroutePhase(null, request, listener);
@@ -280,9 +293,9 @@ public class TransportReplicationActionTests extends ESTestCase {
     public void testRoutePhaseExecutesRequest() {
     public void testRoutePhaseExecutesRequest() {
         final String index = "test";
         final String index = "test";
         final ShardId shardId = new ShardId(index, "_na_", 0);
         final ShardId shardId = new ShardId(index, "_na_", 0);
+        ReplicationTask task = maybeTask();
 
 
         clusterService.setState(stateWithActivePrimary(index, randomBoolean(), 3));
         clusterService.setState(stateWithActivePrimary(index, randomBoolean(), 3));
-
         logger.debug("using state: \n{}", clusterService.state().prettyPrint());
         logger.debug("using state: \n{}", clusterService.state().prettyPrint());
 
 
         final IndexShardRoutingTable shardRoutingTable = clusterService.state().routingTable().index(index).shard(shardId.id());
         final IndexShardRoutingTable shardRoutingTable = clusterService.state().routingTable().index(index).shard(shardId.id());
@@ -290,7 +303,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         Request request = new Request(shardId);
         Request request = new Request(shardId);
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
 
 
-        TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(null, request, listener);
+        TransportReplicationAction.ReroutePhase reroutePhase = action.new ReroutePhase(task, request, listener);
         reroutePhase.run();
         reroutePhase.run();
         assertThat(request.shardId(), equalTo(shardId));
         assertThat(request.shardId(), equalTo(shardId));
         logger.info("--> primary is assigned to [{}], checking request forwarded", primaryNodeId);
         logger.info("--> primary is assigned to [{}], checking request forwarded", primaryNodeId);
@@ -299,8 +312,10 @@ public class TransportReplicationActionTests extends ESTestCase {
         assertThat(capturedRequests.size(), equalTo(1));
         assertThat(capturedRequests.size(), equalTo(1));
         if (clusterService.state().nodes().localNodeId().equals(primaryNodeId)) {
         if (clusterService.state().nodes().localNodeId().equals(primaryNodeId)) {
             assertThat(capturedRequests.get(0).action, equalTo("testAction[p]"));
             assertThat(capturedRequests.get(0).action, equalTo("testAction[p]"));
+            assertPhase(task, "waiting_on_primary");
         } else {
         } else {
             assertThat(capturedRequests.get(0).action, equalTo("testAction"));
             assertThat(capturedRequests.get(0).action, equalTo("testAction"));
+            assertPhase(task, "rerouted");
         }
         }
         assertIndexShardUninitialized();
         assertIndexShardUninitialized();
     }
     }
@@ -312,8 +327,9 @@ public class TransportReplicationActionTests extends ESTestCase {
         clusterService.setState(state);
         clusterService.setState(state);
         Request request = new Request(shardId).timeout("1ms");
         Request request = new Request(shardId).timeout("1ms");
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        ReplicationTask task = maybeTask();
         AtomicBoolean movedToReplication = new AtomicBoolean();
         AtomicBoolean movedToReplication = new AtomicBoolean();
-        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(request, createTransportChannel(listener)) {
+        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(task, request, createTransportChannel(listener)) {
             @Override
             @Override
             void finishAndMoveToReplication(TransportReplicationAction.ReplicationPhase replicationPhase) {
             void finishAndMoveToReplication(TransportReplicationAction.ReplicationPhase replicationPhase) {
                 super.finishAndMoveToReplication(replicationPhase);
                 super.finishAndMoveToReplication(replicationPhase);
@@ -335,6 +351,9 @@ public class TransportReplicationActionTests extends ESTestCase {
             assertThat(requests, notNullValue());
             assertThat(requests, notNullValue());
             assertThat(requests.size(), equalTo(1));
             assertThat(requests.size(), equalTo(1));
             assertThat("primary request was not delegated to relocation target", requests.get(0).action, equalTo("testAction[p]"));
             assertThat("primary request was not delegated to relocation target", requests.get(0).action, equalTo("testAction[p]"));
+            assertPhase(task, "primary");
+        } else {
+            assertPhase(task, either(equalTo("finished")).or(equalTo("replicating")));
         }
         }
     }
     }
 
 
@@ -348,8 +367,9 @@ public class TransportReplicationActionTests extends ESTestCase {
         clusterService.setState(state);
         clusterService.setState(state);
         Request request = new Request(shardId).timeout("1ms");
         Request request = new Request(shardId).timeout("1ms");
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        ReplicationTask task = maybeTask();
         AtomicBoolean movedToReplication = new AtomicBoolean();
         AtomicBoolean movedToReplication = new AtomicBoolean();
-        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(request, createTransportChannel(listener)) {
+        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(task, request, createTransportChannel(listener)) {
             @Override
             @Override
             void finishAndMoveToReplication(TransportReplicationAction.ReplicationPhase replicationPhase) {
             void finishAndMoveToReplication(TransportReplicationAction.ReplicationPhase replicationPhase) {
                 super.finishAndMoveToReplication(replicationPhase);
                 super.finishAndMoveToReplication(replicationPhase);
@@ -359,6 +379,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         primaryPhase.run();
         primaryPhase.run();
         assertThat("request was not processed on primary relocation target", request.processedOnPrimary.get(), equalTo(true));
         assertThat("request was not processed on primary relocation target", request.processedOnPrimary.get(), equalTo(true));
         assertThat(movedToReplication.get(), equalTo(true));
         assertThat(movedToReplication.get(), equalTo(true));
+        assertPhase(task, "replicating");
     }
     }
 
 
     public void testAddedReplicaAfterPrimaryOperation() {
     public void testAddedReplicaAfterPrimaryOperation() {
@@ -368,6 +389,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         clusterService.setState(stateWithActivePrimary(index, true, 0));
         clusterService.setState(stateWithActivePrimary(index, true, 0));
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         final ClusterState stateWithAddedReplicas = state(index, true, ShardRoutingState.STARTED, randomBoolean() ? ShardRoutingState.INITIALIZING : ShardRoutingState.STARTED);
         final ClusterState stateWithAddedReplicas = state(index, true, ShardRoutingState.STARTED, randomBoolean() ? ShardRoutingState.INITIALIZING : ShardRoutingState.STARTED);
+        ReplicationTask task = maybeTask();
 
 
         final Action actionWithAddedReplicaAfterPrimaryOp = new Action(Settings.EMPTY, "testAction", transportService, clusterService, threadPool) {
         final Action actionWithAddedReplicaAfterPrimaryOp = new Action(Settings.EMPTY, "testAction", transportService, clusterService, threadPool) {
             @Override
             @Override
@@ -382,9 +404,10 @@ public class TransportReplicationActionTests extends ESTestCase {
 
 
         Request request = new Request(shardId);
         Request request = new Request(shardId);
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
-        TransportReplicationAction<Request, Request, Response>.PrimaryPhase primaryPhase = actionWithAddedReplicaAfterPrimaryOp.new PrimaryPhase(request, createTransportChannel(listener));
+        TransportReplicationAction<Request, Request, Response>.PrimaryPhase primaryPhase = actionWithAddedReplicaAfterPrimaryOp.new PrimaryPhase(task, request, createTransportChannel(listener));
         primaryPhase.run();
         primaryPhase.run();
         assertThat("request was not processed on primary", request.processedOnPrimary.get(), equalTo(true));
         assertThat("request was not processed on primary", request.processedOnPrimary.get(), equalTo(true));
+        assertPhase(task, "replicating");
         Map<String, List<CapturingTransport.CapturedRequest>> capturedRequestsByTargetNode = transport.getCapturedRequestsByTargetNodeAndClear();
         Map<String, List<CapturingTransport.CapturedRequest>> capturedRequestsByTargetNode = transport.getCapturedRequestsByTargetNodeAndClear();
         for (ShardRouting replica : stateWithAddedReplicas.getRoutingTable().shardRoutingTable(index, shardId.id()).replicaShards()) {
         for (ShardRouting replica : stateWithAddedReplicas.getRoutingTable().shardRoutingTable(index, shardId.id()).replicaShards()) {
             List<CapturingTransport.CapturedRequest> requests = capturedRequestsByTargetNode.get(replica.currentNodeId());
             List<CapturingTransport.CapturedRequest> requests = capturedRequestsByTargetNode.get(replica.currentNodeId());
@@ -415,11 +438,14 @@ public class TransportReplicationActionTests extends ESTestCase {
 
 
         Request request = new Request(shardId);
         Request request = new Request(shardId);
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
-        TransportReplicationAction<Request, Request, Response>.PrimaryPhase primaryPhase = actionWithRelocatingReplicasAfterPrimaryOp.new PrimaryPhase(request, createTransportChannel(listener));
+        ReplicationTask task = maybeTask();
+        TransportReplicationAction<Request, Request, Response>.PrimaryPhase primaryPhase = actionWithRelocatingReplicasAfterPrimaryOp.new PrimaryPhase(
+                task, request, createTransportChannel(listener));
         primaryPhase.run();
         primaryPhase.run();
         assertThat("request was not processed on primary", request.processedOnPrimary.get(), equalTo(true));
         assertThat("request was not processed on primary", request.processedOnPrimary.get(), equalTo(true));
         ShardRouting relocatingReplicaShard = stateWithRelocatingReplica.getRoutingTable().shardRoutingTable(index, shardId.id()).replicaShards().get(0);
         ShardRouting relocatingReplicaShard = stateWithRelocatingReplica.getRoutingTable().shardRoutingTable(index, shardId.id()).replicaShards().get(0);
         Map<String, List<CapturingTransport.CapturedRequest>> capturedRequestsByTargetNode = transport.getCapturedRequestsByTargetNodeAndClear();
         Map<String, List<CapturingTransport.CapturedRequest>> capturedRequestsByTargetNode = transport.getCapturedRequestsByTargetNodeAndClear();
+        assertPhase(task, "replicating");
         for (String node : new String[] {relocatingReplicaShard.currentNodeId(), relocatingReplicaShard.relocatingNodeId()}) {
         for (String node : new String[] {relocatingReplicaShard.currentNodeId(), relocatingReplicaShard.relocatingNodeId()}) {
             List<CapturingTransport.CapturedRequest> requests = capturedRequestsByTargetNode.get(node);
             List<CapturingTransport.CapturedRequest> requests = capturedRequestsByTargetNode.get(node);
             assertThat(requests, notNullValue());
             assertThat(requests, notNullValue());
@@ -448,10 +474,13 @@ public class TransportReplicationActionTests extends ESTestCase {
 
 
         Request request = new Request(shardId);
         Request request = new Request(shardId);
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
-        TransportReplicationAction<Request, Request, Response>.PrimaryPhase primaryPhase = actionWithDeletedIndexAfterPrimaryOp.new PrimaryPhase(request, createTransportChannel(listener));
+        ReplicationTask task = maybeTask();
+        TransportReplicationAction<Request, Request, Response>.PrimaryPhase primaryPhase = actionWithDeletedIndexAfterPrimaryOp.new PrimaryPhase(
+                task, request, createTransportChannel(listener));
         primaryPhase.run();
         primaryPhase.run();
         assertThat("request was not processed on primary", request.processedOnPrimary.get(), equalTo(true));
         assertThat("request was not processed on primary", request.processedOnPrimary.get(), equalTo(true));
         assertThat("replication phase should be skipped if index gets deleted after primary operation", transport.capturedRequestsByTargetNode().size(), equalTo(0));
         assertThat("replication phase should be skipped if index gets deleted after primary operation", transport.capturedRequestsByTargetNode().size(), equalTo(0));
+        assertPhase(task, "finished");
     }
     }
 
 
     public void testWriteConsistency() throws ExecutionException, InterruptedException {
     public void testWriteConsistency() throws ExecutionException, InterruptedException {
@@ -496,16 +525,18 @@ public class TransportReplicationActionTests extends ESTestCase {
 
 
         final IndexShardRoutingTable shardRoutingTable = clusterService.state().routingTable().index(index).shard(shardId.id());
         final IndexShardRoutingTable shardRoutingTable = clusterService.state().routingTable().index(index).shard(shardId.id());
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
-        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(request, createTransportChannel(listener));
+        ReplicationTask task = maybeTask();
+        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(task, request, createTransportChannel(listener));
         if (passesWriteConsistency) {
         if (passesWriteConsistency) {
             assertThat(primaryPhase.checkWriteConsistency(shardRoutingTable.primaryShard().shardId()), nullValue());
             assertThat(primaryPhase.checkWriteConsistency(shardRoutingTable.primaryShard().shardId()), nullValue());
             primaryPhase.run();
             primaryPhase.run();
-            assertTrue("operations should have been perform, consistency level is met", request.processedOnPrimary.get());
+            assertTrue("operations should have been performed, consistency level is met", request.processedOnPrimary.get());
             if (assignedReplicas > 0) {
             if (assignedReplicas > 0) {
                 assertIndexShardCounter(2);
                 assertIndexShardCounter(2);
             } else {
             } else {
                 assertIndexShardCounter(1);
                 assertIndexShardCounter(1);
             }
             }
+            assertPhase(task, either(equalTo("finished")).or(equalTo("replicating")));
         } else {
         } else {
             assertThat(primaryPhase.checkWriteConsistency(shardRoutingTable.primaryShard().shardId()), notNullValue());
             assertThat(primaryPhase.checkWriteConsistency(shardRoutingTable.primaryShard().shardId()), notNullValue());
             primaryPhase.run();
             primaryPhase.run();
@@ -517,10 +548,11 @@ public class TransportReplicationActionTests extends ESTestCase {
             }
             }
             clusterService.setState(state(index, true, ShardRoutingState.STARTED, replicaStates));
             clusterService.setState(state(index, true, ShardRoutingState.STARTED, replicaStates));
             listener = new PlainActionFuture<>();
             listener = new PlainActionFuture<>();
-            primaryPhase = action.new PrimaryPhase(request, createTransportChannel(listener));
+            primaryPhase = action.new PrimaryPhase(task, request, createTransportChannel(listener));
             primaryPhase.run();
             primaryPhase.run();
             assertTrue("once the consistency level met, operation should continue", request.processedOnPrimary.get());
             assertTrue("once the consistency level met, operation should continue", request.processedOnPrimary.get());
             assertIndexShardCounter(2);
             assertIndexShardCounter(2);
+            assertPhase(task, "replicating");
         }
         }
     }
     }
 
 
@@ -590,6 +622,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         final ShardId shardId = shardIt.shardId();
         final ShardId shardId = shardIt.shardId();
         final Request request = new Request(shardId);
         final Request request = new Request(shardId);
         final PlainActionFuture<Response> listener = new PlainActionFuture<>();
         final PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        ReplicationTask task = maybeTask();
         logger.debug("expecting [{}] assigned replicas, [{}] total shards. using state: \n{}", assignedReplicas, totalShards, clusterService.state().prettyPrint());
         logger.debug("expecting [{}] assigned replicas, [{}] total shards. using state: \n{}", assignedReplicas, totalShards, clusterService.state().prettyPrint());
 
 
         TransportReplicationAction.IndexShardReference reference = getOrCreateIndexShardOperationsCounter();
         TransportReplicationAction.IndexShardReference reference = getOrCreateIndexShardOperationsCounter();
@@ -599,15 +632,14 @@ public class TransportReplicationActionTests extends ESTestCase {
 
 
         assertIndexShardCounter(2);
         assertIndexShardCounter(2);
         // TODO: set a default timeout
         // TODO: set a default timeout
-        TransportReplicationAction<Request, Request, Response>.ReplicationPhase replicationPhase =
-                action.new ReplicationPhase(request,
-                        new Response(),
-                        request.shardId(), createTransportChannel(listener), reference);
+        TransportReplicationAction<Request, Request, Response>.ReplicationPhase replicationPhase = action.new ReplicationPhase(task,
+                request, new Response(), request.shardId(), createTransportChannel(listener), reference);
 
 
         assertThat(replicationPhase.totalShards(), equalTo(totalShards));
         assertThat(replicationPhase.totalShards(), equalTo(totalShards));
         assertThat(replicationPhase.pending(), equalTo(assignedReplicas));
         assertThat(replicationPhase.pending(), equalTo(assignedReplicas));
         replicationPhase.run();
         replicationPhase.run();
         final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
         final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
+        assertPhase(task, either(equalTo("finished")).or(equalTo("replicating")));
 
 
         HashMap<String, Request> nodesSentTo = new HashMap<>();
         HashMap<String, Request> nodesSentTo = new HashMap<>();
         boolean executeOnReplica =
         boolean executeOnReplica =
@@ -718,11 +750,11 @@ public class TransportReplicationActionTests extends ESTestCase {
         final String index = "test";
         final String index = "test";
         final ShardId shardId = new ShardId(index, "_na_", 0);
         final ShardId shardId = new ShardId(index, "_na_", 0);
         // no replica, we only want to test on primary
         // no replica, we only want to test on primary
-        clusterService.setState(state(index, true,
-                ShardRoutingState.STARTED));
+        clusterService.setState(state(index, true, ShardRoutingState.STARTED));
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         Request request = new Request(shardId).timeout("100ms");
         Request request = new Request(shardId).timeout("100ms");
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        ReplicationTask task = maybeTask();
 
 
         /**
         /**
          * Execute an action that is stuck in shard operation until a latch is counted down.
          * Execute an action that is stuck in shard operation until a latch is counted down.
@@ -732,7 +764,7 @@ public class TransportReplicationActionTests extends ESTestCase {
          * However, this failure would only become apparent once listener.get is called. Seems a little implicit.
          * However, this failure would only become apparent once listener.get is called. Seems a little implicit.
          * */
          * */
         action = new ActionWithDelay(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool);
         action = new ActionWithDelay(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool);
-        final TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(request, createTransportChannel(listener));
+        final TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(task, request, createTransportChannel(listener));
         Thread t = new Thread() {
         Thread t = new Thread() {
             @Override
             @Override
             public void run() {
             public void run() {
@@ -751,6 +783,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         // operation finished, counter back to 0
         // operation finished, counter back to 0
         assertIndexShardCounter(1);
         assertIndexShardCounter(1);
         assertThat(transport.capturedRequests().length, equalTo(0));
         assertThat(transport.capturedRequests().length, equalTo(0));
+        assertPhase(task, "finished");
     }
     }
 
 
     public void testCounterIncrementedWhileReplicationOngoing() throws InterruptedException, ExecutionException, IOException {
     public void testCounterIncrementedWhileReplicationOngoing() throws InterruptedException, ExecutionException, IOException {
@@ -764,7 +797,9 @@ public class TransportReplicationActionTests extends ESTestCase {
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         Request request = new Request(shardId).timeout("100ms");
         Request request = new Request(shardId).timeout("100ms");
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
-        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(request, createTransportChannel(listener));
+        ReplicationTask task = maybeTask();
+
+        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(task, request, createTransportChannel(listener));
         primaryPhase.run();
         primaryPhase.run();
         assertIndexShardCounter(2);
         assertIndexShardCounter(2);
         assertThat(transport.capturedRequests().length, equalTo(1));
         assertThat(transport.capturedRequests().length, equalTo(1));
@@ -772,10 +807,14 @@ public class TransportReplicationActionTests extends ESTestCase {
         transport.handleResponse(transport.capturedRequests()[0].requestId, TransportResponse.Empty.INSTANCE);
         transport.handleResponse(transport.capturedRequests()[0].requestId, TransportResponse.Empty.INSTANCE);
         transport.clear();
         transport.clear();
         assertIndexShardCounter(1);
         assertIndexShardCounter(1);
+        assertPhase(task, "finished");
+
         request = new Request(shardId).timeout("100ms");
         request = new Request(shardId).timeout("100ms");
-        primaryPhase = action.new PrimaryPhase(request, createTransportChannel(listener));
+        task = maybeTask();
+        primaryPhase = action.new PrimaryPhase(task, request, createTransportChannel(listener));
         primaryPhase.run();
         primaryPhase.run();
         assertIndexShardCounter(2);
         assertIndexShardCounter(2);
+        assertPhase(task, "replicating");
         CapturingTransport.CapturedRequest[] replicationRequests = transport.getCapturedRequestsAndClear();
         CapturingTransport.CapturedRequest[] replicationRequests = transport.getCapturedRequestsAndClear();
         assertThat(replicationRequests.length, equalTo(1));
         assertThat(replicationRequests.length, equalTo(1));
         // try with failure response
         // try with failure response
@@ -792,12 +831,14 @@ public class TransportReplicationActionTests extends ESTestCase {
                 ShardRoutingState.STARTED, ShardRoutingState.STARTED));
                 ShardRoutingState.STARTED, ShardRoutingState.STARTED));
         action = new ActionWithDelay(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool);
         action = new ActionWithDelay(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool);
         final Action.ReplicaOperationTransportHandler replicaOperationTransportHandler = action.new ReplicaOperationTransportHandler();
         final Action.ReplicaOperationTransportHandler replicaOperationTransportHandler = action.new ReplicaOperationTransportHandler();
+        final ReplicationTask task = maybeTask();
         Thread t = new Thread() {
         Thread t = new Thread() {
             @Override
             @Override
             public void run() {
             public void run() {
                 try {
                 try {
-                    replicaOperationTransportHandler.messageReceived(new Request().setShardId(shardId), createTransportChannel(new PlainActionFuture<>()));
+                    replicaOperationTransportHandler.messageReceived(new Request().setShardId(shardId), createTransportChannel(new PlainActionFuture<>()), task);
                 } catch (Exception e) {
                 } catch (Exception e) {
+                    logger.error("Failed", e);
                 }
                 }
             }
             }
         };
         };
@@ -807,13 +848,14 @@ public class TransportReplicationActionTests extends ESTestCase {
         assertBusy(() -> assertIndexShardCounter(2));
         assertBusy(() -> assertIndexShardCounter(2));
         ((ActionWithDelay) action).countDownLatch.countDown();
         ((ActionWithDelay) action).countDownLatch.countDown();
         t.join();
         t.join();
+        assertPhase(task, "finished");
         // operation should have finished and counter decreased because no outstanding replica requests
         // operation should have finished and counter decreased because no outstanding replica requests
         assertIndexShardCounter(1);
         assertIndexShardCounter(1);
         // now check if this also works if operation throws exception
         // now check if this also works if operation throws exception
         action = new ActionWithExceptions(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool);
         action = new ActionWithExceptions(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool);
         final Action.ReplicaOperationTransportHandler replicaOperationTransportHandlerForException = action.new ReplicaOperationTransportHandler();
         final Action.ReplicaOperationTransportHandler replicaOperationTransportHandlerForException = action.new ReplicaOperationTransportHandler();
         try {
         try {
-            replicaOperationTransportHandlerForException.messageReceived(new Request(shardId), createTransportChannel(new PlainActionFuture<>()));
+            replicaOperationTransportHandlerForException.messageReceived(new Request(shardId), createTransportChannel(new PlainActionFuture<>()), task);
             fail();
             fail();
         } catch (Throwable t2) {
         } catch (Throwable t2) {
         }
         }
@@ -829,12 +871,15 @@ public class TransportReplicationActionTests extends ESTestCase {
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
         Request request = new Request(shardId).timeout("100ms");
         Request request = new Request(shardId).timeout("100ms");
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
-        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(request, createTransportChannel(listener));
+        ReplicationTask task = maybeTask();
+
+        TransportReplicationAction.PrimaryPhase primaryPhase = action.new PrimaryPhase(task, request, createTransportChannel(listener));
         primaryPhase.run();
         primaryPhase.run();
         // no replica request should have been sent yet
         // no replica request should have been sent yet
         assertThat(transport.capturedRequests().length, equalTo(0));
         assertThat(transport.capturedRequests().length, equalTo(0));
         // no matter if the operation is retried or not, counter must be be back to 1
         // no matter if the operation is retried or not, counter must be be back to 1
         assertIndexShardCounter(1);
         assertIndexShardCounter(1);
+        assertPhase(task, "failed");
     }
     }
 
 
     private void assertIndexShardCounter(int expected) {
     private void assertIndexShardCounter(int expected) {
@@ -847,9 +892,9 @@ public class TransportReplicationActionTests extends ESTestCase {
 
 
     private final AtomicReference<ShardRouting> indexShardRouting = new AtomicReference<>();
     private final AtomicReference<ShardRouting> indexShardRouting = new AtomicReference<>();
 
 
-    /*
-    * Returns testIndexShardOperationsCounter or initializes it if it was already created in this test run.
-    * */
+    /**
+     * Returns testIndexShardOperationsCounter or initializes it if it was already created in this test run.
+     */
     private synchronized TransportReplicationAction.IndexShardReference getOrCreateIndexShardOperationsCounter() {
     private synchronized TransportReplicationAction.IndexShardReference getOrCreateIndexShardOperationsCounter() {
         count.incrementAndGet();
         count.incrementAndGet();
         return new TransportReplicationAction.IndexShardReference() {
         return new TransportReplicationAction.IndexShardReference() {
@@ -872,6 +917,29 @@ public class TransportReplicationActionTests extends ESTestCase {
         };
         };
     }
     }
 
 
+    /**
+     * Sometimes build a ReplicationTask for tracking the phase of the
+     * TransportReplicationAction. Since TransportReplicationAction has to work
+     * if the task as null just as well as if it is supplied this returns null
+     * half the time.
+     */
+    private ReplicationTask maybeTask() {
+        return random().nextBoolean() ? new ReplicationTask(0, null, null, null, null, 0) : null;
+    }
+
+    /**
+     * If the task is non-null this asserts that the phrase matches.
+     */
+    private void assertPhase(@Nullable ReplicationTask task, String phase) {
+        assertPhase(task, equalTo(phase));
+    }
+
+    private void assertPhase(@Nullable ReplicationTask task, Matcher<String> phaseMatcher) {
+        if (task != null) {
+            assertThat(task.getPhase(), phaseMatcher);
+        }
+    }
+
     public static class Request extends ReplicationRequest<Request> {
     public static class Request extends ReplicationRequest<Request> {
         public AtomicBoolean processedOnPrimary = new AtomicBoolean();
         public AtomicBoolean processedOnPrimary = new AtomicBoolean();
         public AtomicInteger processedOnReplicas = new AtomicInteger();
         public AtomicInteger processedOnReplicas = new AtomicInteger();
@@ -959,9 +1027,9 @@ public class TransportReplicationActionTests extends ESTestCase {
         }
         }
     }
     }
 
 
-    /*
-    * Throws exceptions when executed. Used for testing if the counter is correctly decremented in case an operation fails.
-    * */
+    /**
+     * Throws exceptions when executed. Used for testing if the counter is correctly decremented in case an operation fails.
+     */
     class ActionWithExceptions extends Action {
     class ActionWithExceptions extends Action {
 
 
         ActionWithExceptions(Settings settings, String actionName, TransportService transportService, ClusterService clusterService, ThreadPool threadPool) throws IOException {
         ActionWithExceptions(Settings settings, String actionName, TransportService transportService, ClusterService clusterService, ThreadPool threadPool) throws IOException {
@@ -1027,9 +1095,9 @@ public class TransportReplicationActionTests extends ESTestCase {
 
 
     }
     }
 
 
-    /*
-    * Transport channel that is needed for replica operation testing.
-    * */
+    /**
+     * Transport channel that is needed for replica operation testing.
+     */
     public TransportChannel createTransportChannel(final PlainActionFuture<Response> listener) {
     public TransportChannel createTransportChannel(final PlainActionFuture<Response> listener) {
         return new TransportChannel() {
         return new TransportChannel() {
 
 

+ 3 - 3
core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java

@@ -32,13 +32,13 @@ import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.network.NetworkModule;
 import org.elasticsearch.common.network.NetworkModule;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.LocalTransportAddress;
 import org.elasticsearch.common.transport.LocalTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.Plugin;
-import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.ConnectTransportException;
 import org.elasticsearch.transport.ConnectTransportException;
 import org.elasticsearch.transport.Transport;
 import org.elasticsearch.transport.Transport;
@@ -128,8 +128,8 @@ public class TransportClientHeadersTests extends AbstractClientHeadersTestCase {
         CountDownLatch clusterStateLatch = new CountDownLatch(1);
         CountDownLatch clusterStateLatch = new CountDownLatch(1);
 
 
         @Inject
         @Inject
-        public InternalTransportService(Settings settings, Transport transport, ThreadPool threadPool) {
-            super(settings, transport, threadPool);
+        public InternalTransportService(Settings settings, Transport transport, ThreadPool threadPool, NamedWriteableRegistry namedWriteableRegistry) {
+            super(settings, transport, threadPool, namedWriteableRegistry);
         }
         }
 
 
         @Override @SuppressWarnings("unchecked")
         @Override @SuppressWarnings("unchecked")

+ 2 - 2
core/src/test/java/org/elasticsearch/client/transport/TransportClientNodesServiceTests.java

@@ -23,9 +23,9 @@ import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.LocalTransportAddress;
 import org.elasticsearch.common.transport.LocalTransportAddress;
-import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.BaseTransportResponseHandler;
 import org.elasticsearch.transport.BaseTransportResponseHandler;
@@ -71,7 +71,7 @@ public class TransportClientNodesServiceTests extends ESTestCase {
                     return  new TestResponse();
                     return  new TestResponse();
                 }
                 }
             };
             };
-            transportService = new TransportService(Settings.EMPTY, transport, threadPool);
+            transportService = new TransportService(Settings.EMPTY, transport, threadPool, new NamedWriteableRegistry());
             transportService.start();
             transportService.start();
             transportClientNodesService = new TransportClientNodesService(Settings.EMPTY, ClusterName.DEFAULT, transportService, threadPool, Version.CURRENT);
             transportClientNodesService = new TransportClientNodesService(Settings.EMPTY, ClusterName.DEFAULT, transportService, threadPool, Version.CURRENT);
 
 

+ 3 - 1
core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java

@@ -104,7 +104,9 @@ public class ZenFaultDetectionTests extends ESTestCase {
     }
     }
 
 
     protected MockTransportService build(Settings settings, Version version) {
     protected MockTransportService build(Settings settings, Version version) {
-        MockTransportService transportService = new MockTransportService(Settings.EMPTY, new LocalTransport(settings, threadPool, version, new NamedWriteableRegistry()), threadPool);
+        NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry();
+        MockTransportService transportService = new MockTransportService(Settings.EMPTY,
+                new LocalTransport(settings, threadPool, version, namedWriteableRegistry), threadPool, namedWriteableRegistry);
         transportService.start();
         transportService.start();
         return transportService;
         return transportService;
     }
     }

+ 1 - 3
core/src/test/java/org/elasticsearch/discovery/zen/publish/PublishClusterStateActionTests.java

@@ -35,7 +35,6 @@ import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.collect.Tuple;
-import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.logging.ESLogger;
 import org.elasticsearch.common.logging.ESLogger;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.ClusterSettings;
@@ -55,7 +54,6 @@ import org.elasticsearch.transport.TransportConnectionListener;
 import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportResponseOptions;
 import org.elasticsearch.transport.TransportResponseOptions;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.transport.TransportService;
-import org.elasticsearch.transport.local.LocalTransport;
 import org.junit.After;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Before;
 
 
@@ -232,7 +230,7 @@ public class PublishClusterStateActionTests extends ESTestCase {
     }
     }
 
 
     protected MockTransportService buildTransportService(Settings settings, Version version) {
     protected MockTransportService buildTransportService(Settings settings, Version version) {
-        MockTransportService transportService = new MockTransportService(settings, new LocalTransport(settings, threadPool, version, new NamedWriteableRegistry()), threadPool);
+        MockTransportService transportService = MockTransportService.local(Settings.EMPTY, version, threadPool);
         transportService.start();
         transportService.start();
         return transportService;
         return transportService;
     }
     }

+ 2 - 2
core/src/test/java/org/elasticsearch/transport/TransportModuleTests.java

@@ -41,8 +41,8 @@ public class TransportModuleTests extends ModuleTestCase {
 
 
     static class FakeTransportService extends TransportService {
     static class FakeTransportService extends TransportService {
         @Inject
         @Inject
-        public FakeTransportService(Settings settings, Transport transport, ThreadPool threadPool) {
-            super(settings, transport, threadPool);
+        public FakeTransportService(Settings settings, Transport transport, ThreadPool threadPool, NamedWriteableRegistry namedWriteableRegistry) {
+            super(settings, transport, threadPool, namedWriteableRegistry);
         }
         }
     }
     }
 }
 }

+ 1 - 1
core/src/test/java/org/elasticsearch/transport/local/SimpleLocalTransportTests.java

@@ -29,7 +29,7 @@ public class SimpleLocalTransportTests extends AbstractSimpleTransportTestCase {
 
 
     @Override
     @Override
     protected MockTransportService build(Settings settings, Version version, NamedWriteableRegistry namedWriteableRegistry) {
     protected MockTransportService build(Settings settings, Version version, NamedWriteableRegistry namedWriteableRegistry) {
-        MockTransportService transportService = new MockTransportService(Settings.EMPTY, new LocalTransport(settings, threadPool, version, namedWriteableRegistry), threadPool);
+        MockTransportService transportService = MockTransportService.local(settings, version, threadPool);
         transportService.start();
         transportService.start();
         return transportService;
         return transportService;
     }
     }

+ 6 - 4
core/src/test/java/org/elasticsearch/transport/netty/NettyScheduledPingTests.java

@@ -52,12 +52,14 @@ public class NettyScheduledPingTests extends ESTestCase {
 
 
         Settings settings = Settings.builder().put(NettyTransport.PING_SCHEDULE.getKey(), "5ms").put(TransportSettings.PORT.getKey(), 0).build();
         Settings settings = Settings.builder().put(NettyTransport.PING_SCHEDULE.getKey(), "5ms").put(TransportSettings.PORT.getKey(), 0).build();
 
 
-        final NettyTransport nettyA = new NettyTransport(settings, threadPool, new NetworkService(settings), BigArrays.NON_RECYCLING_INSTANCE, Version.CURRENT, new NamedWriteableRegistry());
-        MockTransportService serviceA = new MockTransportService(settings, nettyA, threadPool);
+        NamedWriteableRegistry registryA = new NamedWriteableRegistry();
+        final NettyTransport nettyA = new NettyTransport(settings, threadPool, new NetworkService(settings), BigArrays.NON_RECYCLING_INSTANCE, Version.CURRENT, registryA);
+        MockTransportService serviceA = new MockTransportService(settings, nettyA, threadPool, registryA);
         serviceA.start();
         serviceA.start();
 
 
-        final NettyTransport nettyB = new NettyTransport(settings, threadPool, new NetworkService(settings), BigArrays.NON_RECYCLING_INSTANCE, Version.CURRENT, new NamedWriteableRegistry());
-        MockTransportService serviceB = new MockTransportService(settings, nettyB, threadPool);
+        NamedWriteableRegistry registryB = new NamedWriteableRegistry();
+        final NettyTransport nettyB = new NettyTransport(settings, threadPool, new NetworkService(settings), BigArrays.NON_RECYCLING_INSTANCE, Version.CURRENT, registryB);
+        MockTransportService serviceB = new MockTransportService(settings, nettyB, threadPool, registryB);
         serviceB.start();
         serviceB.start();
 
 
         DiscoveryNode nodeA = new DiscoveryNode("TS_A", "TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), Version.CURRENT);
         DiscoveryNode nodeA = new DiscoveryNode("TS_A", "TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), Version.CURRENT);

+ 1 - 3
core/src/test/java/org/elasticsearch/transport/netty/SimpleNettyTransportTests.java

@@ -22,10 +22,8 @@ package org.elasticsearch.transport.netty;
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
-import org.elasticsearch.common.network.NetworkService;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.InetSocketTransportAddress;
 import org.elasticsearch.common.transport.InetSocketTransportAddress;
-import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
 import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
 import org.elasticsearch.transport.ConnectTransportException;
 import org.elasticsearch.transport.ConnectTransportException;
@@ -41,7 +39,7 @@ public class SimpleNettyTransportTests extends AbstractSimpleTransportTestCase {
     @Override
     @Override
     protected MockTransportService build(Settings settings, Version version, NamedWriteableRegistry namedWriteableRegistry) {
     protected MockTransportService build(Settings settings, Version version, NamedWriteableRegistry namedWriteableRegistry) {
         settings = Settings.builder().put(settings).put(TransportSettings.PORT.getKey(), "0").build();
         settings = Settings.builder().put(settings).put(TransportSettings.PORT.getKey(), "0").build();
-        MockTransportService transportService = new MockTransportService(settings, new NettyTransport(settings, threadPool, new NetworkService(settings), BigArrays.NON_RECYCLING_INSTANCE, version, namedWriteableRegistry), threadPool);
+        MockTransportService transportService = MockTransportService.nettyFromThreadPool(settings, version, threadPool);
         transportService.start();
         transportService.start();
         return transportService;
         return transportService;
     }
     }

+ 3 - 3
modules/lang-groovy/src/test/java/org/elasticsearch/messy/tests/IndicesRequestTests.java

@@ -85,6 +85,7 @@ import org.elasticsearch.action.update.UpdateRequest;
 import org.elasticsearch.action.update.UpdateResponse;
 import org.elasticsearch.action.update.UpdateResponse;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.network.NetworkModule;
 import org.elasticsearch.common.network.NetworkModule;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.index.query.QueryBuilders;
@@ -93,7 +94,6 @@ import org.elasticsearch.script.Script;
 import org.elasticsearch.script.groovy.GroovyPlugin;
 import org.elasticsearch.script.groovy.GroovyPlugin;
 import org.elasticsearch.search.action.SearchServiceTransportAction;
 import org.elasticsearch.search.action.SearchServiceTransportAction;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.Task;
-import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
 import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
 import org.elasticsearch.test.ESIntegTestCase.Scope;
 import org.elasticsearch.test.ESIntegTestCase.Scope;
@@ -785,8 +785,8 @@ public class IndicesRequestTests extends ESIntegTestCase {
         private final Map<String, List<TransportRequest>> requests = new HashMap<>();
         private final Map<String, List<TransportRequest>> requests = new HashMap<>();
 
 
         @Inject
         @Inject
-        public InterceptingTransportService(Settings settings, Transport transport, ThreadPool threadPool) {
-            super(settings, transport, threadPool);
+        public InterceptingTransportService(Settings settings, Transport transport, ThreadPool threadPool, NamedWriteableRegistry namedWriteableRegistry) {
+            super(settings, transport, threadPool, namedWriteableRegistry);
         }
         }
 
 
         synchronized List<TransportRequest> consumeRequests(String action) {
         synchronized List<TransportRequest> consumeRequests(String action) {

+ 2 - 5
plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java

@@ -20,11 +20,11 @@
 package org.elasticsearch.discovery.ec2;
 package org.elasticsearch.discovery.ec2;
 
 
 import com.amazonaws.services.ec2.model.Tag;
 import com.amazonaws.services.ec2.model.Tag;
+
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
 import org.elasticsearch.cloud.aws.AwsEc2Service;
 import org.elasticsearch.cloud.aws.AwsEc2Service;
 import org.elasticsearch.cloud.aws.AwsEc2Service.DISCOVERY_EC2;
 import org.elasticsearch.cloud.aws.AwsEc2Service.DISCOVERY_EC2;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.LocalTransportAddress;
 import org.elasticsearch.common.transport.LocalTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
@@ -32,7 +32,6 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.transport.TransportService;
-import org.elasticsearch.transport.local.LocalTransport;
 import org.junit.AfterClass;
 import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.BeforeClass;
@@ -67,9 +66,7 @@ public class Ec2DiscoveryTests extends ESTestCase {
 
 
     @Before
     @Before
     public void createTransportService() {
     public void createTransportService() {
-        transportService = new MockTransportService(
-                Settings.EMPTY,
-                new LocalTransport(Settings.EMPTY, threadPool, Version.CURRENT, new NamedWriteableRegistry()), threadPool);
+        transportService = MockTransportService.local(Settings.EMPTY, Version.CURRENT, threadPool);
     }
     }
 
 
     protected List<DiscoveryNode> buildDynamicNodes(Settings nodeSettings, int nodes) {
     protected List<DiscoveryNode> buildDynamicNodes(Settings nodeSettings, int nodes) {

+ 1 - 5
plugins/discovery-gce/src/test/java/org/elasticsearch/discovery/gce/GceDiscoveryTests.java

@@ -22,13 +22,11 @@ package org.elasticsearch.discovery.gce;
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
 import org.elasticsearch.cloud.gce.GceComputeService;
 import org.elasticsearch.cloud.gce.GceComputeService;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.network.NetworkService;
 import org.elasticsearch.common.network.NetworkService;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.transport.local.LocalTransport;
 import org.junit.After;
 import org.junit.After;
 import org.junit.AfterClass;
 import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.Before;
@@ -94,9 +92,7 @@ public class GceDiscoveryTests extends ESTestCase {
 
 
     @Before
     @Before
     public void createTransportService() {
     public void createTransportService() {
-        transportService = new MockTransportService(
-                Settings.EMPTY,
-                new LocalTransport(Settings.EMPTY, threadPool, Version.CURRENT, new NamedWriteableRegistry()), threadPool);
+        transportService = MockTransportService.local(Settings.EMPTY, Version.CURRENT, threadPool);
     }
     }
 
 
     @Before
     @Before

+ 21 - 2
test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java

@@ -19,11 +19,13 @@
 
 
 package org.elasticsearch.test.transport;
 package org.elasticsearch.test.transport;
 
 
+import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.component.Lifecycle;
 import org.elasticsearch.common.component.Lifecycle;
 import org.elasticsearch.common.component.LifecycleListener;
 import org.elasticsearch.common.component.LifecycleListener;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.network.NetworkModule;
 import org.elasticsearch.common.network.NetworkModule;
 import org.elasticsearch.common.network.NetworkService;
 import org.elasticsearch.common.network.NetworkService;
@@ -32,6 +34,7 @@ import org.elasticsearch.common.settings.SettingsModule;
 import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.Plugin;
@@ -46,6 +49,8 @@ import org.elasticsearch.transport.TransportRequest;
 import org.elasticsearch.transport.TransportRequestOptions;
 import org.elasticsearch.transport.TransportRequestOptions;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.transport.TransportServiceAdapter;
 import org.elasticsearch.transport.TransportServiceAdapter;
+import org.elasticsearch.transport.local.LocalTransport;
+import org.elasticsearch.transport.netty.NettyTransport;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Arrays;
@@ -91,11 +96,25 @@ public class MockTransportService extends TransportService {
         }
         }
     }
     }
 
 
+    public static MockTransportService local(Settings settings, Version version, ThreadPool threadPool) {
+        NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry();
+        Transport transport = new LocalTransport(settings, threadPool, version, namedWriteableRegistry);
+        return new MockTransportService(settings, transport, threadPool, namedWriteableRegistry);
+    }
+
+    public static MockTransportService nettyFromThreadPool(Settings settings, Version version, ThreadPool threadPool) {
+        NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry();
+        Transport transport = new NettyTransport(settings, threadPool, new NetworkService(settings), BigArrays.NON_RECYCLING_INSTANCE,
+                version, namedWriteableRegistry);
+        return new MockTransportService(Settings.EMPTY, transport, threadPool, namedWriteableRegistry);
+    }
+
+
     private final Transport original;
     private final Transport original;
 
 
     @Inject
     @Inject
-    public MockTransportService(Settings settings, Transport transport, ThreadPool threadPool) {
-        super(settings, new LookupTestTransport(transport), threadPool);
+    public MockTransportService(Settings settings, Transport transport, ThreadPool threadPool, NamedWriteableRegistry namedWriteableRegistry) {
+        super(settings, new LookupTestTransport(transport), threadPool, namedWriteableRegistry);
         this.original = transport;
         this.original = transport;
     }
     }