Browse Source

[ML] Fixing bug with TransportPutModelAction listener and adding timeout to request (#126805)

* Fixing bug with listener and adding timeout

* Update docs/changelog/126805.yaml

* Fixing tests

* Fixing writeTo
Jonathan Buttner 5 months ago
parent
commit
4c507e27d9

+ 5 - 0
docs/changelog/126805.yaml

@@ -0,0 +1,5 @@
+pr: 126805
+summary: Adding timeout to request for creating inference endpoint
+area: Machine Learning
+type: bug
+issues: []

+ 2 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -172,6 +172,7 @@ public class TransportVersions {
     public static final TransportVersion INTRODUCE_FAILURES_LIFECYCLE_BACKPORT_8_19 = def(8_841_0_25);
     public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION_BACKPORT_8_19 = def(8_841_0_26);
     public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
+    public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
     public static final TransportVersion V_9_0_0 = def(9_000_0_09);
     public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
     public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -248,6 +249,7 @@ public class TransportVersions {
     public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION = def(9_071_0_00);
     public static final TransportVersion FILE_SETTINGS_HEALTH_INFO = def(9_072_0_00);
     public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00);
+    public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 24 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.core.inference.action;
 
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.action.ActionType;
@@ -15,6 +16,7 @@ import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xcontent.ToXContentObject;
@@ -41,13 +43,15 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
         private final String inferenceEntityId;
         private final BytesReference content;
         private final XContentType contentType;
+        private final TimeValue timeout;
 
-        public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
+        public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) {
             super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
             this.taskType = taskType;
             this.inferenceEntityId = inferenceEntityId;
             this.content = content;
             this.contentType = contentType;
+            this.timeout = timeout;
         }
 
         public Request(StreamInput in) throws IOException {
@@ -56,6 +60,13 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
             this.taskType = TaskType.fromStream(in);
             this.content = in.readBytesReference();
             this.contentType = in.readEnum(XContentType.class);
+
+            if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
+                || in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
+                this.timeout = in.readTimeValue();
+            } else {
+                this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT;
+            }
         }
 
         public TaskType getTaskType() {
@@ -74,6 +85,10 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
             return contentType;
         }
 
+        public TimeValue getTimeout() {
+            return timeout;
+        }
+
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             super.writeTo(out);
@@ -81,6 +96,11 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
             taskType.writeTo(out);
             out.writeBytesReference(content);
             XContentHelper.writeTo(out, contentType);
+
+            if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
+                || out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
+                out.writeTimeValue(timeout);
+            }
         }
 
         @Override
@@ -105,12 +125,13 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
             return taskType == request.taskType
                 && Objects.equals(inferenceEntityId, request.inferenceEntityId)
                 && Objects.equals(content, request.content)
-                && contentType == request.contentType;
+                && contentType == request.contentType
+                && Objects.equals(timeout, request.timeout);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(taskType, inferenceEntityId, content, contentType);
+            return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout);
         }
     }
 

+ 23 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java

@@ -34,13 +34,25 @@ public class PutInferenceModelActionTests extends ESTestCase {
 
     public void testValidate() {
         // valid model ID
-        var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE);
+        var request = new PutInferenceModelAction.Request(
+            TASK_TYPE,
+            MODEL_ID + "_-0",
+            BYTES,
+            X_CONTENT_TYPE,
+            InferenceAction.Request.DEFAULT_TIMEOUT
+        );
         ActionRequestValidationException validationException = request.validate();
         assertNull(validationException);
 
         // invalid model IDs
 
-        var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE);
+        var invalidRequest = new PutInferenceModelAction.Request(
+            TASK_TYPE,
+            "",
+            BYTES,
+            X_CONTENT_TYPE,
+            InferenceAction.Request.DEFAULT_TIMEOUT
+        );
         validationException = invalidRequest.validate();
         assertNotNull(validationException);
 
@@ -48,12 +60,19 @@ public class PutInferenceModelActionTests extends ESTestCase {
             TASK_TYPE,
             randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS),
             BYTES,
-            X_CONTENT_TYPE
+            X_CONTENT_TYPE,
+            InferenceAction.Request.DEFAULT_TIMEOUT
         );
         validationException = invalidRequest2.validate();
         assertNotNull(validationException);
 
-        var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE);
+        var invalidRequest3 = new PutInferenceModelAction.Request(
+            TASK_TYPE,
+            null,
+            BYTES,
+            X_CONTENT_TYPE,
+            InferenceAction.Request.DEFAULT_TIMEOUT
+        );
         validationException = invalidRequest3.validate();
         assertNotNull(validationException);
     }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

@@ -177,7 +177,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
             return;
         }
 
-        parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
+        parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener);
     }
 
     private void parseAndStoreModel(

+ 9 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java

@@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
 import java.util.List;
 
 import static org.elasticsearch.rest.RestRequest.Method.PUT;
+import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
 import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
 import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
 import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
@@ -49,8 +50,15 @@ public class RestPutInferenceModelAction extends BaseRestHandler {
             taskType = TaskType.ANY; // task type must be defined in the body
         }
 
+        var inferTimeout = parseTimeout(restRequest);
         var content = restRequest.requiredContent();
-        var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType());
+        var request = new PutInferenceModelAction.Request(
+            taskType,
+            inferenceEntityId,
+            content,
+            restRequest.getXContentType(),
+            inferTimeout
+        );
         return channel -> client.execute(
             PutInferenceModelAction.INSTANCE,
             request,

+ 20 - 26
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java

@@ -7,13 +7,16 @@
 
 package org.elasticsearch.xpack.inference.action;
 
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
 
-public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase<PutInferenceModelAction.Request> {
+public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase<PutInferenceModelAction.Request> {
     @Override
     protected Writeable.Reader<PutInferenceModelAction.Request> instanceReader() {
         return PutInferenceModelAction.Request::new;
@@ -25,38 +28,29 @@ public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCa
             randomFrom(TaskType.values()),
             randomAlphaOfLength(6),
             randomBytesReference(50),
-            randomFrom(XContentType.values())
+            randomFrom(XContentType.values()),
+            randomTimeValue()
         );
     }
 
     @Override
     protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
-        return switch (randomIntBetween(0, 3)) {
-            case 0 -> new PutInferenceModelAction.Request(
-                TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
-                instance.getInferenceEntityId(),
-                instance.getContent(),
-                instance.getContentType()
-            );
-            case 1 -> new PutInferenceModelAction.Request(
-                instance.getTaskType(),
-                instance.getInferenceEntityId() + "foo",
-                instance.getContent(),
-                instance.getContentType()
-            );
-            case 2 -> new PutInferenceModelAction.Request(
-                instance.getTaskType(),
-                instance.getInferenceEntityId(),
-                randomBytesReference(instance.getContent().length() + 1),
-                instance.getContentType()
-            );
-            case 3 -> new PutInferenceModelAction.Request(
+        return randomValueOtherThan(instance, this::createTestInstance);
+    }
+
+    @Override
+    protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) {
+        if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
+            || version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
+            return instance;
+        } else {
+            return new PutInferenceModelAction.Request(
                 instance.getTaskType(),
                 instance.getInferenceEntityId(),
                 instance.getContent(),
-                XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]
+                instance.getContentType(),
+                InferenceAction.Request.DEFAULT_TIMEOUT
             );
-            default -> throw new IllegalStateException();
-        };
+        }
     }
 }