|
@@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger;
|
|
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.ResourceAlreadyExistsException;
|
|
|
+import org.elasticsearch.ResourceNotFoundException;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.support.ActionFilters;
|
|
|
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
|
|
@@ -26,11 +27,16 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
|
import org.elasticsearch.common.inject.Inject;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
+import org.elasticsearch.common.util.set.Sets;
|
|
|
import org.elasticsearch.core.TimeValue;
|
|
|
+import org.elasticsearch.index.query.QueryBuilders;
|
|
|
import org.elasticsearch.license.LicenseUtils;
|
|
|
import org.elasticsearch.license.XPackLicenseState;
|
|
|
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
+import org.elasticsearch.search.SearchHit;
|
|
|
+import org.elasticsearch.search.sort.SortBuilders;
|
|
|
+import org.elasticsearch.search.sort.SortOrder;
|
|
|
import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.transport.TransportService;
|
|
@@ -47,11 +53,15 @@ import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
|
|
|
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
import org.elasticsearch.xpack.ml.MachineLearning;
|
|
|
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
|
|
|
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
|
|
|
import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer;
|
|
|
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
|
|
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
|
|
|
|
|
|
import java.util.Collections;
|
|
@@ -65,6 +75,7 @@ import java.util.function.Predicate;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
|
|
+import static org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction.MAX_NUM_NATIVE_DEFINITION_PARTS;
|
|
|
|
|
|
public class TransportStartTrainedModelDeploymentAction extends TransportMasterNodeAction<
|
|
|
StartTrainedModelDeploymentAction.Request,
|
|
@@ -174,26 +185,28 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
|
|
|
listener.onFailure(ExceptionsHelper.serverError("model [{}] does not have location", trainedModelConfig.getModelId()));
|
|
|
return;
|
|
|
}
|
|
|
-
|
|
|
- getModelBytes(trainedModelConfig, ActionListener.wrap(modelBytes -> {
|
|
|
- TaskParams taskParams = new TaskParams(
|
|
|
- trainedModelConfig.getModelId(),
|
|
|
- modelBytes,
|
|
|
- request.getInferenceThreads(),
|
|
|
- request.getModelThreads(),
|
|
|
- request.getQueueCapacity()
|
|
|
- );
|
|
|
- PersistentTasksCustomMetadata persistentTasks = clusterService.state()
|
|
|
- .getMetadata()
|
|
|
- .custom(PersistentTasksCustomMetadata.TYPE);
|
|
|
- memoryTracker.refresh(
|
|
|
- persistentTasks,
|
|
|
- ActionListener.wrap(
|
|
|
- aVoid -> trainedModelAllocationService.createNewModelAllocation(taskParams, waitForDeploymentToStart),
|
|
|
- listener::onFailure
|
|
|
- )
|
|
|
- );
|
|
|
- }, listener::onFailure));
|
|
|
+ validateModelDefinition(
|
|
|
+ trainedModelConfig,
|
|
|
+ ActionListener.wrap(validate -> getModelBytes(trainedModelConfig, ActionListener.wrap(modelBytes -> {
|
|
|
+ TaskParams taskParams = new TaskParams(
|
|
|
+ trainedModelConfig.getModelId(),
|
|
|
+ modelBytes,
|
|
|
+ request.getInferenceThreads(),
|
|
|
+ request.getModelThreads(),
|
|
|
+ request.getQueueCapacity()
|
|
|
+ );
|
|
|
+ PersistentTasksCustomMetadata persistentTasks = clusterService.state()
|
|
|
+ .getMetadata()
|
|
|
+ .custom(PersistentTasksCustomMetadata.TYPE);
|
|
|
+ memoryTracker.refresh(
|
|
|
+ persistentTasks,
|
|
|
+ ActionListener.wrap(
|
|
|
+ aVoid -> trainedModelAllocationService.createNewModelAllocation(taskParams, waitForDeploymentToStart),
|
|
|
+ listener::onFailure
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }, listener::onFailure)), listener::onFailure)
|
|
|
+ );
|
|
|
|
|
|
}, listener::onFailure);
|
|
|
|
|
@@ -270,6 +283,86 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
|
|
|
|
|
|
}
|
|
|
|
|
|
+ private void validateModelDefinition(TrainedModelConfig config, ActionListener<Void> listener) {
|
|
|
+ if (config.getLocation() instanceof IndexLocation == false) {
|
|
|
+ listener.onResponse(null);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ final String modelId = config.getModelId();
|
|
|
+ final String[] requiredSourceFields = new String[] {
|
|
|
+ TrainedModelDefinitionDoc.DEFINITION_LENGTH.getPreferredName(),
|
|
|
+ TrainedModelDefinitionDoc.DOC_NUM.getPreferredName(),
|
|
|
+ TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName(),
|
|
|
+ TrainedModelDefinitionDoc.EOS.getPreferredName() };
|
|
|
+ final Set<String> requiredSet = Set.of(requiredSourceFields);
|
|
|
+ String index = ((IndexLocation) config.getLocation()).getIndexName();
|
|
|
+ client.prepareSearch(index)
|
|
|
+ .setQuery(
|
|
|
+ QueryBuilders.constantScoreQuery(
|
|
|
+ QueryBuilders.boolQuery()
|
|
|
+ .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
|
|
|
+ .filter(
|
|
|
+ QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME)
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
+ .setFetchSource(requiredSourceFields, new String[0])
|
|
|
+ .setSize(MAX_NUM_NATIVE_DEFINITION_PARTS)
|
|
|
+ .setTrackTotalHits(true)
|
|
|
+ .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()).order(SortOrder.ASC).unmappedType("long"))
|
|
|
+ .execute(ActionListener.wrap(response -> {
|
|
|
+ SearchHit[] hits = response.getHits().getHits();
|
|
|
+ if (hits.length == 0) {
|
|
|
+ listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ long summedLengths = 0;
|
|
|
+ for (SearchHit hit : hits) {
|
|
|
+ Map<String, Object> fields = hit.getSourceAsMap();
|
|
|
+ if (fields == null) {
|
|
|
+ listener.onFailure(
|
|
|
+ ExceptionsHelper.badRequestException(
|
|
|
+ "[{}] model definition [{}] is missing required fields {}, unable to be deployed",
|
|
|
+ modelId,
|
|
|
+ TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())),
|
|
|
+ List.of(requiredSourceFields)
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ Set<String> diff = Sets.difference(fields.keySet(), requiredSet);
|
|
|
+ if (diff.isEmpty() == false) {
|
|
|
+ listener.onFailure(
|
|
|
+ ExceptionsHelper.badRequestException(
|
|
|
+ "[{}] model definition [{}] is missing required fields {}, unable to be deployed",
|
|
|
+ modelId,
|
|
|
+ TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())),
|
|
|
+ diff
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ summedLengths += ((Number) fields.get(TrainedModelDefinitionDoc.DEFINITION_LENGTH.getPreferredName())).longValue();
|
|
|
+ }
|
|
|
+ long totalLength = ((Number) hits[hits.length - 1].getSourceAsMap()
|
|
|
+ .get(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName())).longValue();
|
|
|
+ Boolean eos = (Boolean) hits[hits.length - 1].getSourceAsMap().get(TrainedModelDefinitionDoc.EOS.getPreferredName());
|
|
|
+ if (summedLengths != totalLength || eos == null || eos == false) {
|
|
|
+ listener.onFailure(ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ listener.onResponse(null);
|
|
|
+ }, e -> {
|
|
|
+ if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
|
|
+ Exception ex = new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId));
|
|
|
+ ex.addSuppressed(e);
|
|
|
+ listener.onFailure(ex);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ listener.onFailure(e);
|
|
|
+ }));
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
protected ClusterBlockException checkBlock(StartTrainedModelDeploymentAction.Request request, ClusterState state) {
|
|
|
// We only delegate here to PersistentTasksService, but if there is a metadata writeblock,
|