瀏覽代碼

[ML] Make task_type optional (#104483)

Makes the task_type element of the _inference API optional so that 
it is possible to GET, DELETE or POST to an inference entity without
providing the task type
David Kyle 1 年之前
父節點
當前提交
5f325187cb
共有 26 個文件被更改,包括 369 次插入82 次删除
  1. 5 0
      docs/changelog/104483.yaml
  2. 7 6
      docs/reference/inference/delete-inference.asciidoc
  3. 5 3
      docs/reference/inference/get-inference.asciidoc
  4. 5 4
      docs/reference/inference/post-inference.asciidoc
  5. 15 3
      rest-api-spec/src/main/resources/rest-api-spec/api/inference.delete_model.json
  6. 16 4
      rest-api-spec/src/main/resources/rest-api-spec/api/inference.get_model.json
  7. 16 4
      rest-api-spec/src/main/resources/rest-api-spec/api/inference.inference.json
  8. 16 4
      rest-api-spec/src/main/resources/rest-api-spec/api/inference.put_model.json
  9. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceModelAction.java
  10. 3 8
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
  11. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java
  12. 2 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java
  13. 52 7
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java
  14. 29 1
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java
  15. 1 1
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java
  16. 13 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java
  17. 39 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java
  18. 20 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java
  19. 16 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteInferenceModelAction.java
  20. 12 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetInferenceModelAction.java
  21. 16 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java
  22. 16 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java
  23. 3 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java
  24. 5 5
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java
  25. 49 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelActionTests.java
  26. 4 5
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml

+ 5 - 0
docs/changelog/104483.yaml

@@ -0,0 +1,5 @@
+pr: 104483
+summary: Make `task_type` optional in `_inference` APIs
+area: Machine Learning
+type: enhancement
+issues: []

+ 7 - 6
docs/reference/inference/delete-inference.asciidoc

@@ -6,9 +6,9 @@ experimental[]
 
 Deletes an {infer} model deployment.
 
-IMPORTANT: The {infer} APIs enable you to use certain services, such as ELSER, 
-OpenAI, or Hugging Face, in your cluster. This is not the same feature that you 
-can use on an ML node with custom {ml} models. If you want to train and use your 
+IMPORTANT: The {infer} APIs enable you to use certain services, such as ELSER,
+OpenAI, or Hugging Face, in your cluster. This is not the same feature that you
+can use on an ML node with custom {ml} models. If you want to train and use your
 own model, use the <<ml-df-trained-models-apis>>.
 
 
@@ -16,6 +16,7 @@ own model, use the <<ml-df-trained-models-apis>>.
 [[delete-inference-api-request]]
 ==== {api-request-title}
 
+`DELETE /_inference/<model_id>`
 `DELETE /_inference/<task_type>/<model_id>`
 
 [discrete]
@@ -34,7 +35,7 @@ own model, use the <<ml-df-trained-models-apis>>.
 The unique identifier of the {infer} model to delete.
 
 <task_type>::
-(Required, string)
+(Optional, string)
 The type of {infer} task that the model performs.
 
 
@@ -42,7 +43,7 @@ The type of {infer} task that the model performs.
 [[delete-inference-api-example]]
 ==== {api-examples-title}
 
-The following API call deletes the `my-elser-model` {infer} model that can 
+The following API call deletes the `my-elser-model` {infer} model that can
 perform `sparse_embedding` tasks.
 
 
@@ -61,4 +62,4 @@ The API returns the following response:
   "acknowledged": true
 }
 ------------------------------------------------------------
-// NOTCONSOLE
+// NOTCONSOLE

+ 5 - 3
docs/reference/inference/get-inference.asciidoc

@@ -6,9 +6,9 @@ experimental[]
 
 Retrieves {infer} model information.
 
-IMPORTANT: The {infer} APIs enable you to use certain services, such as ELSER, 
-OpenAI, or Hugging Face, in your cluster. This is not the same feature that you 
-can use on an ML node with custom {ml} models. If you want to train and use your 
+IMPORTANT: The {infer} APIs enable you to use certain services, such as ELSER,
+OpenAI, or Hugging Face, in your cluster. This is not the same feature that you
+can use on an ML node with custom {ml} models. If you want to train and use your
 own model, use the <<ml-df-trained-models-apis>>.
 
 
@@ -18,6 +18,8 @@ own model, use the <<ml-df-trained-models-apis>>.
 
 `GET /_inference/_all`
 
+`GET /_inference/<model_id>`
+
 `GET /_inference/<task_type>/_all`
 
 `GET /_inference/<task_type>/<model_id>`

+ 5 - 4
docs/reference/inference/post-inference.asciidoc

@@ -6,9 +6,9 @@ experimental[]
 
 Performs an inference task on an input text by using an {infer} model.
 
-IMPORTANT: The {infer} APIs enable you to use certain services, such as ELSER, 
-OpenAI, or Hugging Face, in your cluster. This is not the same feature that you 
-can use on an ML node with custom {ml} models. If you want to train and use your 
+IMPORTANT: The {infer} APIs enable you to use certain services, such as ELSER,
+OpenAI, or Hugging Face, in your cluster. This is not the same feature that you
+can use on an ML node with custom {ml} models. If you want to train and use your
 own model, use the <<ml-df-trained-models-apis>>.
 
 
@@ -16,6 +16,7 @@ own model, use the <<ml-df-trained-models-apis>>.
 [[post-inference-api-request]]
 ==== {api-request-title}
 
+`POST /_inference/<model_id>`
 `POST /_inference/<task_type>/<model_id>`
 
 
@@ -46,7 +47,7 @@ The unique identifier of the {infer} model.
 
 
 `<task_type>`::
-(Required, string)
+(Optional, string)
 The type of {infer} task that the model performs.
 
 

+ 15 - 3
rest-api-spec/src/main/resources/rest-api-spec/api/inference.delete_model.json

@@ -12,16 +12,28 @@
     "url":{
       "paths":[
         {
-          "path":"/_inference/{task_type}/{model_id}",
+          "path": "/_inference/{inference_id}",
+          "methods": [
+            "DELETE"
+          ],
+          "parts": {
+            "inference_id": {
+              "type": "string",
+              "description": "The inference Id"
+            }
+          }
+        },
+        {
+          "path":"/_inference/{task_type}/{inference_id}",
           "methods":[
             "DELETE"
           ],
           "parts":{
             "task_type":{
               "type":"string",
-              "description":"The model task type"
+              "description":"The task type"
             },
-            "model_id":{
+            "inference_id":{
               "type":"string",
               "description":"The model Id"
             }

+ 16 - 4
rest-api-spec/src/main/resources/rest-api-spec/api/inference.get_model.json

@@ -12,18 +12,30 @@
     "url":{
       "paths":[
         {
-          "path":"/_inference/{task_type}/{model_id}",
+          "path":"/_inference/{inference_id}",
+          "methods":[
+            "GET"
+          ],
+          "parts":{
+            "inference_id":{
+              "type":"string",
+              "description":"The inference Id"
+            }
+          }
+        },
+        {
+          "path":"/_inference/{task_type}/{inference_id}",
           "methods":[
             "GET"
           ],
           "parts":{
             "task_type":{
               "type":"string",
-              "description":"The model task type"
+              "description":"The task type"
             },
-            "model_id":{
+            "inference_id":{
               "type":"string",
-              "description":"The model Id"
+              "description":"The inference Id"
             }
           }
         }

+ 16 - 4
rest-api-spec/src/main/resources/rest-api-spec/api/inference.inference.json

@@ -13,18 +13,30 @@
     "url":{
       "paths":[
         {
-          "path":"/_inference/{task_type}/{model_id}",
+          "path":"/_inference/{inference_id}",
+          "methods":[
+            "POST"
+          ],
+          "parts":{
+            "inference_id":{
+              "type":"string",
+              "description":"The inference Id"
+            }
+          }
+        },
+        {
+          "path":"/_inference/{task_type}/{inference_id}",
           "methods":[
             "POST"
           ],
           "parts":{
             "task_type":{
               "type":"string",
-              "description":"The model task type"
+              "description":"The task type"
             },
-            "model_id":{
+            "inference_id":{
               "type":"string",
-              "description":"The model Id"
+              "description":"The inference Id"
             }
           }
         }

+ 16 - 4
rest-api-spec/src/main/resources/rest-api-spec/api/inference.put_model.json

@@ -13,18 +13,30 @@
     "url":{
       "paths":[
         {
-          "path":"/_inference/{task_type}/{model_id}",
+          "path":"/_inference/{inference_id}",
+          "methods":[
+            "PUT"
+          ],
+          "parts":{
+            "inference_id":{
+              "type":"string",
+              "description":"The inference Id"
+            }
+          }
+        },
+        {
+          "path":"/_inference/{task_type}/{inference_id}",
           "methods":[
             "PUT"
           ],
           "parts":{
             "task_type":{
               "type":"string",
-              "description":"The model task type"
+              "description":"The task type"
             },
-            "model_id":{
+            "inference_id":{
               "type":"string",
-              "description":"The model Id"
+              "description":"The inference Id"
             }
           }
         }

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceModelAction.java

@@ -31,9 +31,9 @@ public class DeleteInferenceModelAction extends ActionType<AcknowledgedResponse>
         private final String inferenceEntityId;
         private final TaskType taskType;
 
-        public Request(String inferenceEntityId, String taskType) {
+        public Request(String inferenceEntityId, TaskType taskType) {
             this.inferenceEntityId = inferenceEntityId;
-            this.taskType = TaskType.fromStringOrStatusException(taskType);
+            this.taskType = taskType;
         }
 
         public Request(StreamInput in) throws IOException {

+ 3 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

@@ -57,7 +57,7 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
         }
 
-        public static Request parseRequest(String inferenceEntityId, String taskType, XContentParser parser) {
+        public static Request parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) {
             Request.Builder builder = PARSER.apply(parser, null);
             builder.setInferenceEntityId(inferenceEntityId);
             builder.setTaskType(taskType);
@@ -197,13 +197,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
                 return this;
             }
 
-            public Builder setTaskType(String taskTypeStr) {
-                try {
-                    TaskType taskType = TaskType.fromString(taskTypeStr);
-                    this.taskType = Objects.requireNonNull(taskType);
-                } catch (IllegalArgumentException e) {
-                    throw new ElasticsearchStatusException("Unknown task_type [{}]", RestStatus.BAD_REQUEST, taskTypeStr);
-                }
+            public Builder setTaskType(TaskType taskType) {
+                this.taskType = taskType;
                 return this;
             }
 

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java

@@ -42,8 +42,8 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
         private final BytesReference content;
         private final XContentType contentType;
 
-        public Request(String taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
-            this.taskType = TaskType.fromStringOrStatusException(taskType);
+        public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
+            this.taskType = taskType;
             this.inferenceEntityId = inferenceEntityId;
             this.content = content;
             this.contentType = contentType;

+ 2 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java

@@ -19,14 +19,14 @@ import org.junit.Before;
 import java.util.Locale;
 
 public class PutInferenceModelActionTests extends ESTestCase {
-    public static String TASK_TYPE;
+    public static TaskType TASK_TYPE;
     public static String MODEL_ID;
     public static XContentType X_CONTENT_TYPE;
     public static BytesReference BYTES;
 
     @Before
     public void setup() throws Exception {
-        TASK_TYPE = TaskType.ANY.toString();
+        TASK_TYPE = TaskType.SPARSE_EMBEDDING;
         MODEL_ID = randomAlphaOfLengthBetween(1, 10).toLowerCase(Locale.ROOT);
         X_CONTENT_TYPE = randomFrom(XContentType.values());
         BYTES = new BytesArray(randomAlphaOfLengthBetween(1, 10));

+ 52 - 7
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.cluster.ElasticsearchCluster;
 import org.elasticsearch.test.cluster.local.distribution.DistributionType;
@@ -50,8 +51,14 @@ public class InferenceBaseRestTest extends ESRestTestCase {
     }
 
     static String mockServiceModelConfig() {
-        return """
+        return mockServiceModelConfig(null);
+    }
+
+    static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody) {
+        var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
+        return Strings.format("""
             {
+              %s
               "service": "test_service",
               "service_settings": {
                 "model": "my_model",
@@ -61,11 +68,35 @@ public class InferenceBaseRestTest extends ESRestTestCase {
                 "temperature": 3
               }
             }
-            """;
+            """, taskType);
+    }
+
+    protected void deleteModel(String modelId) throws IOException {
+        var request = new Request("DELETE", "_inference/" + modelId);
+        var response = client().performRequest(request);
+        assertOkOrCreated(response);
+    }
+
+    protected void deleteModel(String modelId, TaskType taskType) throws IOException {
+        var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, modelId));
+        var response = client().performRequest(request);
+        assertOkOrCreated(response);
     }
 
     protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
         String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
+        return putModelInternal(endpoint, modelConfig);
+    }
+
+    /**
+     * Task type should be in modelConfig
+     */
+    protected Map<String, Object> putModel(String modelId, String modelConfig) throws IOException {
+        String endpoint = Strings.format("_inference/%s", modelId);
+        return putModelInternal(endpoint, modelConfig);
+    }
+
+    private Map<String, Object> putModelInternal(String endpoint, String modelConfig) throws IOException {
         var request = new Request("PUT", endpoint);
         request.setJsonEntity(modelConfig);
         var response = client().performRequest(request);
@@ -73,24 +104,38 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         return entityAsMap(response);
     }
 
+    protected Map<String, Object> getModel(String modelId) throws IOException {
+        var endpoint = Strings.format("_inference/%s", modelId);
+        return getAllModelInternal(endpoint);
+    }
+
     protected Map<String, Object> getModels(String modelId, TaskType taskType) throws IOException {
         var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
-        var request = new Request("GET", endpoint);
-        var response = client().performRequest(request);
-        assertOkOrCreated(response);
-        return entityAsMap(response);
+        return getAllModelInternal(endpoint);
     }
 
     protected Map<String, Object> getAllModels() throws IOException {
-        var endpoint = Strings.format("_inference/_all");
+        return getAllModelInternal("_inference/_all");
+    }
+
+    private Map<String, Object> getAllModelInternal(String endpoint) throws IOException {
         var request = new Request("GET", endpoint);
         var response = client().performRequest(request);
         assertOkOrCreated(response);
         return entityAsMap(response);
     }
 
+    protected Map<String, Object> inferOnMockService(String modelId, List<String> input) throws IOException {
+        var endpoint = Strings.format("_inference/%s", modelId);
+        return inferOnMockServiceInternal(endpoint, input);
+    }
+
     protected Map<String, Object> inferOnMockService(String modelId, TaskType taskType, List<String> input) throws IOException {
         var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
+        return inferOnMockServiceInternal(endpoint, input);
+    }
+
+    private Map<String, Object> inferOnMockServiceInternal(String endpoint, List<String> input) throws IOException {
         var request = new Request("POST", endpoint);
 
         var bodyBuilder = new StringBuilder("{\"input\": [");

+ 29 - 1
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

@@ -48,6 +48,13 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         var singleModel = (List<Map<String, Object>>) getModels("se_model_1", TaskType.SPARSE_EMBEDDING).get("models");
         assertThat(singleModel, hasSize(1));
         assertEquals("se_model_1", singleModel.get(0).get("model_id"));
+
+        for (int i = 0; i < 5; i++) {
+            deleteModel("se_model_" + i, TaskType.SPARSE_EMBEDDING);
+        }
+        for (int i = 0; i < 4; i++) {
+            deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
+        }
     }
 
     public void testGetModelWithWrongTaskType() throws IOException {
@@ -59,13 +66,34 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         );
     }
 
+    public void testDeleteModelWithWrongTaskType() throws IOException {
+        putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        var e = expectThrows(ResponseException.class, () -> deleteModel("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
+        assertThat(
+            e.getMessage(),
+            containsString("Requested task type [text_embedding] does not match the model's task type [sparse_embedding]")
+        );
+    }
+
     @SuppressWarnings("unchecked")
     public void testGetModelWithAnyTaskType() throws IOException {
         String modelId = "sparse_embedding_model";
         putModel(modelId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
         var singleModel = (List<Map<String, Object>>) getModels(modelId, TaskType.ANY).get("models");
-        System.out.println("MODEL" + singleModel);
         assertEquals(modelId, singleModel.get(0).get("model_id"));
         assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
     }
+
+    @SuppressWarnings("unchecked")
+    public void testApisWithoutTaskType() throws IOException {
+        String modelId = "no_task_type_in_url";
+        putModel(modelId, mockServiceModelConfig(TaskType.SPARSE_EMBEDDING));
+        var singleModel = (List<Map<String, Object>>) getModel(modelId).get("models");
+        assertEquals(modelId, singleModel.get(0).get("model_id"));
+        assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
+
+        var inference = inferOnMockService(modelId, List.of(randomAlphaOfLength(10)));
+        assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
+        deleteModel(modelId);
+    }
 }

+ 1 - 1
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java

@@ -29,7 +29,7 @@ public class MockInferenceServiceIT extends InferenceBaseRestTest {
         }
 
         // The response is randomly generated, the input can be anything
-        var inference = inferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)));
+        var inference = inferOnMockService(modelId, List.of(randomAlphaOfLength(10)));
         assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
     }
 

+ 13 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java

@@ -71,6 +71,19 @@ public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMa
         SubscribableListener.<ModelRegistry.UnparsedModel>newForked(modelConfigListener -> {
             modelRegistry.getModel(request.getInferenceEntityId(), modelConfigListener);
         }).<Boolean>andThen((l1, unparsedModel) -> {
+
+            if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
+                // specific task type in request does not match the models
+                l1.onFailure(
+                    new ElasticsearchStatusException(
+                        "Requested task type [{}] does not match the model's task type [{}]",
+                        RestStatus.BAD_REQUEST,
+                        request.getTaskType(),
+                        unparsedModel.taskType()
+                    )
+                );
+                return;
+            }
             var service = serviceRegistry.getService(unparsedModel.service());
             if (service.isPresent()) {
                 service.get().stop(request.getInferenceEntityId(), l1);

+ 39 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

@@ -96,6 +96,8 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
     ) throws Exception {
 
         var requestAsMap = requestToMap(request);
+        var resolvedTaskType = resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME));
+
         String serviceName = (String) requestAsMap.remove(ModelConfigurations.SERVICE);
         if (serviceName == null) {
             listener.onFailure(new ElasticsearchStatusException("Model configuration is missing a service", RestStatus.BAD_REQUEST));
@@ -151,7 +153,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
                     parseAndStoreModel(
                         service.get(),
                         request.getInferenceEntityId(),
-                        request.getTaskType(),
+                        resolvedTaskType,
                         requestAsMap,
                         // In Elastic cloud ml nodes run on Linux x86
                         Set.of("linux-x86_64"),
@@ -162,7 +164,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
                     parseAndStoreModel(
                         service.get(),
                         request.getInferenceEntityId(),
-                        request.getTaskType(),
+                        resolvedTaskType,
                         requestAsMap,
                         architectures,
                         delegate
@@ -171,7 +173,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
             }), client, threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME));
         } else {
             // Not an in cluster service, it does not care about the cluster platform
-            parseAndStoreModel(service.get(), request.getInferenceEntityId(), request.getTaskType(), requestAsMap, Set.of(), listener);
+            parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, Set.of(), listener);
         }
     }
 
@@ -235,4 +237,38 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
         // One such heuristic is where USE_AUTO_MACHINE_MEMORY_PERCENT == true
         return settings.get(MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT);
     }
+
+    /**
+     * task_type can be specified as either a URL parameter or in the
+     * request body. Resolve which to use or throw if the settings are
+     * inconsistent
+     * @param urlTaskType Taken from the URL parameter. ANY means not specified.
+     * @param bodyTaskType Taken from the request body. Maybe null
+     * @return The resolved task type
+     */
+    static TaskType resolveTaskType(TaskType urlTaskType, String bodyTaskType) {
+        if (bodyTaskType == null) {
+            if (urlTaskType == TaskType.ANY) {
+                throw new ElasticsearchStatusException("model is missing required setting [task_type]", RestStatus.BAD_REQUEST);
+            } else {
+                return urlTaskType;
+            }
+        }
+
+        TaskType parsedBodyTask = TaskType.fromStringOrStatusException(bodyTaskType);
+        if (parsedBodyTask == TaskType.ANY) {
+            throw new ElasticsearchStatusException("task_type [any] is not valid type for inference", RestStatus.BAD_REQUEST);
+        }
+
+        if (parsedBodyTask.isAnyOrSame(urlTaskType) == false) {
+            throw new ElasticsearchStatusException(
+                "Cannot resolve conflicting task_type parameter in the request URL [{}] and the request body [{}]",
+                RestStatus.BAD_REQUEST,
+                urlTaskType.toString(),
+                bodyTaskType
+            );
+        }
+
+        return parsedBodyTask;
+    }
 }

+ 20 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java

@@ -0,0 +1,20 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.rest;
+
+public final class Paths {
+
+    static final String INFERENCE_ID = "inference_id";
+    static final String TASK_TYPE_OR_INFERENCE_ID = "task_type_or_id";
+    static final String INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}";
+    static final String TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/{" + INFERENCE_ID + "}";
+
+    private Paths() {
+
+    }
+}

+ 16 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteInferenceModelAction.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.inference.rest;
 
 import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.BaseRestHandler;
 import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.rest.Scope;
@@ -18,9 +19,14 @@ import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction;
 import java.util.List;
 
 import static org.elasticsearch.rest.RestRequest.Method.DELETE;
+import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
+import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
+import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
+import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID;
 
 @ServerlessScope(Scope.PUBLIC)
 public class RestDeleteInferenceModelAction extends BaseRestHandler {
+
     @Override
     public String getName() {
         return "delete_inference_model_action";
@@ -28,13 +34,20 @@ public class RestDeleteInferenceModelAction extends BaseRestHandler {
 
     @Override
     public List<Route> routes() {
-        return List.of(new Route(DELETE, "_inference/{task_type}/{model_id}"));
+        return List.of(new Route(DELETE, INFERENCE_ID_PATH), new Route(DELETE, TASK_TYPE_INFERENCE_ID_PATH));
     }
 
     @Override
     protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) {
-        String taskType = restRequest.param("task_type");
-        String inferenceEntityId = restRequest.param("model_id");
+        String inferenceEntityId;
+        TaskType taskType;
+        if (restRequest.hasParam(INFERENCE_ID)) {
+            inferenceEntityId = restRequest.param(INFERENCE_ID);
+            taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID));
+        } else {
+            inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID);
+            taskType = TaskType.ANY;
+        }
 
         var request = new DeleteInferenceModelAction.Request(inferenceEntityId, taskType);
         return channel -> client.execute(DeleteInferenceModelAction.INSTANCE, request, new RestToXContentListener<>(channel));

+ 12 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetInferenceModelAction.java

@@ -19,9 +19,14 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
 import java.util.List;
 
 import static org.elasticsearch.rest.RestRequest.Method.GET;
+import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
+import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
+import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
+import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID;
 
 @ServerlessScope(Scope.PUBLIC)
 public class RestGetInferenceModelAction extends BaseRestHandler {
+
     @Override
     public String getName() {
         return "get_inference_model_action";
@@ -29,20 +34,23 @@ public class RestGetInferenceModelAction extends BaseRestHandler {
 
     @Override
     public List<Route> routes() {
-        return List.of(new Route(GET, "_inference/{task_type}/{model_id}"), new Route(GET, "_inference/_all"));
+        return List.of(new Route(GET, "_inference/_all"), new Route(GET, INFERENCE_ID_PATH), new Route(GET, TASK_TYPE_INFERENCE_ID_PATH));
     }
 
     @Override
     protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) {
         String inferenceEntityId = null;
         TaskType taskType = null;
-        if (restRequest.hasParam("task_type") == false && restRequest.hasParam("model_id") == false) {
+        if (restRequest.hasParam(TASK_TYPE_OR_INFERENCE_ID) == false && restRequest.hasParam(INFERENCE_ID) == false) {
             // _all models request
             inferenceEntityId = "_all";
             taskType = TaskType.ANY;
+        } else if (restRequest.hasParam(INFERENCE_ID)) {
+            inferenceEntityId = restRequest.param(INFERENCE_ID);
+            taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID));
         } else {
-            taskType = TaskType.fromStringOrStatusException(restRequest.param("task_type"));
-            inferenceEntityId = restRequest.param("model_id");
+            inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID);
+            taskType = TaskType.ANY;
         }
 
         var request = new GetInferenceModelAction.Request(inferenceEntityId, taskType);

+ 16 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.inference.rest;
 
 import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.BaseRestHandler;
 import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.rest.Scope;
@@ -19,6 +20,10 @@ import java.io.IOException;
 import java.util.List;
 
 import static org.elasticsearch.rest.RestRequest.Method.POST;
+import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
+import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
+import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
+import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID;
 
 @ServerlessScope(Scope.PUBLIC)
 public class RestInferenceAction extends BaseRestHandler {
@@ -29,13 +34,21 @@ public class RestInferenceAction extends BaseRestHandler {
 
     @Override
     public List<Route> routes() {
-        return List.of(new Route(POST, "_inference/{task_type}/{model_id}"));
+        return List.of(new Route(POST, INFERENCE_ID_PATH), new Route(POST, TASK_TYPE_INFERENCE_ID_PATH));
     }
 
     @Override
     protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
-        String taskType = restRequest.param("task_type");
-        String inferenceEntityId = restRequest.param("model_id");
+        String inferenceEntityId;
+        TaskType taskType;
+        if (restRequest.hasParam(INFERENCE_ID)) {
+            inferenceEntityId = restRequest.param(INFERENCE_ID);
+            taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID));
+        } else {
+            inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID);
+            taskType = TaskType.ANY;
+        }
+
         try (var parser = restRequest.contentParser()) {
             var request = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser);
             return channel -> client.execute(InferenceAction.INSTANCE, request, new RestToXContentListener<>(channel));

+ 16 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.inference.rest;
 
 import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.BaseRestHandler;
 import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.rest.Scope;
@@ -15,10 +16,13 @@ import org.elasticsearch.rest.ServerlessScope;
 import org.elasticsearch.rest.action.RestToXContentListener;
 import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
 
-import java.io.IOException;
 import java.util.List;
 
 import static org.elasticsearch.rest.RestRequest.Method.PUT;
+import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
+import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
+import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
+import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID;
 
 @ServerlessScope(Scope.PUBLIC)
 public class RestPutInferenceModelAction extends BaseRestHandler {
@@ -29,13 +33,20 @@ public class RestPutInferenceModelAction extends BaseRestHandler {
 
     @Override
     public List<Route> routes() {
-        return List.of(new Route(PUT, "_inference/{task_type}/{model_id}"));
+        return List.of(new Route(PUT, INFERENCE_ID_PATH), new Route(PUT, TASK_TYPE_INFERENCE_ID_PATH));
     }
 
     @Override
-    protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
-        String taskType = restRequest.param("task_type");
-        String inferenceEntityId = restRequest.param("model_id");
+    protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) {
+        String inferenceEntityId;
+        TaskType taskType;
+        if (restRequest.hasParam(INFERENCE_ID)) {
+            inferenceEntityId = restRequest.param(INFERENCE_ID);
+            taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID));
+        } else {
+            inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID);
+            taskType = TaskType.ANY; // task type must be defined in the body
+        }
 
         var request = new PutInferenceModelAction.Request(
             taskType,

+ 3 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java

@@ -51,7 +51,7 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             }
             """;
         try (var parser = createParser(JsonXContent.jsonXContent, singleInputRequest)) {
-            var request = InferenceAction.Request.parseRequest("model_id", "sparse_embedding", parser);
+            var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser);
             assertThat(request.getInput(), contains("single text input"));
         }
 
@@ -61,7 +61,7 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             }
             """;
         try (var parser = createParser(JsonXContent.jsonXContent, multiInputRequest)) {
-            var request = InferenceAction.Request.parseRequest("model_id", "sparse_embedding", parser);
+            var request = InferenceAction.Request.parseRequest("model_id", TaskType.ANY, parser);
             assertThat(request.getInput(), contains("an array", "of", "inputs"));
         }
     }
@@ -73,7 +73,7 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             }
             """;
         try (var parser = createParser(JsonXContent.jsonXContent, singleInputRequest)) {
-            var request = InferenceAction.Request.parseRequest("model_id", "sparse_embedding", parser);
+            var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser);
             assertThat(request.getInputType(), is(InputType.UNSPECIFIED));
         }
     }

+ 5 - 5
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java

@@ -22,7 +22,7 @@ public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCa
     @Override
     protected PutInferenceModelAction.Request createTestInstance() {
         return new PutInferenceModelAction.Request(
-            randomFrom(TaskType.values()).toString(),
+            randomFrom(TaskType.values()),
             randomAlphaOfLength(6),
             randomBytesReference(50),
             randomFrom(XContentType.values())
@@ -33,25 +33,25 @@ public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCa
     protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
         return switch (randomIntBetween(0, 3)) {
             case 0 -> new PutInferenceModelAction.Request(
-                TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length].toString(),
+                TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
                 instance.getInferenceEntityId(),
                 instance.getContent(),
                 instance.getContentType()
             );
             case 1 -> new PutInferenceModelAction.Request(
-                instance.getTaskType().toString(),
+                instance.getTaskType(),
                 instance.getInferenceEntityId() + "foo",
                 instance.getContent(),
                 instance.getContentType()
             );
             case 2 -> new PutInferenceModelAction.Request(
-                instance.getTaskType().toString(),
+                instance.getTaskType(),
                 instance.getInferenceEntityId(),
                 randomBytesReference(instance.getContent().length() + 1),
                 instance.getContentType()
             );
             case 3 -> new PutInferenceModelAction.Request(
-                instance.getTaskType().toString(),
+                instance.getTaskType(),
                 instance.getInferenceEntityId(),
                 instance.getContent(),
                 XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]

+ 49 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelActionTests.java

@@ -0,0 +1,49 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.action;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class TransportPutInferenceModelActionTests extends ESTestCase {
+
+    public void testResolveTaskType() {
+
+        assertEquals(TaskType.SPARSE_EMBEDDING, TransportPutInferenceModelAction.resolveTaskType(TaskType.SPARSE_EMBEDDING, null));
+        assertEquals(
+            TaskType.SPARSE_EMBEDDING,
+            TransportPutInferenceModelAction.resolveTaskType(TaskType.ANY, TaskType.SPARSE_EMBEDDING.toString())
+        );
+
+        var e = expectThrows(
+            ElasticsearchStatusException.class,
+            () -> TransportPutInferenceModelAction.resolveTaskType(TaskType.ANY, null)
+        );
+        assertThat(e.getMessage(), containsString("model is missing required setting [task_type]"));
+
+        e = expectThrows(
+            ElasticsearchStatusException.class,
+            () -> TransportPutInferenceModelAction.resolveTaskType(TaskType.ANY, TaskType.ANY.toString())
+        );
+        assertThat(e.getMessage(), containsString("task_type [any] is not valid type for inference"));
+
+        e = expectThrows(
+            ElasticsearchStatusException.class,
+            () -> TransportPutInferenceModelAction.resolveTaskType(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING.toString())
+        );
+        assertThat(
+            e.getMessage(),
+            containsString(
+                "Cannot resolve conflicting task_type parameter in the request URL [sparse_embedding] and the request body [text_embedding]"
+            )
+        );
+    }
+}

+ 4 - 5
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml

@@ -3,8 +3,7 @@
   - do:
       catch: missing
       inference.get_model:
-        task_type: sparse_embedding
-        model_id: model_to_get
+        inference_id: model_to_get
   - match: { error.type: "resource_not_found_exception" }
   - match: { error.reason: "Model not found [model_to_get]" }
 
@@ -13,10 +12,10 @@
   - do:
       catch: bad_request
       inference.put_model:
-        task_type: bad
-        model_id: elser_model
+        inference_id: elser_model
         body: >
           {
+            "task_type": "bad",
             "service": "elser",
             "service_settings": {
               "num_allocations": 1,
@@ -33,7 +32,7 @@
       catch: bad_request
       inference.inference:
         task_type: bad
-        model_id: elser_model
+        inference_id: elser_model
         body: >
           {
             "input": "important text"