|
@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.action;
|
|
|
import org.apache.logging.log4j.LogManager;
|
|
|
import org.apache.logging.log4j.Logger;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
+import org.elasticsearch.action.ActionRunnable;
|
|
|
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
|
|
|
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
|
|
|
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters;
|
|
@@ -17,7 +18,8 @@ import org.elasticsearch.action.admin.cluster.node.stats.TransportNodesStatsActi
|
|
|
import org.elasticsearch.action.search.SearchRequest;
|
|
|
import org.elasticsearch.action.search.TransportSearchAction;
|
|
|
import org.elasticsearch.action.support.ActionFilters;
|
|
|
-import org.elasticsearch.action.support.HandledTransportAction;
|
|
|
+import org.elasticsearch.action.support.SubscribableListener;
|
|
|
+import org.elasticsearch.action.support.TransportAction;
|
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
import org.elasticsearch.cluster.ClusterState;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
@@ -26,8 +28,6 @@ import org.elasticsearch.common.document.DocumentField;
|
|
|
import org.elasticsearch.common.inject.Inject;
|
|
|
import org.elasticsearch.common.metrics.CounterMetric;
|
|
|
import org.elasticsearch.common.util.Maps;
|
|
|
-import org.elasticsearch.common.util.concurrent.EsExecutors;
|
|
|
-import org.elasticsearch.common.util.concurrent.ListenableFuture;
|
|
|
import org.elasticsearch.common.util.set.Sets;
|
|
|
import org.elasticsearch.core.Tuple;
|
|
|
import org.elasticsearch.index.query.QueryBuilder;
|
|
@@ -37,6 +37,7 @@ import org.elasticsearch.search.SearchHit;
|
|
|
import org.elasticsearch.search.sort.SortOrder;
|
|
|
import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.tasks.TaskId;
|
|
|
+import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.transport.TransportService;
|
|
|
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
|
|
|
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
|
|
@@ -53,6 +54,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
|
|
|
+import org.elasticsearch.xpack.ml.MachineLearning;
|
|
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
|
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
|
|
|
|
@@ -65,6 +67,7 @@ import java.util.LinkedHashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
+import java.util.concurrent.Executor;
|
|
|
import java.util.function.Function;
|
|
|
import java.util.stream.Collectors;
|
|
|
import java.util.stream.Stream;
|
|
@@ -73,7 +76,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
|
|
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsByResource;
|
|
|
|
|
|
-public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<
|
|
|
+public class TransportGetTrainedModelsStatsAction extends TransportAction<
|
|
|
GetTrainedModelsStatsAction.Request,
|
|
|
GetTrainedModelsStatsAction.Response> {
|
|
|
|
|
@@ -82,25 +85,22 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
|
|
|
private final Client client;
|
|
|
private final ClusterService clusterService;
|
|
|
private final TrainedModelProvider trainedModelProvider;
|
|
|
+ private final Executor executor;
|
|
|
|
|
|
@Inject
|
|
|
public TransportGetTrainedModelsStatsAction(
|
|
|
TransportService transportService,
|
|
|
ActionFilters actionFilters,
|
|
|
ClusterService clusterService,
|
|
|
+ ThreadPool threadPool,
|
|
|
TrainedModelProvider trainedModelProvider,
|
|
|
Client client
|
|
|
) {
|
|
|
- super(
|
|
|
- GetTrainedModelsStatsAction.NAME,
|
|
|
- transportService,
|
|
|
- actionFilters,
|
|
|
- GetTrainedModelsStatsAction.Request::new,
|
|
|
- EsExecutors.DIRECT_EXECUTOR_SERVICE
|
|
|
- );
|
|
|
+ super(GetTrainedModelsStatsAction.NAME, actionFilters, transportService.getTaskManager());
|
|
|
this.client = client;
|
|
|
this.clusterService = clusterService;
|
|
|
this.trainedModelProvider = trainedModelProvider;
|
|
|
+ this.executor = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -108,6 +108,15 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
|
|
|
Task task,
|
|
|
GetTrainedModelsStatsAction.Request request,
|
|
|
ActionListener<GetTrainedModelsStatsAction.Response> listener
|
|
|
+ ) {
|
|
|
+ // workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can
|
|
|
+ executor.execute(ActionRunnable.wrap(listener, l -> doExecuteForked(task, request, l)));
|
|
|
+ }
|
|
|
+
|
|
|
+ protected void doExecuteForked(
|
|
|
+ Task task,
|
|
|
+ GetTrainedModelsStatsAction.Request request,
|
|
|
+ ActionListener<GetTrainedModelsStatsAction.Response> listener
|
|
|
) {
|
|
|
final TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
|
|
|
final ModelAliasMetadata modelAliasMetadata = ModelAliasMetadata.fromState(clusterService.state());
|
|
@@ -116,101 +125,108 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
|
|
|
|
|
|
GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
|
|
|
|
|
|
- ListenableFuture<Map<String, TrainedModelSizeStats>> modelSizeStatsListener = new ListenableFuture<>();
|
|
|
- modelSizeStatsListener.addListener(listener.delegateFailureAndWrap((l, modelSizeStatsByModelId) -> {
|
|
|
- responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId);
|
|
|
- l.onResponse(
|
|
|
- responseBuilder.build(modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata))
|
|
|
- );
|
|
|
- }));
|
|
|
-
|
|
|
- ListenableFuture<GetDeploymentStatsAction.Response> deploymentStatsListener = new ListenableFuture<>();
|
|
|
- deploymentStatsListener.addListener(listener.delegateFailureAndWrap((delegate, deploymentStats) -> {
|
|
|
- // deployment stats for each matching deployment
|
|
|
- // not necessarily for all models
|
|
|
- responseBuilder.setDeploymentStatsByDeploymentId(
|
|
|
- deploymentStats.getStats()
|
|
|
- .results()
|
|
|
+ SubscribableListener
|
|
|
+
|
|
|
+ .<Tuple<Long, Map<String, Set<String>>>>newForked(l -> {
|
|
|
+ // When the request resource is a deployment find the model used in that deployment for the model stats
|
|
|
+ final var idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata);
|
|
|
+
|
|
|
+ logger.debug("Expanded models/deployment Ids request [{}]", idExpression);
|
|
|
+
|
|
|
+ // the request id may contain deployment ids
|
|
|
+ // It is not an error if these don't match a model id but
|
|
|
+ // they need to be included in case the deployment id is also
|
|
|
+ // a model id. Hence, the `matchedDeploymentIds` parameter
|
|
|
+ trainedModelProvider.expandIds(
|
|
|
+ idExpression,
|
|
|
+ request.isAllowNoResources(),
|
|
|
+ request.getPageParams(),
|
|
|
+ Collections.emptySet(),
|
|
|
+ modelAliasMetadata,
|
|
|
+ parentTaskId,
|
|
|
+ matchedDeploymentIds,
|
|
|
+ l
|
|
|
+ );
|
|
|
+ })
|
|
|
+ .andThenAccept(tuple -> responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1()))
|
|
|
+
|
|
|
+ .<NodesStatsResponse>andThen(
|
|
|
+ (l, ignored) -> executeAsyncWithOrigin(
|
|
|
+ client,
|
|
|
+ ML_ORIGIN,
|
|
|
+ TransportNodesStatsAction.TYPE,
|
|
|
+ nodeStatsRequest(clusterService.state(), parentTaskId),
|
|
|
+ l
|
|
|
+ )
|
|
|
+ )
|
|
|
+ .<List<InferenceStats>>andThen(executor, null, (l, nodesStatsResponse) -> {
|
|
|
+ // find all pipelines whether using the model id, alias or deployment id.
|
|
|
+ Set<String> allPossiblePipelineReferences = responseBuilder.getExpandedModelIdsWithAliases()
|
|
|
+ .entrySet()
|
|
|
.stream()
|
|
|
- .collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity()))
|
|
|
- );
|
|
|
+ .flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey())))
|
|
|
+ .collect(Collectors.toSet());
|
|
|
+ allPossiblePipelineReferences.addAll(matchedDeploymentIds);
|
|
|
|
|
|
- int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum();
|
|
|
- modelSizeStats(
|
|
|
- responseBuilder.getExpandedModelIdsWithAliases(),
|
|
|
- request.isAllowNoResources(),
|
|
|
- parentTaskId,
|
|
|
- modelSizeStatsListener,
|
|
|
- numberOfAllocations
|
|
|
- );
|
|
|
- }));
|
|
|
-
|
|
|
- ListenableFuture<List<InferenceStats>> inferenceStatsListener = new ListenableFuture<>();
|
|
|
- // inference stats are per model and are only
|
|
|
- // persisted for boosted tree models
|
|
|
- inferenceStatsListener.addListener(listener.delegateFailureAndWrap((l, inferenceStats) -> {
|
|
|
- responseBuilder.setInferenceStatsByModelId(
|
|
|
- inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity()))
|
|
|
- );
|
|
|
- getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, deploymentStatsListener);
|
|
|
- }));
|
|
|
-
|
|
|
- ListenableFuture<NodesStatsResponse> nodesStatsListener = new ListenableFuture<>();
|
|
|
- nodesStatsListener.addListener(listener.delegateFailureAndWrap((delegate, nodesStatsResponse) -> {
|
|
|
- // find all pipelines whether using the model id,
|
|
|
- // alias or deployment id.
|
|
|
- Set<String> allPossiblePipelineReferences = responseBuilder.getExpandedModelIdsWithAliases()
|
|
|
- .entrySet()
|
|
|
- .stream()
|
|
|
- .flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey())))
|
|
|
- .collect(Collectors.toSet());
|
|
|
- allPossiblePipelineReferences.addAll(matchedDeploymentIds);
|
|
|
-
|
|
|
- Map<String, Set<String>> pipelineIdsByResource = pipelineIdsByResource(clusterService.state(), allPossiblePipelineReferences);
|
|
|
- Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(
|
|
|
- nodesStatsResponse,
|
|
|
- modelAliasMetadata,
|
|
|
- pipelineIdsByResource
|
|
|
- );
|
|
|
- responseBuilder.setIngestStatsByModelId(modelIdIngestStats);
|
|
|
- trainedModelProvider.getInferenceStats(
|
|
|
- responseBuilder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]),
|
|
|
- parentTaskId,
|
|
|
- inferenceStatsListener
|
|
|
- );
|
|
|
- }));
|
|
|
-
|
|
|
- ListenableFuture<Tuple<Long, Map<String, Set<String>>>> idsListener = new ListenableFuture<>();
|
|
|
- idsListener.addListener(listener.delegateFailureAndWrap((delegate, tuple) -> {
|
|
|
- responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1());
|
|
|
- executeAsyncWithOrigin(
|
|
|
- client,
|
|
|
- ML_ORIGIN,
|
|
|
- TransportNodesStatsAction.TYPE,
|
|
|
- nodeStatsRequest(clusterService.state(), parentTaskId),
|
|
|
- nodesStatsListener
|
|
|
- );
|
|
|
- }));
|
|
|
-
|
|
|
- // When the request resource is a deployment find the
|
|
|
- // model used in that deployment for the model stats
|
|
|
- String idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata);
|
|
|
- logger.debug("Expanded models/deployment Ids request [{}]", idExpression);
|
|
|
-
|
|
|
- // the request id may contain deployment ids
|
|
|
- // It is not an error if these don't match a model id but
|
|
|
- // they need to be included in case the deployment id is also
|
|
|
- // a model id. Hence, the `matchedDeploymentIds` parameter
|
|
|
- trainedModelProvider.expandIds(
|
|
|
- idExpression,
|
|
|
- request.isAllowNoResources(),
|
|
|
- request.getPageParams(),
|
|
|
- Collections.emptySet(),
|
|
|
- modelAliasMetadata,
|
|
|
- parentTaskId,
|
|
|
- matchedDeploymentIds,
|
|
|
- idsListener
|
|
|
- );
|
|
|
+ Map<String, Set<String>> pipelineIdsByResource = pipelineIdsByResource(
|
|
|
+ clusterService.state(),
|
|
|
+ allPossiblePipelineReferences
|
|
|
+ );
|
|
|
+ Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(
|
|
|
+ nodesStatsResponse,
|
|
|
+ modelAliasMetadata,
|
|
|
+ pipelineIdsByResource
|
|
|
+ );
|
|
|
+ responseBuilder.setIngestStatsByModelId(modelIdIngestStats);
|
|
|
+ trainedModelProvider.getInferenceStats(
|
|
|
+ responseBuilder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]),
|
|
|
+ parentTaskId,
|
|
|
+ l
|
|
|
+ );
|
|
|
+ })
|
|
|
+ .andThenAccept(
|
|
|
+ // inference stats are per model and are only persisted for boosted tree models
|
|
|
+ inferenceStats -> responseBuilder.setInferenceStatsByModelId(
|
|
|
+ inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity()))
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ .<GetDeploymentStatsAction.Response>andThen(
|
|
|
+ executor,
|
|
|
+ null,
|
|
|
+ (l, ignored) -> getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, l)
|
|
|
+ )
|
|
|
+ .andThenApply(deploymentStats -> {
|
|
|
+ // deployment stats for each matching deployment not necessarily for all models
|
|
|
+ responseBuilder.setDeploymentStatsByDeploymentId(
|
|
|
+ deploymentStats.getStats()
|
|
|
+ .results()
|
|
|
+ .stream()
|
|
|
+ .collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity()))
|
|
|
+ );
|
|
|
+ return deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum();
|
|
|
+ })
|
|
|
+
|
|
|
+ .<Map<String, TrainedModelSizeStats>>andThen(
|
|
|
+ executor,
|
|
|
+ null,
|
|
|
+ (l, numberOfAllocations) -> modelSizeStats(
|
|
|
+ responseBuilder.getExpandedModelIdsWithAliases(),
|
|
|
+ request.isAllowNoResources(),
|
|
|
+ parentTaskId,
|
|
|
+ l,
|
|
|
+ numberOfAllocations
|
|
|
+ )
|
|
|
+ )
|
|
|
+ .andThenAccept(responseBuilder::setModelSizeStatsByModelId)
|
|
|
+
|
|
|
+ .andThenApply(
|
|
|
+ ignored -> responseBuilder.build(
|
|
|
+ modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata)
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ .addListener(listener, executor, null);
|
|
|
}
|
|
|
|
|
|
static String addModelsUsedInMatchingDeployments(String idExpression, TrainedModelAssignmentMetadata assignmentMetadata) {
|