|
@@ -13,6 +13,10 @@ import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.support.ActionFilters;
|
|
|
import org.elasticsearch.action.support.HandledTransportAction;
|
|
|
+import org.elasticsearch.client.internal.node.NodeClient;
|
|
|
+import org.elasticsearch.cluster.node.DiscoveryNode;
|
|
|
+import org.elasticsearch.common.Randomness;
|
|
|
+import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.Writeable;
|
|
|
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
|
|
import org.elasticsearch.common.xcontent.ChunkedToXContent;
|
|
@@ -27,24 +31,42 @@ import org.elasticsearch.license.LicenseUtils;
|
|
|
import org.elasticsearch.license.XPackLicenseState;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.tasks.Task;
|
|
|
+import org.elasticsearch.threadpool.ThreadPool;
|
|
|
+import org.elasticsearch.transport.TransportException;
|
|
|
+import org.elasticsearch.transport.TransportResponseHandler;
|
|
|
import org.elasticsearch.transport.TransportService;
|
|
|
import org.elasticsearch.xpack.core.XPackField;
|
|
|
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
|
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
+import org.elasticsearch.xpack.inference.InferencePlugin;
|
|
|
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
|
|
|
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
|
|
|
+import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
|
|
|
+import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
|
|
|
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
|
|
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
|
|
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
|
|
|
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.Random;
|
|
|
+import java.util.concurrent.Executor;
|
|
|
import java.util.function.Supplier;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import static org.elasticsearch.core.Strings.format;
|
|
|
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
|
|
|
+import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
|
|
|
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
|
|
|
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
|
|
|
|
|
|
+/**
|
|
|
+ * Base class for transport actions that handle inference requests.
|
|
|
+ * Works in conjunction with {@link InferenceServiceNodeLocalRateLimitCalculator} to
|
|
|
+ * route requests to specific nodes, iff they support "node-local" rate limiting, which is described in detail
|
|
|
+ * in {@link InferenceServiceNodeLocalRateLimitCalculator}.
|
|
|
+ *
|
|
|
+ * @param <Request> The specific type of inference request being handled
|
|
|
+ */
|
|
|
public abstract class BaseTransportInferenceAction<Request extends BaseInferenceActionRequest> extends HandledTransportAction<
|
|
|
Request,
|
|
|
InferenceAction.Response> {
|
|
@@ -57,6 +79,11 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
private final InferenceServiceRegistry serviceRegistry;
|
|
|
private final InferenceStats inferenceStats;
|
|
|
private final StreamingTaskManager streamingTaskManager;
|
|
|
+ private final InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
|
|
|
+ private final NodeClient nodeClient;
|
|
|
+ private final ThreadPool threadPool;
|
|
|
+ private final TransportService transportService;
|
|
|
+ private final Random random;
|
|
|
|
|
|
public BaseTransportInferenceAction(
|
|
|
String inferenceActionName,
|
|
@@ -67,7 +94,10 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
InferenceServiceRegistry serviceRegistry,
|
|
|
InferenceStats inferenceStats,
|
|
|
StreamingTaskManager streamingTaskManager,
|
|
|
- Writeable.Reader<Request> requestReader
|
|
|
+ Writeable.Reader<Request> requestReader,
|
|
|
+ InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
|
|
|
+ NodeClient nodeClient,
|
|
|
+ ThreadPool threadPool
|
|
|
) {
|
|
|
super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE);
|
|
|
this.licenseState = licenseState;
|
|
@@ -75,8 +105,24 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
this.serviceRegistry = serviceRegistry;
|
|
|
this.inferenceStats = inferenceStats;
|
|
|
this.streamingTaskManager = streamingTaskManager;
|
|
|
+ this.inferenceServiceRateLimitCalculator = inferenceServiceNodeLocalRateLimitCalculator;
|
|
|
+ this.nodeClient = nodeClient;
|
|
|
+ this.threadPool = threadPool;
|
|
|
+ this.transportService = transportService;
|
|
|
+ this.random = Randomness.get();
|
|
|
}
|
|
|
|
|
|
+ protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel);
|
|
|
+
|
|
|
+ protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel);
|
|
|
+
|
|
|
+ protected abstract void doInference(
|
|
|
+ Model model,
|
|
|
+ Request request,
|
|
|
+ InferenceService service,
|
|
|
+ ActionListener<InferenceServiceResults> listener
|
|
|
+ );
|
|
|
+
|
|
|
@Override
|
|
|
protected void doExecute(Task task, Request request, ActionListener<InferenceAction.Response> listener) {
|
|
|
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
|
|
@@ -87,31 +133,32 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
var timer = InferenceTimer.start();
|
|
|
|
|
|
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
|
|
|
- var service = serviceRegistry.getService(unparsedModel.service());
|
|
|
+ var serviceName = unparsedModel.service();
|
|
|
+
|
|
|
try {
|
|
|
- validationHelper(service::isEmpty, () -> unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
|
|
|
- validationHelper(
|
|
|
- () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false,
|
|
|
- () -> requestModelTaskTypeMismatchException(request.getTaskType(), unparsedModel.taskType())
|
|
|
- );
|
|
|
- validationHelper(
|
|
|
- () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel),
|
|
|
- () -> createInvalidTaskTypeException(request, unparsedModel)
|
|
|
- );
|
|
|
+ validateRequest(request, unparsedModel);
|
|
|
} catch (Exception e) {
|
|
|
recordMetrics(unparsedModel, timer, e);
|
|
|
listener.onFailure(e);
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- var model = service.get()
|
|
|
- .parsePersistedConfigWithSecrets(
|
|
|
+ var service = serviceRegistry.getService(serviceName).get();
|
|
|
+ var routingDecision = determineRouting(serviceName, request, unparsedModel);
|
|
|
+
|
|
|
+ if (routingDecision.currentNodeShouldHandleRequest()) {
|
|
|
+ var model = service.parsePersistedConfigWithSecrets(
|
|
|
unparsedModel.inferenceEntityId(),
|
|
|
unparsedModel.taskType(),
|
|
|
unparsedModel.settings(),
|
|
|
unparsedModel.secrets()
|
|
|
);
|
|
|
- inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
|
|
|
+ inferOnServiceWithMetrics(model, request, service, timer, listener);
|
|
|
+ } else {
|
|
|
+ // Reroute request
|
|
|
+ request.setHasBeenRerouted(true);
|
|
|
+ rerouteRequest(request, listener, routingDecision.targetNode);
|
|
|
+ }
|
|
|
}, e -> {
|
|
|
try {
|
|
|
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
|
|
@@ -124,15 +171,95 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
|
|
|
}
|
|
|
|
|
|
+ private void validateRequest(Request request, UnparsedModel unparsedModel) {
|
|
|
+ var serviceName = unparsedModel.service();
|
|
|
+ var requestTaskType = request.getTaskType();
|
|
|
+ var service = serviceRegistry.getService(serviceName);
|
|
|
+
|
|
|
+ validationHelper(service::isEmpty, () -> unknownServiceException(serviceName, request.getInferenceEntityId()));
|
|
|
+ validationHelper(
|
|
|
+ () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false,
|
|
|
+ () -> requestModelTaskTypeMismatchException(requestTaskType, unparsedModel.taskType())
|
|
|
+ );
|
|
|
+ validationHelper(
|
|
|
+ () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel),
|
|
|
+ () -> createInvalidTaskTypeException(request, unparsedModel)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel) {
|
|
|
+ if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG.isEnabled() == false) {
|
|
|
+ return NodeRoutingDecision.handleLocally();
|
|
|
+ }
|
|
|
+
|
|
|
+ var modelTaskType = unparsedModel.taskType();
|
|
|
+
|
|
|
+ // Rerouting not supported or request was already rerouted
|
|
|
+ if (inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceName, modelTaskType) == false
|
|
|
+ || request.hasBeenRerouted()) {
|
|
|
+ return NodeRoutingDecision.handleLocally();
|
|
|
+ }
|
|
|
+
|
|
|
+ var rateLimitAssignment = inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceName, modelTaskType);
|
|
|
+
|
|
|
+ // No assignment yet
|
|
|
+ if (rateLimitAssignment == null) {
|
|
|
+ return NodeRoutingDecision.handleLocally();
|
|
|
+ }
|
|
|
+
|
|
|
+ var responsibleNodes = rateLimitAssignment.responsibleNodes();
|
|
|
+
|
|
|
+ // Empty assignment
|
|
|
+ if (responsibleNodes == null || responsibleNodes.isEmpty()) {
|
|
|
+ return NodeRoutingDecision.handleLocally();
|
|
|
+ }
|
|
|
+
|
|
|
+ var nodeToHandleRequest = responsibleNodes.get(random.nextInt(responsibleNodes.size()));
|
|
|
+ String localNodeId = nodeClient.getLocalNodeId();
|
|
|
+
|
|
|
+ // The drawn node is the current node
|
|
|
+ if (nodeToHandleRequest.getId().equals(localNodeId)) {
|
|
|
+ return NodeRoutingDecision.handleLocally();
|
|
|
+ }
|
|
|
+
|
|
|
+ // Reroute request
|
|
|
+ return NodeRoutingDecision.routeTo(nodeToHandleRequest);
|
|
|
+ }
|
|
|
+
|
|
|
private static void validationHelper(Supplier<Boolean> validationFailure, Supplier<ElasticsearchStatusException> exceptionCreator) {
|
|
|
if (validationFailure.get()) {
|
|
|
throw exceptionCreator.get();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel);
|
|
|
-
|
|
|
- protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel);
|
|
|
+ private void rerouteRequest(Request request, ActionListener<InferenceAction.Response> listener, DiscoveryNode nodeToHandleRequest) {
|
|
|
+ transportService.sendRequest(
|
|
|
+ nodeToHandleRequest,
|
|
|
+ InferenceAction.NAME,
|
|
|
+ request,
|
|
|
+ new TransportResponseHandler<InferenceAction.Response>() {
|
|
|
+ @Override
|
|
|
+ public Executor executor() {
|
|
|
+ return threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void handleResponse(InferenceAction.Response response) {
|
|
|
+ listener.onResponse(response);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void handleException(TransportException exp) {
|
|
|
+ listener.onFailure(exp);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public InferenceAction.Response read(StreamInput in) throws IOException {
|
|
|
+ return new InferenceAction.Response(in);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ );
|
|
|
+ }
|
|
|
|
|
|
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
|
|
|
try {
|
|
@@ -185,13 +312,6 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- protected abstract void doInference(
|
|
|
- Model model,
|
|
|
- Request request,
|
|
|
- InferenceService service,
|
|
|
- ActionListener<InferenceServiceResults> listener
|
|
|
- );
|
|
|
-
|
|
|
private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) {
|
|
|
var supportedTasks = service.supportedStreamingTasks();
|
|
|
if (supportedTasks.isEmpty()) {
|
|
@@ -259,4 +379,14 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
super.onComplete();
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
|
|
|
+ static NodeRoutingDecision handleLocally() {
|
|
|
+ return new NodeRoutingDecision(true, null);
|
|
|
+ }
|
|
|
+
|
|
|
+ static NodeRoutingDecision routeTo(DiscoveryNode node) {
|
|
|
+ return new NodeRoutingDecision(false, node);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|