|
@@ -37,6 +37,7 @@ import org.elasticsearch.inference.Model;
|
|
|
import org.elasticsearch.inference.ModelConfigurations;
|
|
|
import org.elasticsearch.inference.SimilarityMeasure;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
+import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.xcontent.ParseField;
|
|
@@ -49,13 +50,16 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
|
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
|
|
import org.elasticsearch.xpack.core.ml.MachineLearningField;
|
|
|
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
|
|
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
|
@@ -1832,6 +1836,49 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException {
|
|
|
+ var model = new ElserInternalModel(
|
|
|
+ "inference_id",
|
|
|
+ TaskType.SPARSE_EMBEDDING,
|
|
|
+ "elasticsearch",
|
|
|
+ new ElserInternalServiceSettings(
|
|
|
+ new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null)
|
|
|
+ ),
|
|
|
+ new ElserMlNodeTaskSettings(),
|
|
|
+ null
|
|
|
+ );
|
|
|
+
|
|
|
+ var client = mock(Client.class);
|
|
|
+ when(client.threadPool()).thenReturn(threadPool);
|
|
|
+
|
|
|
+ doAnswer(invocationOnMock -> {
|
|
|
+ ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2);
|
|
|
+ var builder = GetTrainedModelsAction.Response.builder();
|
|
|
+ builder.setModels(List.of(mock(TrainedModelConfig.class)));
|
|
|
+ builder.setTotalCount(1);
|
|
|
+
|
|
|
+ listener.onResponse(builder.build());
|
|
|
+ return Void.TYPE;
|
|
|
+ }).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any());
|
|
|
+
|
|
|
+ doAnswer(invocationOnMock -> {
|
|
|
+ ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2);
|
|
|
+ listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT));
|
|
|
+ return Void.TYPE;
|
|
|
+ }).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any());
|
|
|
+
|
|
|
+ try (var service = createService(client)) {
|
|
|
+ var actionListener = new PlainActionFuture<Boolean>();
|
|
|
+ service.start(model, TimeValue.timeValueSeconds(30), actionListener);
|
|
|
+ var exception = expectThrows(
|
|
|
+ ElasticsearchStatusException.class,
|
|
|
+ () -> actionListener.actionGet(TimeValue.timeValueSeconds(30))
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThat(exception.getMessage(), is("failed"));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private ElasticsearchInternalService createService(Client client) {
|
|
|
var cs = mock(ClusterService.class);
|
|
|
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
|