|  | @@ -6,13 +6,15 @@
 | 
	
		
			
				|  |  |   */
 | 
	
		
			
				|  |  |  package org.elasticsearch.xpack.ml.action;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import org.elasticsearch.ResourceNotFoundException;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.ActionListener;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.support.ActionFilters;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.support.HandledTransportAction;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.internal.Client;
 | 
	
		
			
				|  |  |  import org.elasticsearch.cluster.service.ClusterService;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.inject.Inject;
 | 
	
		
			
				|  |  | -import org.elasticsearch.core.TimeValue;
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.util.concurrent.AtomicArray;
 | 
	
		
			
				|  |  | +import org.elasticsearch.core.Tuple;
 | 
	
		
			
				|  |  |  import org.elasticsearch.license.License;
 | 
	
		
			
				|  |  |  import org.elasticsearch.license.LicenseUtils;
 | 
	
		
			
				|  |  |  import org.elasticsearch.license.XPackLicenseState;
 | 
	
	
		
			
				|  | @@ -28,8 +30,11 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.ml.MachineLearning;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
 | 
	
	
		
			
				|  | @@ -38,9 +43,10 @@ import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -import java.util.Collections;
 | 
	
		
			
				|  |  | -import java.util.Map;
 | 
	
		
			
				|  |  | +import java.util.List;
 | 
	
		
			
				|  |  |  import java.util.Optional;
 | 
	
		
			
				|  |  | +import java.util.concurrent.atomic.AtomicInteger;
 | 
	
		
			
				|  |  | +import java.util.concurrent.atomic.AtomicReference;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import static org.elasticsearch.core.Strings.format;
 | 
	
		
			
				|  |  |  import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 | 
	
	
		
			
				|  | @@ -132,19 +138,19 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
 | 
	
		
			
				|  |  |      ) {
 | 
	
		
			
				|  |  |          String concreteModelId = Optional.ofNullable(ModelAliasMetadata.fromState(clusterService.state()).getModelId(request.getModelId()))
 | 
	
		
			
				|  |  |              .orElse(request.getModelId());
 | 
	
		
			
				|  |  | -        if (isAllocatedModel(concreteModelId)) {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        responseBuilder.setModelId(concreteModelId);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state());
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if (trainedModelAssignmentMetadata.isAssigned(concreteModelId)) {
 | 
	
		
			
				|  |  |              // It is important to use the resolved model ID here as the alias could change between transport calls.
 | 
	
		
			
				|  |  | -            inferAgainstAllocatedModel(request, concreteModelId, responseBuilder, parentTaskId, listener);
 | 
	
		
			
				|  |  | +            inferAgainstAllocatedModel(trainedModelAssignmentMetadata, request, concreteModelId, responseBuilder, parentTaskId, listener);
 | 
	
		
			
				|  |  |          } else {
 | 
	
		
			
				|  |  |              getModelAndInfer(request, responseBuilder, parentTaskId, (CancellableTask) task, listener);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private boolean isAllocatedModel(String modelId) {
 | 
	
		
			
				|  |  | -        TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state());
 | 
	
		
			
				|  |  | -        return trainedModelAssignmentMetadata.isAssigned(modelId);
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      private void getModelAndInfer(
 | 
	
		
			
				|  |  |          Request request,
 | 
	
		
			
				|  |  |          Response.Builder responseBuilder,
 | 
	
	
		
			
				|  | @@ -169,75 +175,153 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              typedChainTaskExecutor.execute(ActionListener.wrap(inferenceResultsInterfaces -> {
 | 
	
		
			
				|  |  |                  model.release();
 | 
	
		
			
				|  |  | -                listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).setModelId(model.getModelId()).build());
 | 
	
		
			
				|  |  | +                listener.onResponse(responseBuilder.addInferenceResults(inferenceResultsInterfaces).build());
 | 
	
		
			
				|  |  |              }, e -> {
 | 
	
		
			
				|  |  |                  model.release();
 | 
	
		
			
				|  |  |                  listener.onFailure(e);
 | 
	
		
			
				|  |  |              }));
 | 
	
		
			
				|  |  | -        }, listener::onFailure);
 | 
	
		
			
				|  |  | +        }, e -> {
 | 
	
		
			
				|  |  | +            if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
 | 
	
		
			
				|  |  | +                listener.onFailure(e);
 | 
	
		
			
				|  |  | +                return;
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +            // The model was found, check if a more relevant error message can be returned
 | 
	
		
			
				|  |  | +            trainedModelProvider.getTrainedModel(
 | 
	
		
			
				|  |  | +                request.getModelId(),
 | 
	
		
			
				|  |  | +                GetTrainedModelsAction.Includes.empty(),
 | 
	
		
			
				|  |  | +                parentTaskId,
 | 
	
		
			
				|  |  | +                ActionListener.wrap(trainedModelConfig -> {
 | 
	
		
			
				|  |  | +                    if (trainedModelConfig.getModelType() == TrainedModelType.PYTORCH) {
 | 
	
		
			
				|  |  | +                        // The PyTorch model cannot be allocated if we got here
 | 
	
		
			
				|  |  | +                        listener.onFailure(
 | 
	
		
			
				|  |  | +                            ExceptionsHelper.conflictStatusException(
 | 
	
		
			
				|  |  | +                                "Model ["
 | 
	
		
			
				|  |  | +                                    + request.getModelId()
 | 
	
		
			
				|  |  | +                                    + "] must be deployed to use. Please deploy with the start trained model deployment API.",
 | 
	
		
			
				|  |  | +                                request.getModelId()
 | 
	
		
			
				|  |  | +                            )
 | 
	
		
			
				|  |  | +                        );
 | 
	
		
			
				|  |  | +                    } else {
 | 
	
		
			
				|  |  | +                        // return the original error
 | 
	
		
			
				|  |  | +                        listener.onFailure(e);
 | 
	
		
			
				|  |  | +                    }
 | 
	
		
			
				|  |  | +                }, listener::onFailure)
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // TODO should `getModelForInternalInference` be used here??
 | 
	
		
			
				|  |  |          modelLoadingService.getModelForPipeline(request.getModelId(), parentTaskId, getModelListener);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private void inferAgainstAllocatedModel(
 | 
	
		
			
				|  |  | +        TrainedModelAssignmentMetadata assignmentMeta,
 | 
	
		
			
				|  |  |          Request request,
 | 
	
		
			
				|  |  |          String concreteModelId,
 | 
	
		
			
				|  |  |          Response.Builder responseBuilder,
 | 
	
		
			
				|  |  |          TaskId parentTaskId,
 | 
	
		
			
				|  |  |          ActionListener<Response> listener
 | 
	
		
			
				|  |  |      ) {
 | 
	
		
			
				|  |  | -        TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor = new TypedChainTaskExecutor<>(
 | 
	
		
			
				|  |  | -            client.threadPool().executor(ThreadPool.Names.SAME),
 | 
	
		
			
				|  |  | -            // run through all tasks
 | 
	
		
			
				|  |  | -            r -> true,
 | 
	
		
			
				|  |  | -            // Always fail immediately and return an error
 | 
	
		
			
				|  |  | -            ex -> true
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -        request.getObjectsToInfer()
 | 
	
		
			
				|  |  | -            .forEach(
 | 
	
		
			
				|  |  | -                stringObjectMap -> typedChainTaskExecutor.add(
 | 
	
		
			
				|  |  | -                    chainedTask -> inferSingleDocAgainstAllocatedModel(
 | 
	
		
			
				|  |  | -                        concreteModelId,
 | 
	
		
			
				|  |  | -                        request.getTimeout(),
 | 
	
		
			
				|  |  | -                        request.getUpdate(),
 | 
	
		
			
				|  |  | -                        stringObjectMap,
 | 
	
		
			
				|  |  | -                        parentTaskId,
 | 
	
		
			
				|  |  | -                        chainedTask
 | 
	
		
			
				|  |  | -                    )
 | 
	
		
			
				|  |  | -                )
 | 
	
		
			
				|  |  | +        TrainedModelAssignment assignment = assignmentMeta.getModelAssignment(concreteModelId);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if (assignment.getAssignmentState() == AssignmentState.STOPPING) {
 | 
	
		
			
				|  |  | +            String message = "Trained model [" + request.getModelId() + "] is STOPPING";
 | 
	
		
			
				|  |  | +            listener.onFailure(ExceptionsHelper.conflictStatusException(message));
 | 
	
		
			
				|  |  | +            return;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Get a list of nodes to send the requests to and the number of
 | 
	
		
			
				|  |  | +        // documents for each node.
 | 
	
		
			
				|  |  | +        var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments());
 | 
	
		
			
				|  |  | +        if (nodes.isEmpty()) {
 | 
	
		
			
				|  |  | +            logger.trace(() -> format("[%s] model not allocated to any node [%s]", assignment.getModelId()));
 | 
	
		
			
				|  |  | +            listener.onFailure(
 | 
	
		
			
				|  |  | +                ExceptionsHelper.conflictStatusException("Trained model [" + request.getModelId() + "] is not allocated to any nodes")
 | 
	
		
			
				|  |  |              );
 | 
	
		
			
				|  |  | +            return;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        typedChainTaskExecutor.execute(
 | 
	
		
			
				|  |  | -            ActionListener.wrap(
 | 
	
		
			
				|  |  | -                inferenceResults -> listener.onResponse(
 | 
	
		
			
				|  |  | -                    responseBuilder.setInferenceResults(inferenceResults).setModelId(concreteModelId).build()
 | 
	
		
			
				|  |  | -                ),
 | 
	
		
			
				|  |  | -                listener::onFailure
 | 
	
		
			
				|  |  | -            )
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | +        assert nodes.stream().mapToInt(Tuple::v2).sum() == request.numberOfDocuments()
 | 
	
		
			
				|  |  | +            : "mismatch; sum of node requests does not match number of documents in request";
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        AtomicInteger count = new AtomicInteger();
 | 
	
		
			
				|  |  | +        AtomicArray<List<InferenceResults>> results = new AtomicArray<>(nodes.size());
 | 
	
		
			
				|  |  | +        AtomicReference<Exception> failure = new AtomicReference<>();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        int startPos = 0;
 | 
	
		
			
				|  |  | +        int slot = 0;
 | 
	
		
			
				|  |  | +        for (var node : nodes) {
 | 
	
		
			
				|  |  | +            InferTrainedModelDeploymentAction.Request deploymentRequest;
 | 
	
		
			
				|  |  | +            if (request.getTextInput() == null) {
 | 
	
		
			
				|  |  | +                deploymentRequest = InferTrainedModelDeploymentAction.Request.forDocs(
 | 
	
		
			
				|  |  | +                    concreteModelId,
 | 
	
		
			
				|  |  | +                    request.getUpdate(),
 | 
	
		
			
				|  |  | +                    request.getObjectsToInfer().subList(startPos, startPos + node.v2()),
 | 
	
		
			
				|  |  | +                    request.getInferenceTimeout()
 | 
	
		
			
				|  |  | +                );
 | 
	
		
			
				|  |  | +            } else {
 | 
	
		
			
				|  |  | +                deploymentRequest = InferTrainedModelDeploymentAction.Request.forTextInput(
 | 
	
		
			
				|  |  | +                    concreteModelId,
 | 
	
		
			
				|  |  | +                    request.getUpdate(),
 | 
	
		
			
				|  |  | +                    request.getTextInput().subList(startPos, startPos + node.v2()),
 | 
	
		
			
				|  |  | +                    request.getInferenceTimeout()
 | 
	
		
			
				|  |  | +                );
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            deploymentRequest.setNodes(node.v1());
 | 
	
		
			
				|  |  | +            deploymentRequest.setParentTask(parentTaskId);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            startPos += node.v2();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            executeAsyncWithOrigin(
 | 
	
		
			
				|  |  | +                client,
 | 
	
		
			
				|  |  | +                ML_ORIGIN,
 | 
	
		
			
				|  |  | +                InferTrainedModelDeploymentAction.INSTANCE,
 | 
	
		
			
				|  |  | +                deploymentRequest,
 | 
	
		
			
				|  |  | +                collectingListener(count, results, failure, slot, nodes.size(), responseBuilder, listener)
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            slot++;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private void inferSingleDocAgainstAllocatedModel(
 | 
	
		
			
				|  |  | -        String modelId,
 | 
	
		
			
				|  |  | -        TimeValue timeValue,
 | 
	
		
			
				|  |  | -        InferenceConfigUpdate inferenceConfigUpdate,
 | 
	
		
			
				|  |  | -        Map<String, Object> doc,
 | 
	
		
			
				|  |  | -        TaskId parentTaskId,
 | 
	
		
			
				|  |  | -        ActionListener<InferenceResults> listener
 | 
	
		
			
				|  |  | +    private ActionListener<InferTrainedModelDeploymentAction.Response> collectingListener(
 | 
	
		
			
				|  |  | +        AtomicInteger count,
 | 
	
		
			
				|  |  | +        AtomicArray<List<InferenceResults>> results,
 | 
	
		
			
				|  |  | +        AtomicReference<Exception> failure,
 | 
	
		
			
				|  |  | +        int slot,
 | 
	
		
			
				|  |  | +        int totalNumberOfResponses,
 | 
	
		
			
				|  |  | +        Response.Builder responseBuilder,
 | 
	
		
			
				|  |  | +        ActionListener<Response> finalListener
 | 
	
		
			
				|  |  |      ) {
 | 
	
		
			
				|  |  | -        InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request(
 | 
	
		
			
				|  |  | -            modelId,
 | 
	
		
			
				|  |  | -            inferenceConfigUpdate,
 | 
	
		
			
				|  |  | -            Collections.singletonList(doc),
 | 
	
		
			
				|  |  | -            timeValue
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -        request.setParentTask(parentTaskId);
 | 
	
		
			
				|  |  | -        executeAsyncWithOrigin(
 | 
	
		
			
				|  |  | -            client,
 | 
	
		
			
				|  |  | -            ML_ORIGIN,
 | 
	
		
			
				|  |  | -            InferTrainedModelDeploymentAction.INSTANCE,
 | 
	
		
			
				|  |  | -            request,
 | 
	
		
			
				|  |  | -            ActionListener.wrap(r -> listener.onResponse(r.getResults()), listener::onFailure)
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | +        return new ActionListener<>() {
 | 
	
		
			
				|  |  | +            @Override
 | 
	
		
			
				|  |  | +            public void onResponse(InferTrainedModelDeploymentAction.Response response) {
 | 
	
		
			
				|  |  | +                results.setOnce(slot, response.getResults());
 | 
	
		
			
				|  |  | +                if (count.incrementAndGet() == totalNumberOfResponses) {
 | 
	
		
			
				|  |  | +                    sendResponse();
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            @Override
 | 
	
		
			
				|  |  | +            public void onFailure(Exception e) {
 | 
	
		
			
				|  |  | +                failure.set(e);
 | 
	
		
			
				|  |  | +                if (count.incrementAndGet() == totalNumberOfResponses) {
 | 
	
		
			
				|  |  | +                    sendResponse();
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            private void sendResponse() {
 | 
	
		
			
				|  |  | +                if (results.nonNullLength() > 0) {
 | 
	
		
			
				|  |  | +                    for (int i = 0; i < results.length(); i++) {
 | 
	
		
			
				|  |  | +                        if (results.get(i) != null) {
 | 
	
		
			
				|  |  | +                            responseBuilder.addInferenceResults(results.get(i));
 | 
	
		
			
				|  |  | +                        }
 | 
	
		
			
				|  |  | +                    }
 | 
	
		
			
				|  |  | +                    finalListener.onResponse(responseBuilder.build());
 | 
	
		
			
				|  |  | +                } else {
 | 
	
		
			
				|  |  | +                    finalListener.onFailure(failure.get());
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        };
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  }
 |