Browse Source

Expose model registry to SemanticTextFieldMapper (#126635) (#126817)

This change integrates the new model registry with the `SemanticTextFieldMapper`, allowing inference IDs to be eagerly resolved at parse time.
It also preserves the existing lenient behavior: no error is thrown if the specified inference id does not exist, only a warning is logged.
Jim Ferenczi 6 months ago
parent
commit
7bedd0d585

+ 1 - 6
server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

@@ -242,10 +242,6 @@ public record MinimalServiceSettings(
         }
     }
 
-    public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
-        return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
-    }
-
     /**
      * Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
      */
@@ -253,7 +249,6 @@ public record MinimalServiceSettings(
         return taskType == other.taskType
             && Objects.equals(dimensions, other.dimensions)
             && similarity == other.similarity
-            && elementType == other.elementType
-            && (service == null || service.equals(other.service));
+            && elementType == other.elementType;
     }
 }

+ 18 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -199,6 +199,7 @@ public class InferencePlugin extends Plugin
     private final SetOnce<ElasticInferenceServiceComponents> elasticInferenceServiceComponents = new SetOnce<>();
     private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
     private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
+    private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();
     private List<InferenceServiceExtension> inferenceServiceExtensions;
 
     public InferencePlugin(Settings settings) {
@@ -262,8 +263,8 @@ public class InferencePlugin extends Plugin
         var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService());
         amazonBedrockFactory.set(amazonBedrockRequestSenderFactory);
 
-        ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client());
-        services.clusterService().addListener(modelRegistry);
+        modelRegistry.set(new ModelRegistry(services.clusterService(), services.client()));
+        services.clusterService().addListener(modelRegistry.get());
 
         if (inferenceServiceExtensions == null) {
             inferenceServiceExtensions = new ArrayList<>();
@@ -301,7 +302,7 @@ public class InferencePlugin extends Plugin
                     elasicInferenceServiceFactory.get(),
                     serviceComponents.get(),
                     inferenceServiceSettings,
-                    modelRegistry,
+                    modelRegistry.get(),
                     authorizationHandler
                 )
             )
@@ -319,18 +320,23 @@ public class InferencePlugin extends Plugin
         var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
         serviceRegistry.init(services.client());
         for (var service : serviceRegistry.getServices().values()) {
-            service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
+            service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds);
         }
         inferenceServiceRegistry.set(serviceRegistry);
 
-        var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
+        var actionFilter = new ShardBulkInferenceActionFilter(
+            services.clusterService(),
+            serviceRegistry,
+            modelRegistry.get(),
+            getLicenseState()
+        );
         shardBulkInferenceActionFilter.set(actionFilter);
 
         var meterRegistry = services.telemetryProvider().getMeterRegistry();
         var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
 
         components.add(serviceRegistry);
-        components.add(modelRegistry);
+        components.add(modelRegistry.get());
         components.add(httpClientManager);
         components.add(inferenceStats);
 
@@ -497,11 +503,16 @@ public class InferencePlugin extends Plugin
         return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER);
     }
 
+    // Overridable for tests
+    protected Supplier<ModelRegistry> getModelRegistry() {
+        return modelRegistry::get;
+    }
+
     @Override
     public Map<String, Mapper.TypeParser> getMappers() {
         return Map.of(
             SemanticTextFieldMapper.CONTENT_TYPE,
-            SemanticTextFieldMapper.PARSER,
+            SemanticTextFieldMapper.parser(getModelRegistry()),
             OffsetSourceFieldMapper.CONTENT_TYPE,
             OffsetSourceFieldMapper.PARSER
         );

+ 56 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

@@ -7,6 +7,8 @@
 
 package org.elasticsearch.xpack.inference.mapper;
 
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
 import org.apache.lucene.index.FieldInfos;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.DocIdSetIterator;
@@ -18,6 +20,7 @@ import org.apache.lucene.search.Weight;
 import org.apache.lucene.search.join.BitSetProducer;
 import org.apache.lucene.search.join.ScoreMode;
 import org.apache.lucene.util.BitSet;
+import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
@@ -75,6 +78,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
 import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
 import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
+import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 
 import java.io.IOException;
 import java.io.UncheckedIOException;
@@ -89,6 +93,7 @@ import java.util.Optional;
 import java.util.Set;
 import java.util.function.BiConsumer;
 import java.util.function.Function;
+import java.util.function.Supplier;
 
 import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
 import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
@@ -112,6 +117,7 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.Elasticse
  * A {@link FieldMapper} for semantic text fields.
  */
 public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper {
+    private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class);
     public static final NodeFeature SEMANTIC_TEXT_SEARCH_INFERENCE_ID = new NodeFeature("semantic_text.search_inference_id", true);
     public static final NodeFeature SEMANTIC_TEXT_DEFAULT_ELSER_2 = new NodeFeature("semantic_text.default_elser_2", true);
     public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix");
@@ -129,10 +135,12 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
     public static final String CONTENT_TYPE = "semantic_text";
     public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
 
-    public static final TypeParser PARSER = new TypeParser(
-        (n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()),
-        List.of(validateParserContext(CONTENT_TYPE))
-    );
+    public static final TypeParser parser(Supplier<ModelRegistry> modelRegistry) {
+        return new TypeParser(
+            (n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings(), modelRegistry.get()),
+            List.of(validateParserContext(CONTENT_TYPE))
+        );
+    }
 
     public static BiConsumer<String, MappingParserContext> validateParserContext(String type) {
         return (n, c) -> {
@@ -144,6 +152,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
     }
 
     public static class Builder extends FieldMapper.Builder {
+        private final ModelRegistry modelRegistry;
         private final boolean useLegacyFormat;
 
         private final Parameter<String> inferenceId = Parameter.stringParam(
@@ -201,14 +210,21 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
             Builder builder = new Builder(
                 mapper.leafName(),
                 mapper.fieldType().getChunksField().bitsetProducer(),
-                mapper.fieldType().getChunksField().indexSettings()
+                mapper.fieldType().getChunksField().indexSettings(),
+                mapper.modelRegistry
             );
             builder.init(mapper);
             return builder;
         }
 
-        public Builder(String name, Function<Query, BitSetProducer> bitSetProducer, IndexSettings indexSettings) {
+        public Builder(
+            String name,
+            Function<Query, BitSetProducer> bitSetProducer,
+            IndexSettings indexSettings,
+            ModelRegistry modelRegistry
+        ) {
             super(name);
+            this.modelRegistry = modelRegistry;
             this.useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(indexSettings.getSettings()) == false;
             this.inferenceFieldBuilder = c -> createInferenceField(
                 c,
@@ -266,9 +282,32 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
             if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) {
                 throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields");
             }
+
+            if (modelSettings.get() == null) {
+                try {
+                    var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
+                    if (resolvedModelSettings != null) {
+                        modelSettings.setValue(resolvedModelSettings);
+                    }
+                } catch (ResourceNotFoundException exc) {
+                    // We allow the inference ID to be unregistered at this point.
+                    // This will delay the creation of sub-fields, so indexing and querying for this field won't work
+                    // until the corresponding inference endpoint is created.
+                }
+            }
+
             if (modelSettings.get() != null) {
                 validateServiceSettings(modelSettings.get());
+            } else {
+                logger.warn(
+                    "The field [{}] references an unknown inference ID [{}]. "
+                        + "Indexing and querying this field will not work correctly until the corresponding "
+                        + "inference endpoint is created.",
+                    leafName(),
+                    inferenceId.get()
+                );
             }
+
             final String fullName = context.buildFullName(leafName());
 
             if (context.isInNestedContext()) {
@@ -289,7 +328,8 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
                     useLegacyFormat,
                     meta.getValue()
                 ),
-                builderParams(this, context)
+                builderParams(this, context),
+                modelRegistry
             );
         }
 
@@ -330,9 +370,17 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
         }
     }
 
-    private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) {
+    private final ModelRegistry modelRegistry;
+
+    private SemanticTextFieldMapper(
+        String simpleName,
+        MappedFieldType mappedFieldType,
+        BuilderParams builderParams,
+        ModelRegistry modelRegistry
+    ) {
         super(simpleName, mappedFieldType, builderParams);
         ensureMultiFields(builderParams.multiFields().iterator());
+        this.modelRegistry = modelRegistry;
     }
 
     private void ensureMultiFields(Iterator<FieldMapper> mappers) {

+ 16 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

@@ -139,7 +139,6 @@ public class ModelRegistry implements ClusterStateListener {
     private static final String MODEL_ID_FIELD = "model_id";
     private static final Logger logger = LogManager.getLogger(ModelRegistry.class);
 
-    private final ClusterService clusterService;
     private final OriginSettingClient client;
     private final Map<String, InferenceService.DefaultConfigId> defaultConfigIds;
 
@@ -147,10 +146,11 @@ public class ModelRegistry implements ClusterStateListener {
     private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false);
     private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>());
 
+    private volatile Metadata lastMetadata;
+
     public ModelRegistry(ClusterService clusterService, Client client) {
         this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
         this.defaultConfigIds = new ConcurrentHashMap<>();
-        this.clusterService = clusterService;
         var executor = new SimpleBatchedAckListenerTaskExecutor<MetadataTask>() {
             @Override
             public Tuple<ClusterState, ClusterStateAckListener> executeTask(MetadataTask task, ClusterState clusterState) throws Exception {
@@ -222,11 +222,17 @@ public class ModelRegistry implements ClusterStateListener {
      * @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster.
      */
     public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
+        synchronized (this) {
+            assert lastMetadata != null : "initial cluster state not set yet";
+            if (lastMetadata == null) {
+                throw new IllegalStateException("initial cluster state not set yet");
+            }
+        }
         var config = defaultConfigIds.get(inferenceEntityId);
         if (config != null) {
             return config.settings();
         }
-        var state = ModelRegistryMetadata.fromState(clusterService.state().metadata());
+        var state = ModelRegistryMetadata.fromState(lastMetadata);
         var existing = state.getMinimalServiceSettings(inferenceEntityId);
         if (state.isUpgraded() && existing == null) {
             throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster.");
@@ -931,6 +937,13 @@ public class ModelRegistry implements ClusterStateListener {
 
     @Override
     public void clusterChanged(ClusterChangedEvent event) {
+        if (lastMetadata == null || event.metadataChanged()) {
+            // keep track of the last applied cluster state
+            synchronized (this) {
+                lastMetadata = event.state().metadata();
+            }
+        }
+
         if (event.localNodeMaster() == false) {
             return;
         }

+ 34 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

@@ -24,6 +24,7 @@ import org.apache.lucene.search.join.BitSetProducer;
 import org.apache.lucene.search.join.QueryBitSetProducer;
 import org.apache.lucene.search.join.ScoreMode;
 import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
+import org.elasticsearch.cluster.ClusterChangedEvent;
 import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.CheckedBiFunction;
 import org.elasticsearch.common.Strings;
@@ -62,6 +63,9 @@ import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.search.LeafNestedDocuments;
 import org.elasticsearch.search.NestedDocuments;
 import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.test.ClusterServiceUtils;
+import org.elasticsearch.test.client.NoOpClient;
+import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xcontent.json.JsonXContent;
@@ -69,7 +73,10 @@ import org.elasticsearch.xpack.core.XPackClientPlugin;
 import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 import org.elasticsearch.xpack.inference.model.TestModel;
+import org.elasticsearch.xpack.inference.registry.ModelRegistry;
+import org.junit.After;
 import org.junit.AssumptionViolatedException;
+import org.junit.Before;
 
 import java.io.IOException;
 import java.util.Collection;
@@ -79,6 +86,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.BiConsumer;
+import java.util.function.Supplier;
 
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD;
@@ -100,10 +108,22 @@ import static org.hamcrest.Matchers.instanceOf;
 public class SemanticTextFieldMapperTests extends MapperTestCase {
     private final boolean useLegacyFormat;
 
+    private TestThreadPool threadPool;
+
     public SemanticTextFieldMapperTests(boolean useLegacyFormat) {
         this.useLegacyFormat = useLegacyFormat;
     }
 
+    @Before
+    private void startThreadPool() {
+        threadPool = createThreadPool();
+    }
+
+    @After
+    private void stopThreadPool() {
+        threadPool.close();
+    }
+
     @ParametersFactory
     public static Iterable<Object[]> parameters() throws Exception {
         return List.of(new Object[] { true }, new Object[] { false });
@@ -111,7 +131,20 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
 
     @Override
     protected Collection<? extends Plugin> getPlugins() {
-        return List.of(new InferencePlugin(Settings.EMPTY), new XPackClientPlugin());
+        var clusterService = ClusterServiceUtils.createClusterService(threadPool);
+        var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
+        modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
+            @Override
+            public boolean localNodeMaster() {
+                return false;
+            }
+        });
+        return List.of(new InferencePlugin(Settings.EMPTY) {
+            @Override
+            protected Supplier<ModelRegistry> getModelRegistry() {
+                return () -> modelRegistry;
+            }
+        }, new XPackClientPlugin());
     }
 
     private MapperService createMapperService(XContentBuilder mappings, boolean useLegacyFormat) throws IOException {

+ 41 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java

@@ -22,12 +22,14 @@ import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.cluster.ClusterChangedEvent;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.compress.CompressedXContent;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
 import org.elasticsearch.index.mapper.MapperService;
@@ -46,6 +48,9 @@ import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.test.AbstractQueryTestCase;
+import org.elasticsearch.test.ClusterServiceUtils;
+import org.elasticsearch.test.client.NoOpClient;
+import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xcontent.json.JsonXContent;
@@ -60,6 +65,8 @@ import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
 import org.elasticsearch.xpack.core.ml.search.WeightedToken;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
+import org.elasticsearch.xpack.inference.registry.ModelRegistry;
+import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.BeforeClass;
 
@@ -70,6 +77,7 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Supplier;
 
 import static org.apache.lucene.search.BooleanClause.Occur.FILTER;
 import static org.apache.lucene.search.BooleanClause.Occur.MUST;
@@ -118,6 +126,24 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
         useSearchInferenceId = randomBoolean();
     }
 
+    @BeforeClass
+    public static void startModelRegistry() {
+        threadPool = new TestThreadPool(SemanticQueryBuilderTests.class.getName());
+        var clusterService = ClusterServiceUtils.createClusterService(threadPool);
+        modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
+        modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
+            @Override
+            public boolean localNodeMaster() {
+                return false;
+            }
+        });
+    }
+
+    @AfterClass
+    public static void stopModelRegistry() {
+        IOUtils.closeWhileHandlingException(threadPool);
+    }
+
     @Override
     @Before
     public void setUp() throws Exception {
@@ -127,7 +153,7 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
 
     @Override
     protected Collection<Class<? extends Plugin>> getPlugins() {
-        return List.of(XPackClientPlugin.class, InferencePlugin.class, FakeMlPlugin.class);
+        return List.of(XPackClientPlugin.class, InferencePluginWithModelRegistry.class, FakeMlPlugin.class);
     }
 
     @Override
@@ -394,4 +420,18 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
             return new MlInferenceNamedXContentProvider().getNamedWriteables();
         }
     }
+
+    private static TestThreadPool threadPool;
+    private static ModelRegistry modelRegistry;
+
+    public static class InferencePluginWithModelRegistry extends InferencePlugin {
+        public InferencePluginWithModelRegistry(Settings settings) {
+            super(settings);
+        }
+
+        @Override
+        protected Supplier<ModelRegistry> getModelRegistry() {
+            return () -> modelRegistry;
+        }
+    }
 }

+ 27 - 0
x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java

@@ -9,13 +9,20 @@ package org.elasticsearch.xpack.inference;
 
 import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
 
+import org.elasticsearch.client.Request;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.test.cluster.ElasticsearchCluster;
 import org.elasticsearch.test.cluster.local.distribution.DistributionType;
 import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
 import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
+import org.junit.After;
 import org.junit.ClassRule;
 
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
 public class InferenceRestIT extends ESClientYamlSuiteTestCase {
 
     @ClassRule
@@ -50,4 +57,24 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase {
     public static Iterable<Object[]> parameters() throws Exception {
         return ESClientYamlSuiteTestCase.createParameters();
     }
+
+    @After
+    public void cleanup() throws Exception {
+        for (var model : getAllModels()) {
+            var inferenceId = model.get("inference_id");
+            try {
+                var endpoint = Strings.format("_inference/%s?force", inferenceId);
+                adminClient().performRequest(new Request("DELETE", endpoint));
+            } catch (Exception ex) {
+                logger.warn(() -> "failed to delete inference endpoint " + inferenceId, ex);
+            }
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    static List<Map<String, Object>> getAllModels() throws IOException {
+        var request = new Request("GET", "_inference/_all");
+        var response = client().performRequest(request);
+        return (List<Map<String, Object>>) entityAsMap(response).get("endpoints");
+    }
 }