|
@@ -13,7 +13,6 @@ import org.elasticsearch.action.TaskOperationFailure;
|
|
|
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
|
|
|
import org.elasticsearch.action.support.tasks.BaseTasksResponse;
|
|
|
import org.elasticsearch.cluster.node.DiscoveryNode;
|
|
|
-import org.elasticsearch.cluster.node.DiscoveryNodes;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.common.io.stream.Writeable;
|
|
@@ -32,14 +31,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.time.Instant;
|
|
|
-import java.util.ArrayList;
|
|
|
import java.util.Collections;
|
|
|
-import java.util.Comparator;
|
|
|
-import java.util.HashSet;
|
|
|
import java.util.List;
|
|
|
-import java.util.Map;
|
|
|
import java.util.Objects;
|
|
|
-import java.util.Set;
|
|
|
|
|
|
public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsAction.Response> {
|
|
|
|
|
@@ -129,13 +123,15 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
private final Instant lastAccess;
|
|
|
private final Integer pendingCount;
|
|
|
private final RoutingStateAndReason routingState;
|
|
|
+ private final Instant startTime;
|
|
|
|
|
|
public static NodeStats forStartedState(
|
|
|
DiscoveryNode node,
|
|
|
long inferenceCount,
|
|
|
Double avgInferenceTime,
|
|
|
int pendingCount,
|
|
|
- Instant lastAccess
|
|
|
+ Instant lastAccess,
|
|
|
+ Instant startTime
|
|
|
) {
|
|
|
return new NodeStats(
|
|
|
node,
|
|
@@ -143,12 +139,13 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
avgInferenceTime,
|
|
|
lastAccess,
|
|
|
pendingCount,
|
|
|
- new RoutingStateAndReason(RoutingState.STARTED, null)
|
|
|
+ new RoutingStateAndReason(RoutingState.STARTED, null),
|
|
|
+ Objects.requireNonNull(startTime)
|
|
|
);
|
|
|
}
|
|
|
|
|
|
public static NodeStats forNotStartedState(DiscoveryNode node, RoutingState state, String reason) {
|
|
|
- return new NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason));
|
|
|
+ return new NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason), null);
|
|
|
}
|
|
|
|
|
|
private NodeStats(
|
|
@@ -157,7 +154,8 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
Double avgInferenceTime,
|
|
|
Instant lastAccess,
|
|
|
Integer pendingCount,
|
|
|
- RoutingStateAndReason routingState
|
|
|
+ RoutingStateAndReason routingState,
|
|
|
+ @Nullable Instant startTime
|
|
|
) {
|
|
|
this.node = node;
|
|
|
this.inferenceCount = inferenceCount;
|
|
@@ -165,6 +163,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
this.lastAccess = lastAccess;
|
|
|
this.pendingCount = pendingCount;
|
|
|
this.routingState = routingState;
|
|
|
+ this.startTime = startTime;
|
|
|
|
|
|
// if lastAccess time is null there have been no inferences
|
|
|
assert this.lastAccess != null || (inferenceCount == null || inferenceCount == 0);
|
|
@@ -177,6 +176,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
this.lastAccess = in.readOptionalInstant();
|
|
|
this.pendingCount = in.readOptionalVInt();
|
|
|
this.routingState = in.readOptionalWriteable(RoutingStateAndReason::new);
|
|
|
+ this.startTime = in.readOptionalInstant();
|
|
|
}
|
|
|
|
|
|
public DiscoveryNode getNode() {
|
|
@@ -208,6 +208,9 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
if (pendingCount != null) {
|
|
|
builder.field("number_of_pending_requests", pendingCount);
|
|
|
}
|
|
|
+ if (startTime != null) {
|
|
|
+ builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
|
|
|
+ }
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
|
}
|
|
@@ -220,6 +223,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
out.writeOptionalInstant(lastAccess);
|
|
|
out.writeOptionalVInt(pendingCount);
|
|
|
out.writeOptionalWriteable(routingState);
|
|
|
+ out.writeOptionalInstant(startTime);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -232,12 +236,13 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
&& Objects.equals(node, that.node)
|
|
|
&& Objects.equals(lastAccess, that.lastAccess)
|
|
|
&& Objects.equals(pendingCount, that.pendingCount)
|
|
|
- && Objects.equals(routingState, that.routingState);
|
|
|
+ && Objects.equals(routingState, that.routingState)
|
|
|
+ && Objects.equals(startTime, that.startTime);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(node, inferenceCount, avgInferenceTime, lastAccess, pendingCount, routingState);
|
|
|
+ return Objects.hash(node, inferenceCount, avgInferenceTime, lastAccess, pendingCount, routingState, startTime);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -253,6 +258,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
private final Integer modelThreads;
|
|
|
@Nullable
|
|
|
private final Integer queueCapacity;
|
|
|
+ private final Instant startTime;
|
|
|
private final List<NodeStats> nodeStats;
|
|
|
|
|
|
public AllocationStats(
|
|
@@ -261,6 +267,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
@Nullable Integer inferenceThreads,
|
|
|
@Nullable Integer modelThreads,
|
|
|
@Nullable Integer queueCapacity,
|
|
|
+ Instant startTime,
|
|
|
List<NodeStats> nodeStats
|
|
|
) {
|
|
|
this.modelId = modelId;
|
|
@@ -268,6 +275,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
this.inferenceThreads = inferenceThreads;
|
|
|
this.modelThreads = modelThreads;
|
|
|
this.queueCapacity = queueCapacity;
|
|
|
+ this.startTime = Objects.requireNonNull(startTime);
|
|
|
this.nodeStats = nodeStats;
|
|
|
this.state = null;
|
|
|
this.reason = null;
|
|
@@ -279,6 +287,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
inferenceThreads = in.readOptionalVInt();
|
|
|
modelThreads = in.readOptionalVInt();
|
|
|
queueCapacity = in.readOptionalVInt();
|
|
|
+ startTime = in.readInstant();
|
|
|
nodeStats = in.readList(NodeStats::new);
|
|
|
state = in.readOptionalEnum(AllocationState.class);
|
|
|
reason = in.readOptionalString();
|
|
@@ -308,6 +317,10 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
return queueCapacity;
|
|
|
}
|
|
|
|
|
|
+ public Instant getStartTime() {
|
|
|
+ return startTime;
|
|
|
+ }
|
|
|
+
|
|
|
public List<NodeStats> getNodeStats() {
|
|
|
return nodeStats;
|
|
|
}
|
|
@@ -360,6 +373,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
if (allocationStatus != null) {
|
|
|
builder.field("allocation_status", allocationStatus);
|
|
|
}
|
|
|
+ builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
|
|
|
builder.startArray("nodes");
|
|
|
for (NodeStats nodeStat : nodeStats) {
|
|
|
nodeStat.toXContent(builder, params);
|
|
@@ -376,6 +390,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
out.writeOptionalVInt(inferenceThreads);
|
|
|
out.writeOptionalVInt(modelThreads);
|
|
|
out.writeOptionalVInt(queueCapacity);
|
|
|
+ out.writeInstant(startTime);
|
|
|
out.writeList(nodeStats);
|
|
|
out.writeOptionalEnum(state);
|
|
|
out.writeOptionalString(reason);
|
|
@@ -392,6 +407,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
&& Objects.equals(inferenceThreads, that.inferenceThreads)
|
|
|
&& Objects.equals(modelThreads, that.modelThreads)
|
|
|
&& Objects.equals(queueCapacity, that.queueCapacity)
|
|
|
+ && Objects.equals(startTime, that.startTime)
|
|
|
&& Objects.equals(state, that.state)
|
|
|
&& Objects.equals(reason, that.reason)
|
|
|
&& Objects.equals(allocationStatus, that.allocationStatus)
|
|
@@ -406,6 +422,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
inferenceThreads,
|
|
|
modelThreads,
|
|
|
queueCapacity,
|
|
|
+ startTime,
|
|
|
nodeStats,
|
|
|
state,
|
|
|
reason,
|
|
@@ -464,118 +481,5 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
return Objects.hash(super.hashCode(), stats);
|
|
|
}
|
|
|
|
|
|
- /**
|
|
|
- * Update the collected task responses with the non-started
|
|
|
- * allocation information. The result is the task responses
|
|
|
- * merged with the non-started model allocations.
|
|
|
- *
|
|
|
- * Where there is a merge collision for the pair {@code <model_id, node_id>}
|
|
|
- * the non-started allocations are used.
|
|
|
- *
|
|
|
- * @param tasksResponse All the responses from the tasks
|
|
|
- * @param nonStartedModelRoutes Non-started routes
|
|
|
- * @param nodes current cluster nodes
|
|
|
- * @return The result of merging tasksResponse and the non-started routes
|
|
|
- */
|
|
|
- public static GetDeploymentStatsAction.Response addFailedRoutes(
|
|
|
- GetDeploymentStatsAction.Response tasksResponse,
|
|
|
- Map<String, Map<String, RoutingStateAndReason>> nonStartedModelRoutes,
|
|
|
- DiscoveryNodes nodes
|
|
|
- ) {
|
|
|
-
|
|
|
- List<GetDeploymentStatsAction.Response.AllocationStats> updatedAllocationStats = new ArrayList<>();
|
|
|
-
|
|
|
- for (GetDeploymentStatsAction.Response.AllocationStats stat : tasksResponse.getStats().results()) {
|
|
|
- if (nonStartedModelRoutes.containsKey(stat.getModelId())) {
|
|
|
- // there is merging to be done
|
|
|
- Map<String, RoutingStateAndReason> nodeToRoutingStates = nonStartedModelRoutes.get(stat.getModelId());
|
|
|
- List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> updatedNodeStats = new ArrayList<>();
|
|
|
-
|
|
|
- Set<String> visitedNodes = new HashSet<>();
|
|
|
- for (var nodeStat : stat.getNodeStats()) {
|
|
|
- if (nodeToRoutingStates.containsKey(nodeStat.getNode().getId())) {
|
|
|
- // conflict as there is both a task response for the model/node pair
|
|
|
- // and we have a non-started routing entry.
|
|
|
- // Prefer the entry from nonStartedModelRoutes as we cannot be sure
|
|
|
- // of the state of the task - it may be starting, started, stopping, or stopped.
|
|
|
- RoutingStateAndReason stateAndReason = nodeToRoutingStates.get(nodeStat.getNode().getId());
|
|
|
- updatedNodeStats.add(
|
|
|
- GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
|
|
|
- nodeStat.getNode(),
|
|
|
- stateAndReason.getState(),
|
|
|
- stateAndReason.getReason()
|
|
|
- )
|
|
|
- );
|
|
|
- } else {
|
|
|
- updatedNodeStats.add(nodeStat);
|
|
|
- }
|
|
|
-
|
|
|
- visitedNodes.add(nodeStat.node.getId());
|
|
|
- }
|
|
|
-
|
|
|
- // add nodes from the failures that were not in the task responses
|
|
|
- for (var nodeRoutingState : nodeToRoutingStates.entrySet()) {
|
|
|
- if (visitedNodes.contains(nodeRoutingState.getKey()) == false) {
|
|
|
- updatedNodeStats.add(
|
|
|
- GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
|
|
|
- nodes.get(nodeRoutingState.getKey()),
|
|
|
- nodeRoutingState.getValue().getState(),
|
|
|
- nodeRoutingState.getValue().getReason()
|
|
|
- )
|
|
|
- );
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- updatedNodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
|
|
|
- updatedAllocationStats.add(
|
|
|
- new GetDeploymentStatsAction.Response.AllocationStats(
|
|
|
- stat.getModelId(),
|
|
|
- stat.getModelSize(),
|
|
|
- stat.getInferenceThreads(),
|
|
|
- stat.getModelThreads(),
|
|
|
- stat.getQueueCapacity(),
|
|
|
- updatedNodeStats
|
|
|
- )
|
|
|
- );
|
|
|
- } else {
|
|
|
- updatedAllocationStats.add(stat);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Merge any models in the non-started that were not in the task responses
|
|
|
- for (var nonStartedEntries : nonStartedModelRoutes.entrySet()) {
|
|
|
- String modelId = nonStartedEntries.getKey();
|
|
|
- if (tasksResponse.getStats().results().stream().anyMatch(e -> modelId.equals(e.getModelId())) == false) {
|
|
|
-
|
|
|
- // no tasks for this model so build the allocation stats from the non-started states
|
|
|
- List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStats = new ArrayList<>();
|
|
|
-
|
|
|
- for (var routingEntry : nonStartedEntries.getValue().entrySet()) {
|
|
|
- nodeStats.add(
|
|
|
- AllocationStats.NodeStats.forNotStartedState(
|
|
|
- nodes.get(routingEntry.getKey()),
|
|
|
- routingEntry.getValue().getState(),
|
|
|
- routingEntry.getValue().getReason()
|
|
|
- )
|
|
|
- );
|
|
|
- }
|
|
|
-
|
|
|
- nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
|
|
|
-
|
|
|
- updatedAllocationStats.add(
|
|
|
- new GetDeploymentStatsAction.Response.AllocationStats(modelId, null, null, null, null, nodeStats)
|
|
|
- );
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- updatedAllocationStats.sort(Comparator.comparing(GetDeploymentStatsAction.Response.AllocationStats::getModelId));
|
|
|
-
|
|
|
- return new GetDeploymentStatsAction.Response(
|
|
|
- tasksResponse.getTaskFailures(),
|
|
|
- tasksResponse.getNodeFailures(),
|
|
|
- updatedAllocationStats,
|
|
|
- updatedAllocationStats.size()
|
|
|
- );
|
|
|
- }
|
|
|
}
|
|
|
}
|