|
@@ -6,13 +6,15 @@
|
|
|
*/
|
|
|
package org.elasticsearch.xpack.ml.action;
|
|
|
|
|
|
+import org.elasticsearch.ResourceNotFoundException;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.support.ActionFilters;
|
|
|
import org.elasticsearch.action.support.HandledTransportAction;
|
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
|
import org.elasticsearch.common.inject.Inject;
|
|
|
-import org.elasticsearch.core.TimeValue;
|
|
|
+import org.elasticsearch.common.util.concurrent.AtomicArray;
|
|
|
+import org.elasticsearch.core.Tuple;
|
|
|
import org.elasticsearch.license.License;
|
|
|
import org.elasticsearch.license.LicenseUtils;
|
|
|
import org.elasticsearch.license.XPackLicenseState;
|
|
@@ -28,8 +30,11 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request;
|
|
|
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response;
|
|
|
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
|
|
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
import org.elasticsearch.xpack.ml.MachineLearning;
|
|
|
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
|
|
|
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
|
|
@@ -38,9 +43,10 @@ import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
|
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
|
|
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
|
|
|
|
|
-import java.util.Collections;
|
|
|
-import java.util.Map;
|
|
|
+import java.util.List;
|
|
|
import java.util.Optional;
|
|
|
+import java.util.concurrent.atomic.AtomicInteger;
|
|
|
+import java.util.concurrent.atomic.AtomicReference;
|
|
|
|
|
|
import static org.elasticsearch.core.Strings.format;
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
|
@@ -132,19 +138,19 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|
|
) {
|
|
|
String concreteModelId = Optional.ofNullable(ModelAliasMetadata.fromState(clusterService.state()).getModelId(request.getModelId()))
|
|
|
.orElse(request.getModelId());
|
|
|
- if (isAllocatedModel(concreteModelId)) {
|
|
|
+
|
|
|
+ responseBuilder.setModelId(concreteModelId);
|
|
|
+
|
|
|
+ TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state());
|
|
|
+
|
|
|
+ if (trainedModelAssignmentMetadata.isAssigned(concreteModelId)) {
|
|
|
// It is important to use the resolved model ID here as the alias could change between transport calls.
|
|
|
- inferAgainstAllocatedModel(request, concreteModelId, responseBuilder, parentTaskId, listener);
|
|
|
+ inferAgainstAllocatedModel(trainedModelAssignmentMetadata, request, concreteModelId, responseBuilder, parentTaskId, listener);
|
|
|
} else {
|
|
|
getModelAndInfer(request, responseBuilder, parentTaskId, (CancellableTask) task, listener);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private boolean isAllocatedModel(String modelId) {
|
|
|
- TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state());
|
|
|
- return trainedModelAssignmentMetadata.isAssigned(modelId);
|
|
|
- }
|
|
|
-
|
|
|
private void getModelAndInfer(
|
|
|
Request request,
|
|
|
Response.Builder responseBuilder,
|
|
@@ -169,75 +175,153 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|
|
|
|
|
typedChainTaskExecutor.execute(ActionListener.wrap(inferenceResultsInterfaces -> {
|
|
|
model.release();
|
|
|
- listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).setModelId(model.getModelId()).build());
|
|
|
+ listener.onResponse(responseBuilder.addInferenceResults(inferenceResultsInterfaces).build());
|
|
|
}, e -> {
|
|
|
model.release();
|
|
|
listener.onFailure(e);
|
|
|
}));
|
|
|
- }, listener::onFailure);
|
|
|
+ }, e -> {
|
|
|
+ if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
|
|
+ listener.onFailure(e);
|
|
|
+ return;
|
|
|
+ }
|
|
|
|
|
|
+ // The model was found, check if a more relevant error message can be returned
|
|
|
+ trainedModelProvider.getTrainedModel(
|
|
|
+ request.getModelId(),
|
|
|
+ GetTrainedModelsAction.Includes.empty(),
|
|
|
+ parentTaskId,
|
|
|
+ ActionListener.wrap(trainedModelConfig -> {
|
|
|
+ if (trainedModelConfig.getModelType() == TrainedModelType.PYTORCH) {
|
|
|
+ // The PyTorch model cannot be allocated if we got here
|
|
|
+ listener.onFailure(
|
|
|
+ ExceptionsHelper.conflictStatusException(
|
|
|
+ "Model ["
|
|
|
+ + request.getModelId()
|
|
|
+ + "] must be deployed to use. Please deploy with the start trained model deployment API.",
|
|
|
+ request.getModelId()
|
|
|
+ )
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ // return the original error
|
|
|
+ listener.onFailure(e);
|
|
|
+ }
|
|
|
+ }, listener::onFailure)
|
|
|
+ );
|
|
|
+ });
|
|
|
+
|
|
|
+ // TODO should `getModelForInternalInference` be used here??
|
|
|
modelLoadingService.getModelForPipeline(request.getModelId(), parentTaskId, getModelListener);
|
|
|
}
|
|
|
|
|
|
private void inferAgainstAllocatedModel(
|
|
|
+ TrainedModelAssignmentMetadata assignmentMeta,
|
|
|
Request request,
|
|
|
String concreteModelId,
|
|
|
Response.Builder responseBuilder,
|
|
|
TaskId parentTaskId,
|
|
|
ActionListener<Response> listener
|
|
|
) {
|
|
|
- TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor = new TypedChainTaskExecutor<>(
|
|
|
- client.threadPool().executor(ThreadPool.Names.SAME),
|
|
|
- // run through all tasks
|
|
|
- r -> true,
|
|
|
- // Always fail immediately and return an error
|
|
|
- ex -> true
|
|
|
- );
|
|
|
- request.getObjectsToInfer()
|
|
|
- .forEach(
|
|
|
- stringObjectMap -> typedChainTaskExecutor.add(
|
|
|
- chainedTask -> inferSingleDocAgainstAllocatedModel(
|
|
|
- concreteModelId,
|
|
|
- request.getTimeout(),
|
|
|
- request.getUpdate(),
|
|
|
- stringObjectMap,
|
|
|
- parentTaskId,
|
|
|
- chainedTask
|
|
|
- )
|
|
|
- )
|
|
|
+ TrainedModelAssignment assignment = assignmentMeta.getModelAssignment(concreteModelId);
|
|
|
+
|
|
|
+ if (assignment.getAssignmentState() == AssignmentState.STOPPING) {
|
|
|
+ String message = "Trained model [" + request.getModelId() + "] is STOPPING";
|
|
|
+ listener.onFailure(ExceptionsHelper.conflictStatusException(message));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Get a list of nodes to send the requests to and the number of
|
|
|
+ // documents for each node.
|
|
|
+ var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments());
|
|
|
+ if (nodes.isEmpty()) {
|
|
|
+ logger.trace(() -> format("[%s] model not allocated to any node [%s]", assignment.getModelId()));
|
|
|
+ listener.onFailure(
|
|
|
+ ExceptionsHelper.conflictStatusException("Trained model [" + request.getModelId() + "] is not allocated to any nodes")
|
|
|
);
|
|
|
+ return;
|
|
|
+ }
|
|
|
|
|
|
- typedChainTaskExecutor.execute(
|
|
|
- ActionListener.wrap(
|
|
|
- inferenceResults -> listener.onResponse(
|
|
|
- responseBuilder.setInferenceResults(inferenceResults).setModelId(concreteModelId).build()
|
|
|
- ),
|
|
|
- listener::onFailure
|
|
|
- )
|
|
|
- );
|
|
|
+ assert nodes.stream().mapToInt(Tuple::v2).sum() == request.numberOfDocuments()
|
|
|
+ : "mismatch; sum of node requests does not match number of documents in request";
|
|
|
+
|
|
|
+ AtomicInteger count = new AtomicInteger();
|
|
|
+ AtomicArray<List<InferenceResults>> results = new AtomicArray<>(nodes.size());
|
|
|
+ AtomicReference<Exception> failure = new AtomicReference<>();
|
|
|
+
|
|
|
+ int startPos = 0;
|
|
|
+ int slot = 0;
|
|
|
+ for (var node : nodes) {
|
|
|
+ InferTrainedModelDeploymentAction.Request deploymentRequest;
|
|
|
+ if (request.getTextInput() == null) {
|
|
|
+ deploymentRequest = InferTrainedModelDeploymentAction.Request.forDocs(
|
|
|
+ concreteModelId,
|
|
|
+ request.getUpdate(),
|
|
|
+ request.getObjectsToInfer().subList(startPos, startPos + node.v2()),
|
|
|
+ request.getInferenceTimeout()
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ deploymentRequest = InferTrainedModelDeploymentAction.Request.forTextInput(
|
|
|
+ concreteModelId,
|
|
|
+ request.getUpdate(),
|
|
|
+ request.getTextInput().subList(startPos, startPos + node.v2()),
|
|
|
+ request.getInferenceTimeout()
|
|
|
+ );
|
|
|
+ }
|
|
|
+ deploymentRequest.setNodes(node.v1());
|
|
|
+ deploymentRequest.setParentTask(parentTaskId);
|
|
|
+
|
|
|
+ startPos += node.v2();
|
|
|
+
|
|
|
+ executeAsyncWithOrigin(
|
|
|
+ client,
|
|
|
+ ML_ORIGIN,
|
|
|
+ InferTrainedModelDeploymentAction.INSTANCE,
|
|
|
+ deploymentRequest,
|
|
|
+ collectingListener(count, results, failure, slot, nodes.size(), responseBuilder, listener)
|
|
|
+ );
|
|
|
+
|
|
|
+ slot++;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- private void inferSingleDocAgainstAllocatedModel(
|
|
|
- String modelId,
|
|
|
- TimeValue timeValue,
|
|
|
- InferenceConfigUpdate inferenceConfigUpdate,
|
|
|
- Map<String, Object> doc,
|
|
|
- TaskId parentTaskId,
|
|
|
- ActionListener<InferenceResults> listener
|
|
|
+ private ActionListener<InferTrainedModelDeploymentAction.Response> collectingListener(
|
|
|
+ AtomicInteger count,
|
|
|
+ AtomicArray<List<InferenceResults>> results,
|
|
|
+ AtomicReference<Exception> failure,
|
|
|
+ int slot,
|
|
|
+ int totalNumberOfResponses,
|
|
|
+ Response.Builder responseBuilder,
|
|
|
+ ActionListener<Response> finalListener
|
|
|
) {
|
|
|
- InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request(
|
|
|
- modelId,
|
|
|
- inferenceConfigUpdate,
|
|
|
- Collections.singletonList(doc),
|
|
|
- timeValue
|
|
|
- );
|
|
|
- request.setParentTask(parentTaskId);
|
|
|
- executeAsyncWithOrigin(
|
|
|
- client,
|
|
|
- ML_ORIGIN,
|
|
|
- InferTrainedModelDeploymentAction.INSTANCE,
|
|
|
- request,
|
|
|
- ActionListener.wrap(r -> listener.onResponse(r.getResults()), listener::onFailure)
|
|
|
- );
|
|
|
+ return new ActionListener<>() {
|
|
|
+ @Override
|
|
|
+ public void onResponse(InferTrainedModelDeploymentAction.Response response) {
|
|
|
+ results.setOnce(slot, response.getResults());
|
|
|
+ if (count.incrementAndGet() == totalNumberOfResponses) {
|
|
|
+ sendResponse();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onFailure(Exception e) {
|
|
|
+ failure.set(e);
|
|
|
+ if (count.incrementAndGet() == totalNumberOfResponses) {
|
|
|
+ sendResponse();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void sendResponse() {
|
|
|
+ if (results.nonNullLength() > 0) {
|
|
|
+ for (int i = 0; i < results.length(); i++) {
|
|
|
+ if (results.get(i) != null) {
|
|
|
+ responseBuilder.addInferenceResults(results.get(i));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ finalListener.onResponse(responseBuilder.build());
|
|
|
+ } else {
|
|
|
+ finalListener.onFailure(failure.get());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
}
|
|
|
}
|