Browse Source

Snapshot/Restore: Batching of snapshot state updates

Similar to the batching of "shards-started" actions, this commit implements batching of snapshot status updates. This is useful when backing up many indices as the cluster state does not need to be republished as many times.

Closes #10295
Yannick Welsch 10 years ago
parent
commit
14c1743f30

+ 111 - 57
src/main/java/org/elasticsearch/snapshots/RestoreService.java

@@ -25,6 +25,7 @@ import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.ImmutableSet;
+
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.action.support.IndicesOptions;
@@ -39,6 +40,7 @@ import org.elasticsearch.cluster.settings.ClusterDynamicSettings;
 import org.elasticsearch.cluster.settings.DynamicSettings;
 import org.elasticsearch.cluster.settings.DynamicSettings;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.component.AbstractComponent;
 import org.elasticsearch.common.component.AbstractComponent;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -47,6 +49,7 @@ import org.elasticsearch.common.regex.Regex;
 import org.elasticsearch.common.settings.ImmutableSettings;
 import org.elasticsearch.common.settings.ImmutableSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.repositories.RepositoriesService;
 import org.elasticsearch.repositories.RepositoriesService;
 import org.elasticsearch.repositories.Repository;
 import org.elasticsearch.repositories.Repository;
@@ -55,6 +58,8 @@ import org.elasticsearch.transport.*;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.util.*;
 import java.util.*;
+import java.util.Map.Entry;
+import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.CopyOnWriteArrayList;
 
 
 import static com.google.common.collect.Lists.newArrayList;
 import static com.google.common.collect.Lists.newArrayList;
@@ -120,6 +125,8 @@ public class RestoreService extends AbstractComponent implements ClusterStateLis
 
 
     private final CopyOnWriteArrayList<ActionListener<RestoreCompletionResponse>> listeners = new CopyOnWriteArrayList<>();
     private final CopyOnWriteArrayList<ActionListener<RestoreCompletionResponse>> listeners = new CopyOnWriteArrayList<>();
 
 
+    private final BlockingQueue<UpdateIndexShardRestoreStatusRequest> updatedSnapshotStateQueue = ConcurrentCollections.newBlockingQueue();
+
     @Inject
     @Inject
     public RestoreService(Settings settings, ClusterService clusterService, RepositoriesService repositoriesService, TransportService transportService,
     public RestoreService(Settings settings, ClusterService clusterService, RepositoriesService repositoriesService, TransportService transportService,
                           AllocationService allocationService, MetaDataCreateIndexService createIndexService, @ClusterDynamicSettings DynamicSettings dynamicSettings,
                           AllocationService allocationService, MetaDataCreateIndexService createIndexService, @ClusterDynamicSettings DynamicSettings dynamicSettings,
@@ -469,42 +476,75 @@ public class RestoreService extends AbstractComponent implements ClusterStateLis
      * @param request update shard status request
      * @param request update shard status request
      */
      */
     private void updateRestoreStateOnMaster(final UpdateIndexShardRestoreStatusRequest request) {
     private void updateRestoreStateOnMaster(final UpdateIndexShardRestoreStatusRequest request) {
-        clusterService.submitStateUpdateTask("update snapshot state", new ProcessedClusterStateUpdateTask() {
+        logger.trace("received updated snapshot restore state [{}]", request);
+        updatedSnapshotStateQueue.add(request);
 
 
-            private RestoreInfo restoreInfo = null;
-            private Map<ShardId, ShardRestoreStatus> shards = null;
+        clusterService.submitStateUpdateTask("update snapshot state", new ProcessedClusterStateUpdateTask() {
+            private final List<UpdateIndexShardRestoreStatusRequest> drainedRequests = new ArrayList<>();
+            private Map<SnapshotId, Tuple<RestoreInfo, Map<ShardId, ShardRestoreStatus>>> batchedRestoreInfo = null;
 
 
             @Override
             @Override
             public ClusterState execute(ClusterState currentState) {
             public ClusterState execute(ClusterState currentState) {
-                MetaData metaData = currentState.metaData();
-                MetaData.Builder mdBuilder = MetaData.builder(currentState.metaData());
-                RestoreMetaData restore = metaData.custom(RestoreMetaData.TYPE);
+
+                if (request.processed) {
+                    return currentState;
+                }
+
+                updatedSnapshotStateQueue.drainTo(drainedRequests);
+
+                final int batchSize = drainedRequests.size();
+
+                // nothing to process (a previous event has processed it already)
+                if (batchSize == 0) {
+                    return currentState;
+                }
+
+                final MetaData metaData = currentState.metaData();
+                final RestoreMetaData restore = metaData.custom(RestoreMetaData.TYPE);
                 if (restore != null) {
                 if (restore != null) {
-                    boolean changed = false;
-                    boolean found = false;
-                    ArrayList<RestoreMetaData.Entry> entries = newArrayList();
+                    int changedCount = 0;
+                    final List<RestoreMetaData.Entry> entries = newArrayList();
                     for (RestoreMetaData.Entry entry : restore.entries()) {
                     for (RestoreMetaData.Entry entry : restore.entries()) {
-                        if (entry.snapshotId().equals(request.snapshotId())) {
-                            assert !found;
-                            found = true;
-                            Map<ShardId, ShardRestoreStatus> shards = newHashMap(entry.shards());
-                            logger.trace("[{}] Updating shard [{}] with status [{}]", request.snapshotId(), request.shardId(), request.status().state());
-                            shards.put(request.shardId(), request.status());
+                        Map<ShardId, ShardRestoreStatus> shards = null;
+
+                        for (int i = 0; i < batchSize; i++) {
+                            final UpdateIndexShardRestoreStatusRequest updateSnapshotState = drainedRequests.get(i);
+                            updateSnapshotState.processed = true;
+
+                            if (entry.snapshotId().equals(updateSnapshotState.snapshotId())) {
+                                logger.trace("[{}] Updating shard [{}] with status [{}]", updateSnapshotState.snapshotId(), updateSnapshotState.shardId(), updateSnapshotState.status().state());
+                                if (shards == null) {
+                                    shards = newHashMap(entry.shards());
+                                }
+                                shards.put(updateSnapshotState.shardId(), updateSnapshotState.status());
+                                changedCount++;
+                            }
+                        }
+
+                        if (shards != null) {
                             if (!completed(shards)) {
                             if (!completed(shards)) {
                                 entries.add(new RestoreMetaData.Entry(entry.snapshotId(), RestoreMetaData.State.STARTED, entry.indices(), ImmutableMap.copyOf(shards)));
                                 entries.add(new RestoreMetaData.Entry(entry.snapshotId(), RestoreMetaData.State.STARTED, entry.indices(), ImmutableMap.copyOf(shards)));
                             } else {
                             } else {
-                                logger.info("restore [{}] is done", request.snapshotId());
-                                restoreInfo = new RestoreInfo(entry.snapshotId().getSnapshot(), entry.indices(), shards.size(), shards.size() - failedShards(shards));
-                                this.shards = shards;
+                                logger.info("restore [{}] is done", entry.snapshotId());
+                                if (batchedRestoreInfo == null) {
+                                    batchedRestoreInfo = newHashMap();
+                                }
+                                assert !batchedRestoreInfo.containsKey(entry.snapshotId());
+                                batchedRestoreInfo.put(entry.snapshotId(),
+                                    new Tuple<>(
+                                        new RestoreInfo(entry.snapshotId().getSnapshot(), entry.indices(), shards.size(), shards.size() - failedShards(shards)),
+                                        shards));
                             }
                             }
-                            changed = true;
                         } else {
                         } else {
                             entries.add(entry);
                             entries.add(entry);
                         }
                         }
                     }
                     }
-                    if (changed) {
-                        restore = new RestoreMetaData(entries.toArray(new RestoreMetaData.Entry[entries.size()]));
-                        mdBuilder.putCustom(RestoreMetaData.TYPE, restore);
+
+                    if (changedCount > 0) {
+                        logger.trace("changed cluster state triggered by {} snapshot restore state updates", changedCount);
+
+                        final RestoreMetaData updatedRestore = new RestoreMetaData(entries.toArray(new RestoreMetaData.Entry[entries.size()]));
+                        final MetaData.Builder mdBuilder = MetaData.builder(currentState.metaData()).putCustom(RestoreMetaData.TYPE, updatedRestore);
                         return ClusterState.builder(currentState).metaData(mdBuilder).build();
                         return ClusterState.builder(currentState).metaData(mdBuilder).build();
                     }
                     }
                 }
                 }
@@ -513,48 +553,55 @@ public class RestoreService extends AbstractComponent implements ClusterStateLis
 
 
             @Override
             @Override
             public void onFailure(String source, @Nullable Throwable t) {
             public void onFailure(String source, @Nullable Throwable t) {
-                logger.warn("[{}][{}] failed to update snapshot status to [{}]", t, request.snapshotId(), request.shardId(), request.status());
+                for (UpdateIndexShardRestoreStatusRequest request : drainedRequests) {
+                    logger.warn("[{}][{}] failed to update snapshot status to [{}]", t, request.snapshotId(), request.shardId(), request.status());
+                }
             }
             }
 
 
             @Override
             @Override
             public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
             public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
-                if (restoreInfo != null) {
-                    RoutingTable routingTable = newState.getRoutingTable();
-                    final List<ShardId> waitForStarted = newArrayList();
-                    for (Map.Entry<ShardId, ShardRestoreStatus> shard : shards.entrySet()) {
-                        if (shard.getValue().state() == RestoreMetaData.State.SUCCESS ) {
-                            ShardId shardId = shard.getKey();
-                            ShardRouting shardRouting = findPrimaryShard(routingTable, shardId);
-                            if (shardRouting != null && !shardRouting.active()) {
-                                logger.trace("[{}][{}] waiting for the shard to start", request.snapshotId(), shardId);
-                                waitForStarted.add(shardId);
+                if (batchedRestoreInfo != null) {
+                    for (final Entry<SnapshotId, Tuple<RestoreInfo, Map<ShardId, ShardRestoreStatus>>> entry : batchedRestoreInfo.entrySet()) {
+                        final SnapshotId snapshotId = entry.getKey();
+                        final RestoreInfo restoreInfo = entry.getValue().v1();
+                        final Map<ShardId, ShardRestoreStatus> shards = entry.getValue().v2();
+                        RoutingTable routingTable = newState.getRoutingTable();
+                        final List<ShardId> waitForStarted = newArrayList();
+                        for (Map.Entry<ShardId, ShardRestoreStatus> shard : shards.entrySet()) {
+                            if (shard.getValue().state() == RestoreMetaData.State.SUCCESS ) {
+                                ShardId shardId = shard.getKey();
+                                ShardRouting shardRouting = findPrimaryShard(routingTable, shardId);
+                                if (shardRouting != null && !shardRouting.active()) {
+                                    logger.trace("[{}][{}] waiting for the shard to start", snapshotId, shardId);
+                                    waitForStarted.add(shardId);
+                                }
                             }
                             }
                         }
                         }
-                    }
-                    if (waitForStarted.isEmpty()) {
-                        notifyListeners();
-                    } else {
-                        clusterService.addLast(new ClusterStateListener() {
-                            @Override
-                            public void clusterChanged(ClusterChangedEvent event) {
-                                if (event.routingTableChanged()) {
-                                    RoutingTable routingTable = event.state().getRoutingTable();
-                                    for (Iterator<ShardId> iterator = waitForStarted.iterator(); iterator.hasNext();) {
-                                        ShardId shardId = iterator.next();
-                                        ShardRouting shardRouting = findPrimaryShard(routingTable, shardId);
-                                        // Shard disappeared (index deleted) or became active
-                                        if (shardRouting == null || shardRouting.active()) {
-                                            iterator.remove();
-                                            logger.trace("[{}][{}] shard disappeared or started - removing", request.snapshotId(), shardId);
+                        if (waitForStarted.isEmpty()) {
+                            notifyListeners(snapshotId, restoreInfo);
+                        } else {
+                            clusterService.addLast(new ClusterStateListener() {
+                                @Override
+                                public void clusterChanged(ClusterChangedEvent event) {
+                                    if (event.routingTableChanged()) {
+                                        RoutingTable routingTable = event.state().getRoutingTable();
+                                        for (Iterator<ShardId> iterator = waitForStarted.iterator(); iterator.hasNext();) {
+                                            ShardId shardId = iterator.next();
+                                            ShardRouting shardRouting = findPrimaryShard(routingTable, shardId);
+                                            // Shard disappeared (index deleted) or became active
+                                            if (shardRouting == null || shardRouting.active()) {
+                                                iterator.remove();
+                                                logger.trace("[{}][{}] shard disappeared or started - removing", snapshotId, shardId);
+                                            }
                                         }
                                         }
                                     }
                                     }
+                                    if (waitForStarted.isEmpty()) {
+                                        notifyListeners(snapshotId, restoreInfo);
+                                        clusterService.remove(this);
+                                    }
                                 }
                                 }
-                                if (waitForStarted.isEmpty()) {
-                                    notifyListeners();
-                                    clusterService.remove(this);
-                                }
-                            }
-                        });
+                            });
+                        }
                     }
                     }
                 }
                 }
             }
             }
@@ -570,10 +617,10 @@ public class RestoreService extends AbstractComponent implements ClusterStateLis
                 return null;
                 return null;
             }
             }
 
 
-            private void notifyListeners() {
+            private void notifyListeners(SnapshotId snapshotId, RestoreInfo restoreInfo) {
                 for (ActionListener<RestoreCompletionResponse> listener : listeners) {
                 for (ActionListener<RestoreCompletionResponse> listener : listeners) {
                     try {
                     try {
-                        listener.onResponse(new RestoreCompletionResponse(request.snapshotId, restoreInfo));
+                        listener.onResponse(new RestoreCompletionResponse(snapshotId, restoreInfo));
                     } catch (Throwable e) {
                     } catch (Throwable e) {
                         logger.warn("failed to update snapshot status for [{}]", e, listener);
                         logger.warn("failed to update snapshot status for [{}]", e, listener);
                     }
                     }
@@ -952,6 +999,8 @@ public class RestoreService extends AbstractComponent implements ClusterStateLis
         private ShardId shardId;
         private ShardId shardId;
         private ShardRestoreStatus status;
         private ShardRestoreStatus status;
 
 
+        volatile boolean processed; // state field, no need to serialize
+
         private UpdateIndexShardRestoreStatusRequest() {
         private UpdateIndexShardRestoreStatusRequest() {
 
 
         }
         }
@@ -989,6 +1038,11 @@ public class RestoreService extends AbstractComponent implements ClusterStateLis
         public ShardRestoreStatus status() {
         public ShardRestoreStatus status() {
             return status;
             return status;
         }
         }
+
+        @Override
+        public String toString() {
+            return "" + snapshotId + ", shardId [" + shardId + "], status [" + status.state() + "]";
+        }
     }
     }
 
 
     /**
     /**

+ 60 - 14
src/main/java/org/elasticsearch/snapshots/SnapshotsService.java

@@ -21,6 +21,7 @@ package org.elasticsearch.snapshots;
 
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableMap;
+
 import org.apache.lucene.util.CollectionUtil;
 import org.apache.lucene.util.CollectionUtil;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.ExceptionsHelper;
@@ -44,6 +45,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.snapshots.IndexShardRepository;
 import org.elasticsearch.index.snapshots.IndexShardRepository;
 import org.elasticsearch.index.snapshots.IndexShardSnapshotAndRestoreService;
 import org.elasticsearch.index.snapshots.IndexShardSnapshotAndRestoreService;
@@ -58,6 +60,7 @@ import org.elasticsearch.transport.*;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.util.*;
 import java.util.*;
+import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.locks.Condition;
 import java.util.concurrent.locks.Condition;
@@ -107,6 +110,7 @@ public class SnapshotsService extends AbstractLifecycleComponent<SnapshotsServic
 
 
     private final CopyOnWriteArrayList<SnapshotCompletionListener> snapshotCompletionListeners = new CopyOnWriteArrayList<>();
     private final CopyOnWriteArrayList<SnapshotCompletionListener> snapshotCompletionListeners = new CopyOnWriteArrayList<>();
 
 
+    private final BlockingQueue<UpdateIndexShardSnapshotStatusRequest> updatedSnapshotStateQueue = ConcurrentCollections.newBlockingQueue();
 
 
     @Inject
     @Inject
     public SnapshotsService(Settings settings, ClusterService clusterService, RepositoriesService repositoriesService, ThreadPool threadPool,
     public SnapshotsService(Settings settings, ClusterService clusterService, RepositoriesService repositoriesService, ThreadPool threadPool,
@@ -935,20 +939,51 @@ public class SnapshotsService extends AbstractLifecycleComponent<SnapshotsServic
      * @param request update shard status request
      * @param request update shard status request
      */
      */
     private void innerUpdateSnapshotState(final UpdateIndexShardSnapshotStatusRequest request) {
     private void innerUpdateSnapshotState(final UpdateIndexShardSnapshotStatusRequest request) {
+        logger.trace("received updated snapshot restore state [{}]", request);
+        updatedSnapshotStateQueue.add(request);
+
         clusterService.submitStateUpdateTask("update snapshot state", new ClusterStateUpdateTask() {
         clusterService.submitStateUpdateTask("update snapshot state", new ClusterStateUpdateTask() {
+            private final List<UpdateIndexShardSnapshotStatusRequest> drainedRequests = new ArrayList<>();
+
             @Override
             @Override
             public ClusterState execute(ClusterState currentState) {
             public ClusterState execute(ClusterState currentState) {
-                MetaData metaData = currentState.metaData();
-                MetaData.Builder mdBuilder = MetaData.builder(currentState.metaData());
-                SnapshotMetaData snapshots = metaData.custom(SnapshotMetaData.TYPE);
+
+                if (request.processed) {
+                    return currentState;
+                }
+
+                updatedSnapshotStateQueue.drainTo(drainedRequests);
+
+                final int batchSize = drainedRequests.size();
+
+                // nothing to process (a previous event has processed it already)
+                if (batchSize == 0) {
+                    return currentState;
+                }
+
+                final MetaData metaData = currentState.metaData();
+                final SnapshotMetaData snapshots = metaData.custom(SnapshotMetaData.TYPE);
                 if (snapshots != null) {
                 if (snapshots != null) {
-                    boolean changed = false;
-                    ArrayList<SnapshotMetaData.Entry> entries = newArrayList();
+                    int changedCount = 0;
+                    final List<SnapshotMetaData.Entry> entries = newArrayList();
                     for (SnapshotMetaData.Entry entry : snapshots.entries()) {
                     for (SnapshotMetaData.Entry entry : snapshots.entries()) {
-                        if (entry.snapshotId().equals(request.snapshotId())) {
-                            HashMap<ShardId, ShardSnapshotStatus> shards = newHashMap(entry.shards());
-                            logger.trace("[{}] Updating shard [{}] with status [{}]", request.snapshotId(), request.shardId(), request.status().state());
-                            shards.put(request.shardId(), request.status());
+                        HashMap<ShardId, ShardSnapshotStatus> shards = null;
+
+                        for (int i = 0; i < batchSize; i++) {
+                            final UpdateIndexShardSnapshotStatusRequest updateSnapshotState = drainedRequests.get(i);
+                            updateSnapshotState.processed = true;
+
+                            if (entry.snapshotId().equals(updateSnapshotState.snapshotId())) {
+                                logger.trace("[{}] Updating shard [{}] with status [{}]", updateSnapshotState.snapshotId(), updateSnapshotState.shardId(), updateSnapshotState.status().state());
+                                if (shards == null) {
+                                    shards = newHashMap(entry.shards());
+                                }
+                                shards.put(updateSnapshotState.shardId(), updateSnapshotState.status());
+                                changedCount++;
+                            }
+                        }
+
+                        if (shards != null) {
                             if (!completed(shards.values())) {
                             if (!completed(shards.values())) {
                                 entries.add(new SnapshotMetaData.Entry(entry, ImmutableMap.copyOf(shards)));
                                 entries.add(new SnapshotMetaData.Entry(entry, ImmutableMap.copyOf(shards)));
                             } else {
                             } else {
@@ -960,14 +995,15 @@ public class SnapshotsService extends AbstractLifecycleComponent<SnapshotsServic
                                 endSnapshot(updatedEntry);
                                 endSnapshot(updatedEntry);
                                 logger.info("snapshot [{}] is done", updatedEntry.snapshotId());
                                 logger.info("snapshot [{}] is done", updatedEntry.snapshotId());
                             }
                             }
-                            changed = true;
                         } else {
                         } else {
                             entries.add(entry);
                             entries.add(entry);
                         }
                         }
                     }
                     }
-                    if (changed) {
-                        snapshots = new SnapshotMetaData(entries.toArray(new SnapshotMetaData.Entry[entries.size()]));
-                        mdBuilder.putCustom(SnapshotMetaData.TYPE, snapshots);
+                    if (changedCount > 0) {
+                        logger.trace("changed cluster state triggered by {} snapshot state updates", changedCount);
+
+                        final SnapshotMetaData updatedSnapshots = new SnapshotMetaData(entries.toArray(new SnapshotMetaData.Entry[entries.size()]));
+                        final MetaData.Builder mdBuilder = MetaData.builder(currentState.metaData()).putCustom(SnapshotMetaData.TYPE, updatedSnapshots);
                         return ClusterState.builder(currentState).metaData(mdBuilder).build();
                         return ClusterState.builder(currentState).metaData(mdBuilder).build();
                     }
                     }
                 }
                 }
@@ -976,7 +1012,9 @@ public class SnapshotsService extends AbstractLifecycleComponent<SnapshotsServic
 
 
             @Override
             @Override
             public void onFailure(String source, Throwable t) {
             public void onFailure(String source, Throwable t) {
-                logger.warn("[{}][{}] failed to update snapshot status to [{}]", t, request.snapshotId(), request.shardId(), request.status());
+                for (UpdateIndexShardSnapshotStatusRequest request : drainedRequests) {
+                    logger.warn("[{}][{}] failed to update snapshot status to [{}]", t, request.snapshotId(), request.shardId(), request.status());
+                }
             }
             }
         });
         });
     }
     }
@@ -1562,6 +1600,8 @@ public class SnapshotsService extends AbstractLifecycleComponent<SnapshotsServic
         private ShardId shardId;
         private ShardId shardId;
         private SnapshotMetaData.ShardSnapshotStatus status;
         private SnapshotMetaData.ShardSnapshotStatus status;
 
 
+        volatile boolean processed; // state field, no need to serialize
+
         private UpdateIndexShardSnapshotStatusRequest() {
         private UpdateIndexShardSnapshotStatusRequest() {
 
 
         }
         }
@@ -1599,6 +1639,12 @@ public class SnapshotsService extends AbstractLifecycleComponent<SnapshotsServic
         public SnapshotMetaData.ShardSnapshotStatus status() {
         public SnapshotMetaData.ShardSnapshotStatus status() {
             return status;
             return status;
         }
         }
+
+        @Override
+        public String toString()
+        {
+            return "" + snapshotId + ", shardId [" + shardId + "], status [" + status.state() + "]";
+        }
     }
     }
 
 
     /**
     /**

+ 218 - 3
src/test/java/org/elasticsearch/snapshots/SharedClusterSnapshotRestoreTests.java

@@ -33,20 +33,21 @@ import org.elasticsearch.action.admin.cluster.snapshots.get.GetSnapshotsResponse
 import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse;
 import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse;
 import org.elasticsearch.action.admin.cluster.snapshots.status.*;
 import org.elasticsearch.action.admin.cluster.snapshots.status.*;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
+import org.elasticsearch.action.admin.cluster.tasks.PendingClusterTasksResponse;
 import org.elasticsearch.action.admin.indices.flush.FlushResponse;
 import org.elasticsearch.action.admin.indices.flush.FlushResponse;
 import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse;
 import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse;
 import org.elasticsearch.action.admin.indices.template.get.GetIndexTemplatesResponse;
 import org.elasticsearch.action.admin.indices.template.get.GetIndexTemplatesResponse;
 import org.elasticsearch.action.count.CountResponse;
 import org.elasticsearch.action.count.CountResponse;
 import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.client.Client;
-import org.elasticsearch.cluster.ClusterService;
-import org.elasticsearch.cluster.ClusterState;
-import org.elasticsearch.cluster.ProcessedClusterStateUpdateTask;
+import org.elasticsearch.cluster.*;
 import org.elasticsearch.cluster.metadata.*;
 import org.elasticsearch.cluster.metadata.*;
 import org.elasticsearch.cluster.metadata.SnapshotMetaData.Entry;
 import org.elasticsearch.cluster.metadata.SnapshotMetaData.Entry;
 import org.elasticsearch.cluster.metadata.SnapshotMetaData.ShardSnapshotStatus;
 import org.elasticsearch.cluster.metadata.SnapshotMetaData.ShardSnapshotStatus;
 import org.elasticsearch.cluster.metadata.SnapshotMetaData.State;
 import org.elasticsearch.cluster.metadata.SnapshotMetaData.State;
 import org.elasticsearch.cluster.routing.allocation.decider.FilterAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.FilterAllocationDecider;
+import org.elasticsearch.cluster.service.PendingClusterTask;
+import org.elasticsearch.common.Priority;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.settings.ImmutableSettings;
 import org.elasticsearch.common.settings.ImmutableSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
@@ -56,6 +57,7 @@ import org.elasticsearch.index.store.IndexStore;
 import org.elasticsearch.indices.InvalidIndexNameException;
 import org.elasticsearch.indices.InvalidIndexNameException;
 import org.elasticsearch.repositories.RepositoriesService;
 import org.elasticsearch.repositories.RepositoriesService;
 import org.elasticsearch.snapshots.mockstore.MockRepositoryModule;
 import org.elasticsearch.snapshots.mockstore.MockRepositoryModule;
+import org.elasticsearch.test.junit.annotations.TestLogging;
 import org.junit.Test;
 import org.junit.Test;
 
 
 import java.nio.channels.SeekableByteChannel;
 import java.nio.channels.SeekableByteChannel;
@@ -72,6 +74,7 @@ import java.util.concurrent.TimeUnit;
 import static com.google.common.collect.Lists.newArrayList;
 import static com.google.common.collect.Lists.newArrayList;
 import static org.elasticsearch.cluster.metadata.IndexMetaData.SETTING_NUMBER_OF_REPLICAS;
 import static org.elasticsearch.cluster.metadata.IndexMetaData.SETTING_NUMBER_OF_REPLICAS;
 import static org.elasticsearch.cluster.metadata.IndexMetaData.SETTING_NUMBER_OF_SHARDS;
 import static org.elasticsearch.cluster.metadata.IndexMetaData.SETTING_NUMBER_OF_SHARDS;
+import static org.elasticsearch.common.settings.ImmutableSettings.settingsBuilder;
 import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
 import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
 import static org.elasticsearch.index.shard.IndexShard.INDEX_REFRESH_INTERVAL;
 import static org.elasticsearch.index.shard.IndexShard.INDEX_REFRESH_INTERVAL;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.*;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.*;
@@ -1820,4 +1823,216 @@ public class SharedClusterSnapshotRestoreTests extends AbstractSnapshotTests {
             }
             }
         }, timeout.millis(), TimeUnit.MILLISECONDS);
         }, timeout.millis(), TimeUnit.MILLISECONDS);
     }
     }
+
+    @Test
+    @TestLogging("cluster:DEBUG")
+    public void batchingShardUpdateTaskTest() throws Exception {
+
+        final Client client = client();
+
+        logger.info("-->  creating repository");
+        assertAcked(client.admin().cluster().preparePutRepository("test-repo")
+                .setType("fs").setSettings(ImmutableSettings.settingsBuilder()
+                        .put("location", createTempDir())
+                        .put("compress", randomBoolean())
+                        .put("chunk_size", randomIntBetween(100, 1000))));
+
+        assertAcked(prepareCreate("test-idx", 0, settingsBuilder().put("number_of_shards", between(1, 20))
+                .put("number_of_replicas", 0)));
+        ensureGreen();
+
+        logger.info("--> indexing some data");
+        final int numdocs = randomIntBetween(10, 100);
+        IndexRequestBuilder[] builders = new IndexRequestBuilder[numdocs];
+        for (int i = 0; i < builders.length; i++) {
+            builders[i] = client().prepareIndex("test-idx", "type1", Integer.toString(i)).setSource("field1", "bar " + i);
+        }
+        indexRandom(true, builders);
+        flushAndRefresh();
+
+        final int numberOfShards = getNumShards("test-idx").numPrimaries;
+        logger.info("number of shards: {}", numberOfShards);
+
+        final ClusterService clusterService = internalCluster().clusterService(internalCluster().getMasterName());
+        BlockingClusterStateListener snapshotListener = new BlockingClusterStateListener(clusterService, "update_snapshot [", "update snapshot state", Priority.HIGH);
+        try {
+            clusterService.addFirst(snapshotListener);
+            logger.info("--> snapshot");
+            ListenableActionFuture<CreateSnapshotResponse> snapshotFuture = client.admin().cluster().prepareCreateSnapshot("test-repo", "test-snap").setWaitForCompletion(true).setIndices("test-idx").execute();
+
+            // Await until shard updates are in pending state.
+            assertTrue(waitForPendingTasks("update snapshot state", numberOfShards));
+            snapshotListener.unblock();
+
+            // Check that the snapshot was successful
+            CreateSnapshotResponse createSnapshotResponse = snapshotFuture.actionGet();
+            assertEquals(SnapshotState.SUCCESS, createSnapshotResponse.getSnapshotInfo().state());
+            assertEquals(numberOfShards, createSnapshotResponse.getSnapshotInfo().totalShards());
+            assertEquals(numberOfShards, createSnapshotResponse.getSnapshotInfo().successfulShards());
+
+        } finally {
+            clusterService.remove(snapshotListener);
+        }
+
+        // Check that we didn't timeout
+        assertFalse(snapshotListener.timedOut());
+        // Check that cluster state update task was called only once
+        assertEquals(1, snapshotListener.count());
+
+        logger.info("--> close indices");
+        client.admin().indices().prepareClose("test-idx").get();
+
+        BlockingClusterStateListener restoreListener = new BlockingClusterStateListener(clusterService, "restore_snapshot[", "update snapshot state", Priority.HIGH);
+
+        try {
+            clusterService.addFirst(restoreListener);
+            logger.info("--> restore snapshot");
+            ListenableActionFuture<RestoreSnapshotResponse> futureRestore = client.admin().cluster().prepareRestoreSnapshot("test-repo", "test-snap").setWaitForCompletion(true).execute();
+
+            // Await until shard updates are in pending state.
+            assertTrue(waitForPendingTasks("update snapshot state", numberOfShards));
+            restoreListener.unblock();
+
+            RestoreSnapshotResponse restoreSnapshotResponse = futureRestore.actionGet();
+            assertThat(restoreSnapshotResponse.getRestoreInfo().totalShards(), equalTo(numberOfShards));
+
+        } finally {
+            clusterService.remove(restoreListener);
+        }
+
+        // Check that we didn't timeout
+        assertFalse(restoreListener.timedOut());
+        // Check that cluster state update task was called only once
+        assertEquals(1, restoreListener.count());
+    }
+
+    private boolean waitForPendingTasks(final String taskPrefix, final int expectedCount) throws InterruptedException {
+        return awaitBusy(new Predicate<Object>() {
+            @Override
+            public boolean apply(Object o) {
+                PendingClusterTasksResponse tasks = client().admin().cluster().preparePendingClusterTasks().get();
+                int count = 0;
+                for(PendingClusterTask task : tasks) {
+                    if (task.getSource().toString().startsWith(taskPrefix)) {
+                        count++;
+                    }
+                }
+                return expectedCount == count;
+            }
+        });
+    }
+
+    /**
+     * Cluster state task that blocks waits for the blockOn task to show up and then blocks execution not letting
+     * any cluster state update task to be performed unless they have priority higher then passThroughPriority.
+     *
+     * This class is useful to testing of cluster state update task batching for lower priority tasks.
+     */
+    public class BlockingClusterStateListener implements ClusterStateListener {
+
+        private final Predicate<ClusterChangedEvent> blockOn;
+
+        private final Predicate<ClusterChangedEvent> countOn;
+
+        private final ClusterService clusterService;
+
+        private final CountDownLatch latch;
+
+        private final Priority passThroughPriority;
+
+        private int count;
+
+        private boolean timedOut;
+
+        private final TimeValue timeout;
+
+        private long stopWaitingAt = -1;
+
+        public BlockingClusterStateListener(ClusterService clusterService, String blockOn, String countOn, Priority passThroughPriority) {
+            this(clusterService, blockOn, countOn, passThroughPriority, TimeValue.timeValueSeconds(10));
+        }
+
+        public BlockingClusterStateListener(ClusterService clusterService, final String blockOn, final String countOn, Priority passThroughPriority, TimeValue timeout) {
+            this.clusterService = clusterService;
+            this.blockOn = new Predicate<ClusterChangedEvent>() {
+                @Override
+                public boolean apply(ClusterChangedEvent clusterChangedEvent) {
+                    return clusterChangedEvent.source().startsWith(blockOn);
+                }
+            };
+            this.countOn = new Predicate<ClusterChangedEvent>() {
+                @Override
+                public boolean apply(ClusterChangedEvent clusterChangedEvent) {
+                    return clusterChangedEvent.source().startsWith(countOn);
+                }
+            };
+            this.latch = new CountDownLatch(1);
+            this.passThroughPriority = passThroughPriority;
+            this.timeout = timeout;
+
+        }
+
+        public void unblock() {
+            latch.countDown();
+        }
+
+        @Override
+        public void clusterChanged(ClusterChangedEvent event) {
+            if (blockOn.apply(event)) {
+                logger.info("blocking cluster state tasks on [{}]", event.source());
+                assert stopWaitingAt < 0; // Make sure we are the first time here
+                stopWaitingAt = System.currentTimeMillis() + timeout.getMillis();
+                addBlock();
+            }
+            if (countOn.apply(event)) {
+                count++;
+            }
+        }
+
+        private void addBlock() {
+            // We should block after this task - add blocking cluster state update task
+            clusterService.submitStateUpdateTask("test_block", passThroughPriority, new ClusterStateUpdateTask() {
+                @Override
+                public ClusterState execute(ClusterState currentState) throws Exception {
+                    while(System.currentTimeMillis() < stopWaitingAt) {
+                        for (PendingClusterTask task : clusterService.pendingTasks()) {
+                            if (task.getSource().string().equals("test_block") == false && passThroughPriority.sameOrAfter(task.getPriority())) {
+                                // There are other higher priority tasks in the queue and let them pass through and then set the block again
+                                logger.info("passing through cluster state task {}", task.getSource());
+                                addBlock();
+                                return currentState;
+                            }
+                        }
+                        try {
+                            logger.info("wating....");
+                            if (latch.await(Math.min(100, timeout.millis()), TimeUnit.MILLISECONDS)){
+                                // Done waiting - unblock
+                                logger.info("unblocked");
+                                return currentState;
+                            }
+                            logger.info("done wating....");
+                        } catch (InterruptedException ex) {
+                            Thread.currentThread().interrupt();
+                        }
+                    }
+                    timedOut = true;
+                    return currentState;
+                }
+
+                @Override
+                public void onFailure(String source, Throwable t) {
+                    logger.warn("failed to execute [{}]", t, source);
+                }
+            });
+
+        }
+
+        public int count() {
+            return count;
+        }
+
+        public boolean timedOut() {
+            return timedOut;
+        }
+    }
 }
 }