|
@@ -26,6 +26,7 @@ import org.elasticsearch.common.settings.Setting;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
|
|
import org.elasticsearch.common.util.set.Sets;
|
|
|
+import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.core.Strings;
|
|
|
import org.elasticsearch.core.TimeValue;
|
|
|
import org.elasticsearch.ingest.IngestMetadata;
|
|
@@ -36,6 +37,7 @@ import org.elasticsearch.tasks.TaskId;
|
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
|
@@ -51,12 +53,14 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
|
|
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
|
|
|
|
|
import java.util.ArrayDeque;
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.EnumSet;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
+import java.util.Optional;
|
|
|
import java.util.Queue;
|
|
|
import java.util.Set;
|
|
|
import java.util.concurrent.ExecutionException;
|
|
@@ -110,11 +114,71 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
Setting.Property.NodeScope
|
|
|
);
|
|
|
|
|
|
- // The feature requesting the model
|
|
|
+ /**
|
|
|
+ * The cached model consumer. Various consumers dictate the model's usage and context
|
|
|
+ */
|
|
|
public enum Consumer {
|
|
|
- PIPELINE,
|
|
|
- SEARCH,
|
|
|
- INTERNAL
|
|
|
+ PIPELINE() {
|
|
|
+ @Override
|
|
|
+ public boolean inferenceConfigSupported(InferenceConfig config) {
|
|
|
+ return config == null || config.supportsIngestPipeline();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String exceptionName() {
|
|
|
+ return "ingest";
|
|
|
+ }
|
|
|
+ },
|
|
|
+ SEARCH_AGGS() {
|
|
|
+ @Override
|
|
|
+ public boolean inferenceConfigSupported(InferenceConfig config) {
|
|
|
+ return config == null || config.supportsPipelineAggregation();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String exceptionName() {
|
|
|
+ return "search(aggregation)";
|
|
|
+ }
|
|
|
+ },
|
|
|
+ SEARCH_RESCORER() {
|
|
|
+ @Override
|
|
|
+ public boolean inferenceConfigSupported(InferenceConfig config) {
|
|
|
+ // Null configs imply creation via target type. This is for BWC for very old models
|
|
|
+ // Consequently, if the config is null, we don't support LTR with them.
|
|
|
+ return config != null && config.supportsSearchRescorer();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String exceptionName() {
|
|
|
+ return "search(rescorer)";
|
|
|
+ }
|
|
|
+ },
|
|
|
+ INTERNAL() {
|
|
|
+ @Override
|
|
|
+ public boolean inferenceConfigSupported(InferenceConfig config) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String exceptionName() {
|
|
|
+ return "internal";
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ /**
|
|
|
+ * @param config The inference config for the model. It may be null for very old Regression or classification models
|
|
|
+ * @return Is this configuration type supported within this cache context?
|
|
|
+ */
|
|
|
+ public abstract boolean inferenceConfigSupported(@Nullable InferenceConfig config);
|
|
|
+
|
|
|
+ /**
|
|
|
+ * @return The cache context name to use if an exception must be thrown due to the config not being supported
|
|
|
+ */
|
|
|
+ public abstract String exceptionName();
|
|
|
+
|
|
|
+ public boolean isAnyOf(Consumer... consumers) {
|
|
|
+ return Arrays.stream(consumers).anyMatch(c -> this == c);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
private static class ModelAndConsumer {
|
|
@@ -219,13 +283,23 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * Load the model for use by at search. Models requested by search are always cached.
|
|
|
+ * Load the model for use by at search through aggregations. Models requested by search are always cached.
|
|
|
+ *
|
|
|
+ * @param modelId the model to get
|
|
|
+ * @param modelActionListener the listener to alert when the model has been retrieved
|
|
|
+ */
|
|
|
+ public void getModelForAggregation(String modelId, ActionListener<LocalModel> modelActionListener) {
|
|
|
+ getModel(modelId, Consumer.SEARCH_AGGS, null, modelActionListener);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Load the model for use by at search for rescoring. Models requested by search are always cached.
|
|
|
*
|
|
|
* @param modelId the model to get
|
|
|
* @param modelActionListener the listener to alert when the model has been retrieved
|
|
|
*/
|
|
|
- public void getModelForSearch(String modelId, ActionListener<LocalModel> modelActionListener) {
|
|
|
- getModel(modelId, Consumer.SEARCH, null, modelActionListener);
|
|
|
+ public void getModelForLearnToRank(String modelId, ActionListener<LocalModel> modelActionListener) {
|
|
|
+ getModel(modelId, Consumer.SEARCH_RESCORER, null, modelActionListener);
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -259,6 +333,18 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
|
|
|
ModelAndConsumer cachedModel = localModelCache.get(modelId);
|
|
|
if (cachedModel != null) {
|
|
|
+ // Even if the model is already cached, we don't want to use the model in an unsupported task
|
|
|
+ if (consumer.inferenceConfigSupported(cachedModel.model.getInferenceConfig()) == false) {
|
|
|
+ modelActionListener.onFailure(
|
|
|
+ modelUnsupportedInUsageContext(
|
|
|
+ modelId,
|
|
|
+ cachedModel.model.getTrainedModelType(),
|
|
|
+ cachedModel.model.getInferenceConfig(),
|
|
|
+ consumer
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
cachedModel.consumers.add(consumer);
|
|
|
try {
|
|
|
cachedModel.model.acquire();
|
|
@@ -314,7 +400,6 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
localModelToNotifyListener.set(cachedModel.model);
|
|
|
return true;
|
|
|
}
|
|
|
-
|
|
|
// Add the listener to the queue if the model is loading
|
|
|
Queue<ActionListener<LocalModel>> listeners = loadingListeners.computeIfPresent(
|
|
|
modelId,
|
|
@@ -330,7 +415,8 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
|
|
|
// The model is not currently being loaded (indicated by listeners check above).
|
|
|
// So start a new load outside of the synchronized block.
|
|
|
- if (Consumer.SEARCH != consumer && referencedModels.contains(modelId) == false) {
|
|
|
+ if (consumer.isAnyOf(Consumer.SEARCH_AGGS, Consumer.SEARCH_RESCORER) == false
|
|
|
+ && referencedModels.contains(modelId) == false) {
|
|
|
// The model is requested by a pipeline but not referenced by any ingest pipelines.
|
|
|
// This means it is a simulate call and the model should not be cached
|
|
|
logger.trace(
|
|
@@ -368,19 +454,19 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
// We don't want to cancel the loading if only ONE of them stops listening or closes connection
|
|
|
// TODO Is there a way to only signal a cancel if all the listener tasks cancel???
|
|
|
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), null, ActionListener.wrap(trainedModelConfig -> {
|
|
|
- if (trainedModelConfig.isAllocateOnly()) {
|
|
|
- if (consumer == Consumer.SEARCH) {
|
|
|
- handleLoadFailure(
|
|
|
+ if (consumer.inferenceConfigSupported(trainedModelConfig.getInferenceConfig()) == false) {
|
|
|
+ handleLoadFailure(
|
|
|
+ modelId,
|
|
|
+ modelUnsupportedInUsageContext(
|
|
|
modelId,
|
|
|
- new ElasticsearchStatusException(
|
|
|
- "Trained model [{}] with type [{}] is currently not usable in search.",
|
|
|
- RestStatus.BAD_REQUEST,
|
|
|
- modelId,
|
|
|
- trainedModelConfig.getModelType()
|
|
|
- )
|
|
|
- );
|
|
|
- return;
|
|
|
- }
|
|
|
+ trainedModelConfig.getModelType(),
|
|
|
+ trainedModelConfig.getInferenceConfig(),
|
|
|
+ consumer
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (trainedModelConfig.isAllocateOnly()) {
|
|
|
handleLoadFailure(modelId, modelMustBeDeployedError(modelId));
|
|
|
return;
|
|
|
}
|
|
@@ -419,19 +505,21 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
|
|
// by a simulated pipeline
|
|
|
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), parentTaskId, ActionListener.wrap(trainedModelConfig -> {
|
|
|
+ // If the model is used in an unsupported context, fail here
|
|
|
+ if (consumer.inferenceConfigSupported(trainedModelConfig.getInferenceConfig()) == false) {
|
|
|
+ handleLoadFailure(
|
|
|
+ modelId,
|
|
|
+ modelUnsupportedInUsageContext(
|
|
|
+ modelId,
|
|
|
+ trainedModelConfig.getModelType(),
|
|
|
+ trainedModelConfig.getInferenceConfig(),
|
|
|
+ consumer
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
// If the model should be allocated, we should fail here
|
|
|
if (trainedModelConfig.isAllocateOnly()) {
|
|
|
- if (consumer == Consumer.SEARCH) {
|
|
|
- modelActionListener.onFailure(
|
|
|
- new ElasticsearchStatusException(
|
|
|
- "model [{}] with type [{}] is currently not usable in search.",
|
|
|
- RestStatus.BAD_REQUEST,
|
|
|
- modelId,
|
|
|
- trainedModelConfig.getModelType()
|
|
|
- )
|
|
|
- );
|
|
|
- return;
|
|
|
- }
|
|
|
modelActionListener.onFailure(modelMustBeDeployedError(modelId));
|
|
|
return;
|
|
|
}
|
|
@@ -457,6 +545,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
trainedModelConfig.getDefaultFieldMap(),
|
|
|
inferenceConfig,
|
|
|
trainedModelConfig.getLicenseLevel(),
|
|
|
+ trainedModelConfig.getModelType(),
|
|
|
modelStatsService,
|
|
|
trainedModelCircuitBreaker
|
|
|
)
|
|
@@ -500,7 +589,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private ElasticsearchStatusException modelMustBeDeployedError(String modelId) {
|
|
|
+ private static ElasticsearchStatusException modelMustBeDeployedError(String modelId) {
|
|
|
return new ElasticsearchStatusException(
|
|
|
"Model [{}] must be deployed to use. Please deploy with the start trained model deployment API.",
|
|
|
RestStatus.BAD_REQUEST,
|
|
@@ -508,6 +597,22 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ private static ElasticsearchStatusException modelUnsupportedInUsageContext(
|
|
|
+ String modelId,
|
|
|
+ TrainedModelType modelType,
|
|
|
+ InferenceConfig inferenceConfig,
|
|
|
+ Consumer consumer
|
|
|
+ ) {
|
|
|
+ return new ElasticsearchStatusException(
|
|
|
+ "Trained model [{}] with type [{}] and task [{}] is currently not usable in [{}].",
|
|
|
+ RestStatus.BAD_REQUEST,
|
|
|
+ modelId,
|
|
|
+ modelType,
|
|
|
+ Optional.ofNullable(inferenceConfig).map(InferenceConfig::getName).orElse("_unknown_"),
|
|
|
+ consumer.exceptionName()
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
private void handleLoadSuccess(
|
|
|
String modelId,
|
|
|
Consumer consumer,
|
|
@@ -526,6 +631,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
trainedModelConfig.getDefaultFieldMap(),
|
|
|
inferenceConfig,
|
|
|
trainedModelConfig.getLicenseLevel(),
|
|
|
+ Optional.ofNullable(trainedModelConfig.getModelType()).orElse(TrainedModelType.TREE_ENSEMBLE),
|
|
|
modelStatsService,
|
|
|
trainedModelCircuitBreaker
|
|
|
);
|
|
@@ -536,7 +642,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
// Also, if the consumer is a search consumer, we should always cache it
|
|
|
if (referencedModels.contains(modelId)
|
|
|
|| Sets.haveNonEmptyIntersection(modelIdToModelAliases.getOrDefault(modelId, new HashSet<>()), referencedModels)
|
|
|
- || consumer.equals(Consumer.SEARCH)) {
|
|
|
+ || consumer.equals(Consumer.SEARCH_AGGS)) {
|
|
|
try {
|
|
|
// The local model may already be in cache. If it is, we don't bother adding it to cache.
|
|
|
// If it isn't, we flip an `isLoaded` flag, and increment the model counter to make sure if it is evicted
|
|
@@ -699,7 +805,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
);
|
|
|
if (oldModelAliasesNotReferenced && newModelAliasesNotReferenced && modelIsNotReferenced) {
|
|
|
ModelAndConsumer modelAndConsumer = localModelCache.get(modelId);
|
|
|
- if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) {
|
|
|
+ if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH_AGGS) == false) {
|
|
|
logger.trace("[{} ({})] invalidated from cache", modelId, modelAliasOrId);
|
|
|
localModelCache.invalidate(modelId);
|
|
|
}
|