Browse Source

[8.17][ML] Enable built-in Inference Endpoints and default for Semantic Text (#116952)

David Kyle 11 months ago
parent
commit
db949d60bb
13 changed files with 39 additions and 98 deletions
  1. 5 0
      docs/changelog/116931.yaml
  2. 1 2
      test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java
  3. 0 2
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java
  4. 3 3
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java
  5. 0 21
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java
  6. 7 10
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java
  7. 2 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
  8. 4 9
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java
  9. 1 9
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetInferenceModelAction.java
  10. 0 6
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java
  11. 14 30
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java
  12. 1 1
      x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml
  13. 1 1
      x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml

+ 5 - 0
docs/changelog/116931.yaml

@@ -0,0 +1,5 @@
+pr: 116931
+summary: Enable built-in Inference Endpoints and default for Semantic Text
+area: "Machine Learning"
+type: enhancement
+issues: []

+ 1 - 2
test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java

@@ -18,8 +18,7 @@ import org.elasticsearch.test.cluster.util.Version;
 public enum FeatureFlag {
     TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
     FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
-    SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
-    INFERENCE_DEFAULT_ELSER("es.inference_default_elser_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
+    SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
 
     public final String systemProperty;
     public final Version from;

+ 0 - 2
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

@@ -39,7 +39,6 @@ public class DefaultEndPointsIT extends InferenceBaseRestTest {
 
     @SuppressWarnings("unchecked")
     public void testInferDeploysDefaultElser() throws IOException {
-        assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
         var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
         assertDefaultElserConfig(model);
 
@@ -70,7 +69,6 @@ public class DefaultEndPointsIT extends InferenceBaseRestTest {
 
     @SuppressWarnings("unchecked")
     public void testInferDeploysDefaultE5() throws IOException {
-        assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
         var model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
         assertDefaultE5Config(model);
 

+ 3 - 3
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

@@ -44,18 +44,18 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         }
 
         var getAllModels = getAllModels();
-        int numModels = DefaultElserFeatureFlag.isEnabled() ? 11 : 9;
+        int numModels = 11;
         assertThat(getAllModels, hasSize(numModels));
 
         var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
-        int numSparseModels = DefaultElserFeatureFlag.isEnabled() ? 6 : 5;
+        int numSparseModels = 6;
         assertThat(getSparseModels, hasSize(numSparseModels));
         for (var sparseModel : getSparseModels) {
             assertEquals("sparse_embedding", sparseModel.get("task_type"));
         }
 
         var getDenseModels = getModels("_all", TaskType.TEXT_EMBEDDING);
-        int numDenseModels = DefaultElserFeatureFlag.isEnabled() ? 5 : 4;
+        int numDenseModels = 5;
         assertThat(getDenseModels, hasSize(numDenseModels));
         for (var denseModel : getDenseModels) {
             assertEquals("text_embedding", denseModel.get("task_type"));

+ 0 - 21
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java

@@ -1,21 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference;
-
-import org.elasticsearch.common.util.FeatureFlag;
-
-public class DefaultElserFeatureFlag {
-
-    private DefaultElserFeatureFlag() {}
-
-    private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_default_elser");
-
-    public static boolean isEnabled() {
-        return FEATURE_FLAG.isEnabled();
-    }
-}

+ 7 - 10
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

@@ -13,7 +13,6 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
 import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
 import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
 
-import java.util.HashSet;
 import java.util.Set;
 
 /**
@@ -23,15 +22,13 @@ public class InferenceFeatures implements FeatureSpecification {
 
     @Override
     public Set<NodeFeature> getFeatures() {
-        var features = new HashSet<NodeFeature>();
-        features.add(TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED);
-        features.add(RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED);
-        features.add(SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID);
-        features.add(TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED);
-        if (DefaultElserFeatureFlag.isEnabled()) {
-            features.add(SemanticTextFieldMapper.SEMANTIC_TEXT_DEFAULT_ELSER_2);
-        }
-        return Set.copyOf(features);
+        return Set.of(
+            TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED,
+            RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED,
+            SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID,
+            TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED,
+            SemanticTextFieldMapper.SEMANTIC_TEXT_DEFAULT_ELSER_2
+        );
     }
 
     @Override

+ 2 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -228,10 +228,8 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
         // reference correctly
         var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
         registry.init(services.client());
-        if (DefaultElserFeatureFlag.isEnabled()) {
-            for (var service : registry.getServices().values()) {
-                service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
-            }
+        for (var service : registry.getServices().values()) {
+            service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
         }
         inferenceServiceRegistry.set(registry);
 

+ 4 - 9
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

@@ -57,7 +57,6 @@ import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.XContentParserConfiguration;
 import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
-import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -111,16 +110,12 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
             INFERENCE_ID_FIELD,
             false,
             mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId,
-            DefaultElserFeatureFlag.isEnabled() ? DEFAULT_ELSER_2_INFERENCE_ID : null
+            DEFAULT_ELSER_2_INFERENCE_ID
         ).addValidator(v -> {
             if (Strings.isEmpty(v)) {
-                // If the default ELSER feature flag is enabled, the only way we get here is if the user explicitly sets the param to an
-                // empty value. However, if the feature flag is disabled, we can get here if the user didn't set the param.
-                // Adjust the error message appropriately.
-                String message = DefaultElserFeatureFlag.isEnabled()
-                    ? "[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must not be empty"
-                    : "[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must be specified";
-                throw new IllegalArgumentException(message);
+                throw new IllegalArgumentException(
+                    "[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must not be empty"
+                );
             }
         });
 

+ 1 - 9
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetInferenceModelAction.java

@@ -15,10 +15,7 @@ import org.elasticsearch.rest.Scope;
 import org.elasticsearch.rest.ServerlessScope;
 import org.elasticsearch.rest.action.RestToXContentListener;
 import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
-import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
 
-import java.util.Collections;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
@@ -69,11 +66,6 @@ public class RestGetInferenceModelAction extends BaseRestHandler {
 
     @Override
     public Set<String> supportedCapabilities() {
-        Set<String> capabilities = new HashSet<>();
-        if (DefaultElserFeatureFlag.isEnabled()) {
-            capabilities.add(DEFAULT_ELSER_2_CAPABILITY);
-        }
-
-        return Collections.unmodifiableSet(capabilities);
+        return Set.of(DEFAULT_ELSER_2_CAPABILITY);
     }
 }

+ 0 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

@@ -35,7 +35,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
-import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 
 import java.io.IOException;
@@ -296,11 +295,6 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
         InferModelAction.Request request,
         ActionListener<InferModelAction.Response> listener
     ) {
-        if (DefaultElserFeatureFlag.isEnabled() == false) {
-            listener.onFailure(e);
-            return;
-        }
-
         if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
             this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> {
                 client.execute(InferModelAction.INSTANCE, request, listener);

+ 14 - 30
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

@@ -61,7 +61,6 @@ import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xcontent.json.JsonXContent;
-import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 import org.elasticsearch.xpack.inference.model.TestModel;
 import org.junit.AssumptionViolatedException;
@@ -103,9 +102,6 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
     @Override
     protected void minimalMapping(XContentBuilder b) throws IOException {
         b.field("type", "semantic_text");
-        if (DefaultElserFeatureFlag.isEnabled() == false) {
-            b.field("inference_id", "test_model");
-        }
     }
 
     @Override
@@ -175,9 +171,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
         DocumentMapper mapper = mapperService.documentMapper();
         assertEquals(Strings.toString(fieldMapping), mapper.mappingSource().toString());
         assertSemanticTextField(mapperService, fieldName, false);
-        if (DefaultElserFeatureFlag.isEnabled()) {
-            assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, DEFAULT_ELSER_2_INFERENCE_ID);
-        }
+        assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, DEFAULT_ELSER_2_INFERENCE_ID);
 
         ParsedDocument doc1 = mapper.parse(source(this::writeField));
         List<IndexableField> fields = doc1.rootDoc().getFields("field");
@@ -211,15 +205,13 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
             assertSerialization.accept(fieldMapping, mapperService);
         }
         {
-            if (DefaultElserFeatureFlag.isEnabled()) {
-                final XContentBuilder fieldMapping = fieldMapping(
-                    b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId)
-                );
-                final MapperService mapperService = createMapperService(fieldMapping);
-                assertSemanticTextField(mapperService, fieldName, false);
-                assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, searchInferenceId);
-                assertSerialization.accept(fieldMapping, mapperService);
-            }
+            final XContentBuilder fieldMapping = fieldMapping(
+                b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId)
+            );
+            final MapperService mapperService = createMapperService(fieldMapping);
+            assertSemanticTextField(mapperService, fieldName, false);
+            assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, searchInferenceId);
+            assertSerialization.accept(fieldMapping, mapperService);
         }
         {
             final XContentBuilder fieldMapping = fieldMapping(
@@ -246,26 +238,18 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
             );
         }
         {
-            final String expectedMessage = DefaultElserFeatureFlag.isEnabled()
-                ? "[inference_id] on mapper [field] of type [semantic_text] must not be empty"
-                : "[inference_id] on mapper [field] of type [semantic_text] must be specified";
             Exception e = expectThrows(
                 MapperParsingException.class,
                 () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, "")))
             );
-            assertThat(e.getMessage(), containsString(expectedMessage));
+            assertThat(e.getMessage(), containsString("[inference_id] on mapper [field] of type [semantic_text] must not be empty"));
         }
         {
-            if (DefaultElserFeatureFlag.isEnabled()) {
-                Exception e = expectThrows(
-                    MapperParsingException.class,
-                    () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, "")))
-                );
-                assertThat(
-                    e.getMessage(),
-                    containsString("[search_inference_id] on mapper [field] of type [semantic_text] must not be empty")
-                );
-            }
+            Exception e = expectThrows(
+                MapperParsingException.class,
+                () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, "")))
+            );
+            assertThat(e.getMessage(), containsString("[search_inference_id] on mapper [field] of type [semantic_text] must not be empty"));
         }
     }
 

+ 1 - 1
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml

@@ -551,7 +551,7 @@ setup:
 ---
 "Calculates embeddings using the default ELSER 2 endpoint":
   - requires:
-      reason: "default ELSER 2 inference ID is behind a feature flag"
+      reason: "default ELSER 2 inference ID is enabled via a capability"
       test_runner_features: [capabilities]
       capabilities:
         - method: GET

+ 1 - 1
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml

@@ -843,7 +843,7 @@ setup:
 ---
 "Query a field that uses the default ELSER 2 endpoint":
   - requires:
-      reason: "default ELSER 2 inference ID is behind a feature flag"
+      reason: "default ELSER 2 inference ID is enabled via a capability"
       test_runner_features: [capabilities]
       capabilities:
         - method: GET