Преглед изворни кода

[ML] Inference API fixing bug where request timeout can be null (#107648)

* Adding test class

* Finishing a test

* Testing timeout from params

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Jonathan Buttner пре 1 година
родитељ
комит
23611b7c95

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

@@ -242,7 +242,7 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             private InputType inputType = InputType.UNSPECIFIED;
             private Map<String, Object> taskSettings = Map.of();
             private String query;
-            private TimeValue timeout;
+            private TimeValue timeout = DEFAULT_TIMEOUT;
 
             private Builder() {}
 

+ 5 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java

@@ -54,13 +54,11 @@ public class RestInferenceAction extends BaseRestHandler {
             requestBuilder = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser);
         }
 
-        if (restRequest.hasParam(InferenceAction.Request.TIMEOUT.getPreferredName())) {
-            var inferTimeout = restRequest.paramAsTime(
-                InferenceAction.Request.TIMEOUT.getPreferredName(),
-                InferenceAction.Request.DEFAULT_TIMEOUT
-            );
-            requestBuilder.setInferenceTimeout(inferTimeout);
-        }
+        var inferTimeout = restRequest.paramAsTime(
+            InferenceAction.Request.TIMEOUT.getPreferredName(),
+            InferenceAction.Request.DEFAULT_TIMEOUT
+        );
+        requestBuilder.setInferenceTimeout(inferTimeout);
         return channel -> client.execute(InferenceAction.INSTANCE, requestBuilder.build(), new RestToXContentListener<>(channel));
     }
 }

+ 82 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java

@@ -0,0 +1,82 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.rest;
+
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.test.rest.FakeRestRequest;
+import org.elasticsearch.test.rest.RestActionTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
+import org.junit.Before;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
+
+public class RestInferenceActionTests extends RestActionTestCase {
+
+    @Before
+    public void setUpAction() {
+        controller().registerHandler(new RestInferenceAction());
+    }
+
+    public void testUsesDefaultTimeout() {
+        SetOnce<Boolean> executeCalled = new SetOnce<>();
+        verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
+            assertThat(actionRequest, instanceOf(InferenceAction.Request.class));
+
+            var request = (InferenceAction.Request) actionRequest;
+            assertThat(request.getInferenceTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT));
+
+            executeCalled.set(true);
+            return createResponse();
+        }));
+
+        RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
+            .withPath("_inference/test")
+            .withContent(new BytesArray("{}"), XContentType.JSON)
+            .build();
+        dispatchRequest(inferenceRequest);
+        assertThat(executeCalled.get(), equalTo(true));
+    }
+
+    public void testUses3SecondTimeoutFromParams() {
+        SetOnce<Boolean> executeCalled = new SetOnce<>();
+        verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
+            assertThat(actionRequest, instanceOf(InferenceAction.Request.class));
+
+            var request = (InferenceAction.Request) actionRequest;
+            assertThat(request.getInferenceTimeout(), is(TimeValue.timeValueSeconds(3)));
+
+            executeCalled.set(true);
+            return createResponse();
+        }));
+
+        RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
+            .withPath("_inference/test")
+            .withParams(new HashMap<>(Map.of("timeout", "3s")))
+            .withContent(new BytesArray("{}"), XContentType.JSON)
+            .build();
+        dispatchRequest(inferenceRequest);
+        assertThat(executeCalled.get(), equalTo(true));
+    }
+
+    private static InferenceAction.Response createResponse() {
+        return new InferenceAction.Response(
+            new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) -1))))
+        );
+    }
+}