|
@@ -19,6 +19,7 @@ import org.elasticsearch.common.breaker.CircuitBreaker;
|
|
|
import org.elasticsearch.common.breaker.CircuitBreakingException;
|
|
|
import org.elasticsearch.common.cache.Cache;
|
|
|
import org.elasticsearch.common.cache.CacheBuilder;
|
|
|
+import org.elasticsearch.common.cache.CacheLoader;
|
|
|
import org.elasticsearch.common.cache.RemovalNotification;
|
|
|
import org.elasticsearch.common.settings.Setting;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
@@ -37,12 +38,14 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
import org.elasticsearch.xpack.ml.MachineLearning;
|
|
|
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
|
|
|
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
|
|
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
|
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
|
|
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
|
|
|
|
|
import java.util.ArrayDeque;
|
|
|
+import java.util.Collections;
|
|
|
import java.util.EnumSet;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.HashSet;
|
|
@@ -50,6 +53,7 @@ import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Queue;
|
|
|
import java.util.Set;
|
|
|
+import java.util.concurrent.ExecutionException;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
/**
|
|
@@ -108,11 +112,14 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-
|
|
|
private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
|
|
|
private final TrainedModelStatsService modelStatsService;
|
|
|
private final Cache<String, ModelAndConsumer> localModelCache;
|
|
|
+ // Referenced models can be model aliases or IDs
|
|
|
private final Set<String> referencedModels = new HashSet<>();
|
|
|
+ private final Map<String, String> modelAliasToId = new HashMap<>();
|
|
|
+ private final Map<String, Set<String>> modelIdToModelAliases = new HashMap<>();
|
|
|
+ private final Map<String, Set<String>> modelIdToUpdatedModelAliases = new HashMap<>();
|
|
|
private final Map<String, Queue<ActionListener<LocalModel>>> loadingListeners = new HashMap<>();
|
|
|
private final TrainedModelProvider provider;
|
|
|
private final Set<String> shouldNotAudit;
|
|
@@ -148,8 +155,13 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker");
|
|
|
}
|
|
|
|
|
|
+ // for testing
|
|
|
+ String getModelId(String modelIdOrAlias) {
|
|
|
+ return modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
|
|
|
+ }
|
|
|
+
|
|
|
boolean isModelCached(String modelId) {
|
|
|
- return localModelCache.get(modelId) != null;
|
|
|
+ return localModelCache.get(modelAliasToId.getOrDefault(modelId, modelId)) != null;
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -195,11 +207,12 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
* The main difference being that models for search are always cached whereas pipeline models
|
|
|
* are only cached if they are referenced by an ingest pipeline
|
|
|
*
|
|
|
- * @param modelId the model to get
|
|
|
+ * @param modelIdOrAlias the model id or model alias to get
|
|
|
* @param consumer which feature is requesting the model
|
|
|
* @param modelActionListener the listener to alert when the model has been retrieved.
|
|
|
*/
|
|
|
- private void getModel(String modelId, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
|
|
|
+ private void getModel(String modelIdOrAlias, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
|
|
|
+ final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
|
|
|
ModelAndConsumer cachedModel = localModelCache.get(modelId);
|
|
|
if (cachedModel != null) {
|
|
|
cachedModel.consumers.add(consumer);
|
|
@@ -210,12 +223,16 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
return;
|
|
|
}
|
|
|
modelActionListener.onResponse(cachedModel.model);
|
|
|
- logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId));
|
|
|
+ logger.trace(() -> new ParameterizedMessage("[{}] (model_alias [{}]) loaded from cache", modelId, modelIdOrAlias));
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- if (loadModelIfNecessary(modelId, consumer, modelActionListener)) {
|
|
|
- logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId));
|
|
|
+ if (loadModelIfNecessary(modelIdOrAlias, consumer, modelActionListener)) {
|
|
|
+ logger.trace(() -> new ParameterizedMessage(
|
|
|
+ "[{}] (model_alias [{}]) is loading or loaded, added new listener to queue",
|
|
|
+ modelId,
|
|
|
+ modelIdOrAlias
|
|
|
+ ));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -224,14 +241,15 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
* else if the model is CURRENTLY being loaded the listener is added to be notified when it is loaded
|
|
|
* else the model load is initiated.
|
|
|
*
|
|
|
- * @param modelId The model to get
|
|
|
+ * @param modelIdOrAlias The model to get
|
|
|
* @param consumer The model consumer
|
|
|
* @param modelActionListener The listener
|
|
|
* @return If the model is cached or currently being loaded true is returned. If a new load is started
|
|
|
* false is returned to indicate a new load event
|
|
|
*/
|
|
|
- private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
|
|
|
+ private boolean loadModelIfNecessary(String modelIdOrAlias, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
|
|
|
synchronized (loadingListeners) {
|
|
|
+ final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
|
|
|
ModelAndConsumer cachedModel = localModelCache.get(modelId);
|
|
|
if (cachedModel != null) {
|
|
|
cachedModel.consumers.add(consumer);
|
|
@@ -257,13 +275,21 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) {
|
|
|
// The model is requested by a pipeline but not referenced by any ingest pipelines.
|
|
|
// This means it is a simulate call and the model should not be cached
|
|
|
+ logger.trace(() -> new ParameterizedMessage(
|
|
|
+ "[{}] (model_alias [{}]) not actively loading, eager loading without cache",
|
|
|
+ modelId,
|
|
|
+ modelIdOrAlias
|
|
|
+ ));
|
|
|
loadWithoutCaching(modelId, modelActionListener);
|
|
|
} else {
|
|
|
- logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId));
|
|
|
+ logger.trace(() -> new ParameterizedMessage(
|
|
|
+ "[{}] (model_alias [{}]) attempting to load and cache",
|
|
|
+ modelId,
|
|
|
+ modelIdOrAlias
|
|
|
+ ));
|
|
|
loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
|
|
|
loadModel(modelId, consumer);
|
|
|
}
|
|
|
-
|
|
|
return false;
|
|
|
} // synchronized (loadingListeners)
|
|
|
}
|
|
@@ -304,7 +330,6 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
private void loadWithoutCaching(String modelId, ActionListener<LocalModel> modelActionListener) {
|
|
|
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
|
|
// by a simulated pipeline
|
|
|
- logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
|
|
|
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
|
|
|
trainedModelConfig -> {
|
|
|
// Verify we can pull the model into memory without causing OOM
|
|
@@ -377,34 +402,41 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
trainedModelConfig.getLicenseLevel(),
|
|
|
modelStatsService,
|
|
|
trainedModelCircuitBreaker);
|
|
|
- boolean modelAcquired = false;
|
|
|
+ final ModelAndConsumerLoader modelAndConsumerLoader = new ModelAndConsumerLoader(new ModelAndConsumer(loadedModel, consumer));
|
|
|
synchronized (loadingListeners) {
|
|
|
- listeners = loadingListeners.remove(modelId);
|
|
|
- // if there are no listeners, simply release and leave
|
|
|
- if (listeners == null) {
|
|
|
- loadedModel.release();
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
+ populateNewModelAlias(modelId);
|
|
|
// If the model is referenced, that means it is currently in a pipeline somewhere
|
|
|
// Also, if the consume is a search consumer, we should always cache it
|
|
|
- if (referencedModels.contains(modelId) || consumer.equals(Consumer.SEARCH)) {
|
|
|
- // temporarily increase the reference count before adding to
|
|
|
- // the cache in case the model is evicted before the listeners
|
|
|
- // are called in which case acquire() would throw.
|
|
|
- loadedModel.acquire();
|
|
|
- localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer));
|
|
|
+ if (referencedModels.contains(modelId)
|
|
|
+ || Sets.haveNonEmptyIntersection(modelIdToModelAliases.getOrDefault(modelId, new HashSet<>()), referencedModels)
|
|
|
+ || consumer.equals(Consumer.SEARCH)) {
|
|
|
+ try {
|
|
|
+ // The local model may already be in cache. If it is, we don't bother adding it to cache.
|
|
|
+ // If it isn't, we flip an `isLoaded` flag, and increment the model counter to make sure if it is evicted
|
|
|
+ // between now and when the listeners access it, the circuit breaker reflects actual usage.
|
|
|
+ localModelCache.computeIfAbsent(modelId, modelAndConsumerLoader);
|
|
|
+ } catch (ExecutionException ee) {
|
|
|
+ logger.warn(() -> new ParameterizedMessage("[{}] threw when attempting add to cache", modelId), ee);
|
|
|
+ }
|
|
|
shouldNotAudit.remove(modelId);
|
|
|
- modelAcquired = true;
|
|
|
+ }
|
|
|
+ listeners = loadingListeners.remove(modelId);
|
|
|
+ // if there are no listeners, we should just exit
|
|
|
+ if (listeners == null) {
|
|
|
+ // If we newly added it into cache, release the model so that the circuit breaker can still accurately keep track
|
|
|
+ // of memory
|
|
|
+ if(modelAndConsumerLoader.isLoaded()) {
|
|
|
+ loadedModel.release();
|
|
|
+ }
|
|
|
+ return;
|
|
|
}
|
|
|
} // synchronized (loadingListeners)
|
|
|
for (ActionListener<LocalModel> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
|
|
|
loadedModel.acquire();
|
|
|
listener.onResponse(loadedModel);
|
|
|
}
|
|
|
- // account for the acquire in the synchronized block above
|
|
|
- // We cannot simply utilize the same conditionals as `referencedModels` could have changed once we exited the synchronized block
|
|
|
- if (modelAcquired) {
|
|
|
+ // account for the acquire in the synchronized block above if the model was loaded into the cache
|
|
|
+ if (modelAndConsumerLoader.isLoaded()) {
|
|
|
loadedModel.release();
|
|
|
}
|
|
|
}
|
|
@@ -413,6 +445,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
Queue<ActionListener<LocalModel>> listeners;
|
|
|
synchronized (loadingListeners) {
|
|
|
listeners = loadingListeners.remove(modelId);
|
|
|
+ populateNewModelAlias(modelId);
|
|
|
if (listeners == null) {
|
|
|
return;
|
|
|
}
|
|
@@ -424,6 +457,20 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private void populateNewModelAlias(String modelId) {
|
|
|
+ Set<String> newModelAliases = modelIdToUpdatedModelAliases.remove(modelId);
|
|
|
+ if (newModelAliases != null && newModelAliases.isEmpty() == false) {
|
|
|
+ logger.trace(() -> new ParameterizedMessage(
|
|
|
+ "[{}] model is now loaded, setting new model_aliases {}",
|
|
|
+ modelId,
|
|
|
+ newModelAliases
|
|
|
+ ));
|
|
|
+ for (String modelAlias: newModelAliases) {
|
|
|
+ modelAliasToId.put(modelAlias, modelId);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private void cacheEvictionListener(RemovalNotification<String, ModelAndConsumer> notification) {
|
|
|
try {
|
|
|
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
|
|
@@ -438,12 +485,15 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
INFERENCE_MODEL_CACHE_TTL.getKey());
|
|
|
auditIfNecessary(notification.getKey(), msg);
|
|
|
}
|
|
|
-
|
|
|
- logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]",
|
|
|
- notification.getValue().model.getModelId()));
|
|
|
+ String modelId = modelAliasToId.getOrDefault(notification.getKey(), notification.getKey());
|
|
|
+ logger.trace(() -> new ParameterizedMessage(
|
|
|
+ "Persisting stats for evicted model [{}] (model_aliases {})",
|
|
|
+ modelId,
|
|
|
+ modelIdToModelAliases.getOrDefault(modelId, new HashSet<>())
|
|
|
+ ));
|
|
|
|
|
|
// If the model is no longer referenced, flush the stats to persist as soon as possible
|
|
|
- notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
|
|
|
+ notification.getValue().model.persistStats(referencedModels.contains(modelId) == false);
|
|
|
} finally {
|
|
|
notification.getValue().model.release();
|
|
|
}
|
|
@@ -451,46 +501,112 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
|
|
|
@Override
|
|
|
public void clusterChanged(ClusterChangedEvent event) {
|
|
|
- // If ingest data has not changed or if the current node is not an ingest node, don't bother caching models
|
|
|
- if (event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) == false ||
|
|
|
- event.state().nodes().getLocalNode().isIngestNode() == false) {
|
|
|
+ final boolean prefetchModels = event.state().nodes().getLocalNode().isIngestNode();
|
|
|
+ // If we are not prefetching models and there were no model alias changes, don't bother handling the changes
|
|
|
+ if ((prefetchModels == false)
|
|
|
+ && (event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) == false)
|
|
|
+ && (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME) == false)) {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
ClusterState state = event.state();
|
|
|
IngestMetadata currentIngestMetadata = state.metadata().custom(IngestMetadata.TYPE);
|
|
|
- Set<String> allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata);
|
|
|
- if (allReferencedModelKeys.equals(referencedModels)) {
|
|
|
- return;
|
|
|
- }
|
|
|
- Set<String> referencedModelsBeforeClusterState = null;
|
|
|
+ Set<String> allReferencedModelKeys = event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) ?
|
|
|
+ getReferencedModelKeys(currentIngestMetadata) :
|
|
|
+ new HashSet<>(referencedModels);
|
|
|
+ Set<String> referencedModelsBeforeClusterState;
|
|
|
Set<String> loadingModelBeforeClusterState = null;
|
|
|
- Set<String> removedModels = null;
|
|
|
+ Set<String> removedModels;
|
|
|
+ Map<String, Set<String>> addedModelViaAliases = new HashMap<>();
|
|
|
+ Map<String, Set<String>> oldIdToAliases;
|
|
|
synchronized (loadingListeners) {
|
|
|
+ oldIdToAliases = new HashMap<>(modelIdToModelAliases);
|
|
|
+ Map<String, String> changedAliases = gatherLazyChangedAliasesAndUpdateModelAliases(
|
|
|
+ event,
|
|
|
+ prefetchModels,
|
|
|
+ allReferencedModelKeys
|
|
|
+ );
|
|
|
+
|
|
|
+ // if we are not prefetching, exit now.
|
|
|
+ if (prefetchModels == false) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
referencedModelsBeforeClusterState = new HashSet<>(referencedModels);
|
|
|
if (logger.isTraceEnabled()) {
|
|
|
loadingModelBeforeClusterState = new HashSet<>(loadingListeners.keySet());
|
|
|
}
|
|
|
removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys);
|
|
|
|
|
|
- // Remove all cached models that are not referenced by any processors
|
|
|
- // and are not used in search
|
|
|
- removedModels.forEach(modelId -> {
|
|
|
- ModelAndConsumer modelAndConsumer = localModelCache.get(modelId);
|
|
|
- if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) {
|
|
|
- localModelCache.invalidate(modelId);
|
|
|
- }
|
|
|
- });
|
|
|
// Remove the models that are no longer referenced
|
|
|
referencedModels.removeAll(removedModels);
|
|
|
shouldNotAudit.removeAll(removedModels);
|
|
|
|
|
|
+ // Remove all cached models that are not referenced by any processors
|
|
|
+ // and are not used in search
|
|
|
+ for (String modelAliasOrId : removedModels) {
|
|
|
+ String modelId = changedAliases.getOrDefault(modelAliasOrId, modelAliasToId.getOrDefault(modelAliasOrId, modelAliasOrId));
|
|
|
+ // If the "old" model_alias is referenced, we don't want to invalidate. This way the model that now has the model_alias
|
|
|
+ // can be loaded in first
|
|
|
+ boolean oldModelAliasesNotReferenced = Sets.haveEmptyIntersection(referencedModels,
|
|
|
+ oldIdToAliases.getOrDefault(modelId, Collections.emptySet()));
|
|
|
+ // If the model itself is referenced, we shouldn't evict.
|
|
|
+ boolean modelIsNotReferenced = referencedModels.contains(modelId) == false;
|
|
|
+ // If a model_alias change causes it to NOW be referenced, we shouldn't attempt to evict it
|
|
|
+ boolean newModelAliasesNotReferenced = Sets.haveEmptyIntersection(referencedModels,
|
|
|
+ modelIdToModelAliases.getOrDefault(modelId, Collections.emptySet()));
|
|
|
+ if (oldModelAliasesNotReferenced && newModelAliasesNotReferenced && modelIsNotReferenced) {
|
|
|
+ ModelAndConsumer modelAndConsumer = localModelCache.get(modelId);
|
|
|
+ if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) {
|
|
|
+ logger.trace("[{} ({})] invalidated from cache", modelId, modelAliasOrId);
|
|
|
+ localModelCache.invalidate(modelId);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
// Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels
|
|
|
allReferencedModelKeys.removeAll(referencedModels);
|
|
|
+ for (String newlyReferencedModel : allReferencedModelKeys) {
|
|
|
+ // check if the model_alias has changed in this round
|
|
|
+ String modelId = changedAliases.getOrDefault(
|
|
|
+ newlyReferencedModel,
|
|
|
+ // If the model_alias hasn't changed, get the model id IF it is a model_alias, otherwise we assume it is an id
|
|
|
+ modelAliasToId.getOrDefault(
|
|
|
+ newlyReferencedModel,
|
|
|
+ newlyReferencedModel
|
|
|
+ )
|
|
|
+ );
|
|
|
+ // Verify that it isn't an old model id but just a new model_alias
|
|
|
+ if (referencedModels.contains(modelId) == false) {
|
|
|
+ addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(newlyReferencedModel);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // For any previously referenced model, the model_alias COULD have changed, so it is actually a NEWLY referenced model
|
|
|
+ for (Map.Entry<String, String> modelAliasAndId : changedAliases.entrySet()) {
|
|
|
+ String modelAlias = modelAliasAndId.getKey();
|
|
|
+ String modelId = modelAliasAndId.getValue();
|
|
|
+ if (referencedModels.contains(modelAlias)) {
|
|
|
+ // we need to load the underlying model since its model_alias is referenced
|
|
|
+ addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(modelAlias);
|
|
|
+ // If we are in cache, keep the old translation for now, it will be updated later
|
|
|
+ String oldModelId = modelAliasToId.get(modelAlias);
|
|
|
+ if (oldModelId != null && localModelCache.get(oldModelId) != null) {
|
|
|
+ modelIdToUpdatedModelAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(modelAlias);
|
|
|
+ } else {
|
|
|
+ // If we are not cached, might as well add the translation right away as new callers will have to load
|
|
|
+ // from disk anyways.
|
|
|
+ modelAliasToId.put(modelAlias, modelId);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // Add model_alias and id here, since the model_alias wasn't previously referenced,
|
|
|
+ // no reason to wait on updating the model_alias -> model_id mapping
|
|
|
+ modelAliasToId.put(modelAlias, modelId);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Gather ALL currently referenced model ids
|
|
|
referencedModels.addAll(allReferencedModelKeys);
|
|
|
|
|
|
// Populate loadingListeners key so we know that we are currently loading the model
|
|
|
- for (String modelId : allReferencedModelKeys) {
|
|
|
+ for (String modelId : addedModelViaAliases.keySet()) {
|
|
|
loadingListeners.computeIfAbsent(modelId, (s) -> new ArrayDeque<>());
|
|
|
}
|
|
|
} // synchronized (loadingListeners)
|
|
@@ -503,9 +619,51 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState,
|
|
|
referencedModels);
|
|
|
}
|
|
|
+ if (oldIdToAliases.equals(modelIdToModelAliases) == false) {
|
|
|
+ logger.trace("model id to alias mappings changed. before {} after {}. Model alias to IDs {}",
|
|
|
+ oldIdToAliases,
|
|
|
+ modelIdToModelAliases,
|
|
|
+ modelAliasToId);
|
|
|
+ }
|
|
|
+ if (addedModelViaAliases.isEmpty() == false) {
|
|
|
+ logger.trace("adding new models via model_aliases and ids: {}", addedModelViaAliases);
|
|
|
+ }
|
|
|
+ if (modelIdToUpdatedModelAliases.isEmpty() == false) {
|
|
|
+ logger.trace("delayed model aliases to update {}", modelIdToModelAliases);
|
|
|
+ }
|
|
|
}
|
|
|
removedModels.forEach(this::auditUnreferencedModel);
|
|
|
- loadModelsForPipeline(allReferencedModelKeys);
|
|
|
+ loadModelsForPipeline(addedModelViaAliases.keySet());
|
|
|
+ }
|
|
|
+
|
|
|
+ private Map<String, String> gatherLazyChangedAliasesAndUpdateModelAliases(ClusterChangedEvent event,
|
|
|
+ boolean prefetchModels,
|
|
|
+ Set<String> allReferencedModelKeys) {
|
|
|
+ Map<String, String> changedAliases = new HashMap<>();
|
|
|
+ if (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME)) {
|
|
|
+ final Map<java.lang.String, ModelAliasMetadata.ModelAliasEntry> modelAliasesToIds = new HashMap<>(
|
|
|
+ ModelAliasMetadata.fromState(event.state()).modelAliases()
|
|
|
+ );
|
|
|
+ modelIdToModelAliases.clear();
|
|
|
+ for (Map.Entry<java.lang.String, ModelAliasMetadata.ModelAliasEntry> aliasToId : modelAliasesToIds.entrySet()) {
|
|
|
+ modelIdToModelAliases.computeIfAbsent(aliasToId.getValue().getModelId(), k -> new HashSet<>()).add(aliasToId.getKey());
|
|
|
+ java.lang.String modelId = modelAliasToId.get(aliasToId.getKey());
|
|
|
+ if (modelId != null
|
|
|
+ && modelId.equals(aliasToId.getValue().getModelId()) == false) {
|
|
|
+ if (prefetchModels && allReferencedModelKeys.contains(aliasToId.getKey())) {
|
|
|
+ changedAliases.put(aliasToId.getKey(), aliasToId.getValue().getModelId());
|
|
|
+ } else {
|
|
|
+ modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (modelId == null) {
|
|
|
+ modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Set<java.lang.String> removedAliases = Sets.difference(modelAliasToId.keySet(), modelAliasesToIds.keySet());
|
|
|
+ modelAliasToId.keySet().removeAll(removedAliases);
|
|
|
+ }
|
|
|
+ return changedAliases;
|
|
|
}
|
|
|
|
|
|
private void auditIfNecessary(String modelId, MessageSupplier msg) {
|
|
@@ -600,4 +758,25 @@ public class ModelLoadingService implements ClusterStateListener {
|
|
|
});
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private static class ModelAndConsumerLoader implements CacheLoader<String, ModelAndConsumer> {
|
|
|
+
|
|
|
+ private boolean loaded;
|
|
|
+ private final ModelAndConsumer modelAndConsumer;
|
|
|
+
|
|
|
+ ModelAndConsumerLoader(ModelAndConsumer modelAndConsumer) {
|
|
|
+ this.modelAndConsumer = modelAndConsumer;
|
|
|
+ }
|
|
|
+
|
|
|
+ boolean isLoaded() {
|
|
|
+ return loaded;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public ModelAndConsumer load(String key) throws Exception {
|
|
|
+ loaded = true;
|
|
|
+ modelAndConsumer.model.acquire();
|
|
|
+ return modelAndConsumer;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|