|
@@ -12,27 +12,41 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
|
|
|
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
|
|
|
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
|
|
|
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
|
|
|
+import org.elasticsearch.action.search.SearchAction;
|
|
|
+import org.elasticsearch.action.search.SearchRequest;
|
|
|
import org.elasticsearch.action.support.ActionFilters;
|
|
|
import org.elasticsearch.action.support.HandledTransportAction;
|
|
|
import org.elasticsearch.client.Client;
|
|
|
import org.elasticsearch.cluster.ClusterState;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
|
+import org.elasticsearch.common.document.DocumentField;
|
|
|
import org.elasticsearch.common.inject.Inject;
|
|
|
import org.elasticsearch.common.metrics.CounterMetric;
|
|
|
import org.elasticsearch.common.util.set.Sets;
|
|
|
import org.elasticsearch.core.Tuple;
|
|
|
+import org.elasticsearch.index.query.QueryBuilder;
|
|
|
+import org.elasticsearch.index.query.QueryBuilders;
|
|
|
import org.elasticsearch.ingest.IngestMetadata;
|
|
|
import org.elasticsearch.ingest.IngestService;
|
|
|
import org.elasticsearch.ingest.IngestStats;
|
|
|
import org.elasticsearch.ingest.Pipeline;
|
|
|
+import org.elasticsearch.search.SearchHit;
|
|
|
+import org.elasticsearch.search.sort.SortOrder;
|
|
|
import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.transport.TransportService;
|
|
|
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
|
|
|
+import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
|
|
|
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
|
|
|
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
|
|
|
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
|
|
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
|
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
|
|
|
|
|
import java.util.ArrayList;
|
|
@@ -85,17 +99,17 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
|
|
|
final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(clusterService.state());
|
|
|
GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
|
|
|
|
|
|
- ActionListener<GetDeploymentStatsAction.Response> getDeploymentStats = ActionListener.wrap(
|
|
|
- deploymentStats -> listener.onResponse(
|
|
|
- responseBuilder.setDeploymentStatsByModelId(
|
|
|
- deploymentStats.getStats()
|
|
|
- .results()
|
|
|
- .stream()
|
|
|
- .collect(Collectors.toMap(AllocationStats::getModelId, Function.identity()))
|
|
|
- ).build()
|
|
|
- ),
|
|
|
- listener::onFailure
|
|
|
- );
|
|
|
+ ActionListener<Map<String, TrainedModelSizeStats>> modelSizeStatsListener = ActionListener.wrap(modelSizeStatsByModelId -> {
|
|
|
+ responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId);
|
|
|
+ listener.onResponse(responseBuilder.build());
|
|
|
+ }, listener::onFailure);
|
|
|
+
|
|
|
+ ActionListener<GetDeploymentStatsAction.Response> deploymentStatsListener = ActionListener.wrap(deploymentStats -> {
|
|
|
+ responseBuilder.setDeploymentStatsByModelId(
|
|
|
+ deploymentStats.getStats().results().stream().collect(Collectors.toMap(AllocationStats::getModelId, Function.identity()))
|
|
|
+ );
|
|
|
+ modelSizeStats(responseBuilder.getExpandedIdsWithAliases(), request.isAllowNoResources(), modelSizeStatsListener);
|
|
|
+ }, listener::onFailure);
|
|
|
|
|
|
ActionListener<List<InferenceStats>> inferenceStatsListener = ActionListener.wrap(inferenceStats -> {
|
|
|
responseBuilder.setInferenceStatsByModelId(
|
|
@@ -106,7 +120,7 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
|
|
|
ML_ORIGIN,
|
|
|
GetDeploymentStatsAction.INSTANCE,
|
|
|
new GetDeploymentStatsAction.Request(request.getResourceId()),
|
|
|
- getDeploymentStats
|
|
|
+ deploymentStatsListener
|
|
|
);
|
|
|
}, listener::onFailure);
|
|
|
|
|
@@ -150,6 +164,77 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ private void modelSizeStats(
|
|
|
+ Map<String, Set<String>> expandedIdsWithAliases,
|
|
|
+ boolean allowNoResources,
|
|
|
+ ActionListener<Map<String, TrainedModelSizeStats>> listener
|
|
|
+ ) {
|
|
|
+ ActionListener<List<TrainedModelConfig>> modelsListener = ActionListener.wrap(models -> {
|
|
|
+ final List<String> pytorchModelIds = models.stream()
|
|
|
+ .filter(m -> m.getModelType() == TrainedModelType.PYTORCH)
|
|
|
+ .map(TrainedModelConfig::getModelId)
|
|
|
+ .toList();
|
|
|
+ definitionLengths(pytorchModelIds, ActionListener.wrap(pytorchTotalDefinitionLengthsByModelId -> {
|
|
|
+ Map<String, TrainedModelSizeStats> modelSizeStatsByModelId = new HashMap<>();
|
|
|
+ for (TrainedModelConfig model : models) {
|
|
|
+ if (model.getModelType() == TrainedModelType.PYTORCH) {
|
|
|
+ long totalDefinitionLength = pytorchTotalDefinitionLengthsByModelId.getOrDefault(model.getModelId(), 0L);
|
|
|
+ modelSizeStatsByModelId.put(
|
|
|
+ model.getModelId(),
|
|
|
+ new TrainedModelSizeStats(
|
|
|
+ totalDefinitionLength,
|
|
|
+ totalDefinitionLength > 0L
|
|
|
+ ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(totalDefinitionLength)
|
|
|
+ : 0L
|
|
|
+ )
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ listener.onResponse(modelSizeStatsByModelId);
|
|
|
+ }, listener::onFailure));
|
|
|
+ }, listener::onFailure);
|
|
|
+
|
|
|
+ trainedModelProvider.getTrainedModels(
|
|
|
+ expandedIdsWithAliases,
|
|
|
+ GetTrainedModelsAction.Includes.empty(),
|
|
|
+ allowNoResources,
|
|
|
+ modelsListener
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ private void definitionLengths(List<String> modelIds, ActionListener<Map<String, Long>> listener) {
|
|
|
+ QueryBuilder query = QueryBuilders.boolQuery()
|
|
|
+ .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME))
|
|
|
+ .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds))
|
|
|
+ .filter(QueryBuilders.termQuery(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName(), 0));
|
|
|
+ SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
|
|
+ .setQuery(QueryBuilders.constantScoreQuery(query))
|
|
|
+ .setFetchSource(false)
|
|
|
+ .addDocValueField(TrainedModelConfig.MODEL_ID.getPreferredName())
|
|
|
+ .addDocValueField(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName())
|
|
|
+ // First find the latest index
|
|
|
+ .addSort("_index", SortOrder.DESC)
|
|
|
+ .request();
|
|
|
+
|
|
|
+ executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
|
|
|
+ Map<String, Long> totalDefinitionLengthByModelId = new HashMap<>();
|
|
|
+ for (SearchHit hit : searchResponse.getHits().getHits()) {
|
|
|
+ DocumentField modelIdField = hit.field(TrainedModelConfig.MODEL_ID.getPreferredName());
|
|
|
+ if (modelIdField != null && modelIdField.getValue()instanceof String modelId) {
|
|
|
+ DocumentField totalDefinitionLengthField = hit.field(
|
|
|
+ TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName()
|
|
|
+ );
|
|
|
+ if (totalDefinitionLengthField != null && totalDefinitionLengthField.getValue()instanceof Long totalDefinitionLength) {
|
|
|
+ totalDefinitionLengthByModelId.put(modelId, totalDefinitionLength);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ listener.onResponse(totalDefinitionLengthByModelId);
|
|
|
+ }, listener::onFailure));
|
|
|
+ }
|
|
|
+
|
|
|
static Map<String, IngestStats> inferenceIngestStatsByModelId(
|
|
|
NodesStatsResponse response,
|
|
|
ModelAliasMetadata currentMetadata,
|