ソースを参照

[Inference API] Auto-propagate product origin for every subclass of ElasticInferenceServiceRequest (#123141) (#123980)

Tim Grein 7 ヶ月 前
コミット
78087f7344
11 ファイル変更125 行追加19 行削除
  1. 5 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java
  2. 5 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java
  3. 6 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java
  4. 25 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java
  5. 7 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java
  6. 7 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java
  7. 2 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java
  8. 5 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
  9. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java
  10. 61 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java
  11. 1 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java

+ 5 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java

@@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
+import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.inference.common.Truncator;
 import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler;
 import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
@@ -43,6 +44,8 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast
 
     private final InputType inputType;
 
+    private final String productOrigin;
+
     private static ResponseHandler createSparseEmbeddingsHandler() {
         return new ElasticInferenceServiceResponseHandler(
             String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
@@ -60,6 +63,7 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast
         this.model = model;
         this.truncator = serviceComponents.truncator();
         this.traceContext = traceContext;
+        this.productOrigin = serviceComponents.threadPool().getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
         this.inputType = inputType;
     }
 
@@ -78,6 +82,7 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast
             truncatedInput,
             model,
             traceContext,
+            productOrigin,
             inputType
         );
         execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));

+ 5 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java

@@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler;
 import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
@@ -43,6 +44,7 @@ public class ElasticInferenceServiceUnifiedCompletionRequestManager extends Elas
 
     private final ElasticInferenceServiceCompletionModel model;
     private final TraceContext traceContext;
+    private final String productOrigin;
 
     private ElasticInferenceServiceUnifiedCompletionRequestManager(
         ElasticInferenceServiceCompletionModel model,
@@ -52,6 +54,7 @@ public class ElasticInferenceServiceUnifiedCompletionRequestManager extends Elas
         super(threadPool, model);
         this.model = model;
         this.traceContext = traceContext;
+        this.productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
     }
 
     @Override
@@ -65,7 +68,8 @@ public class ElasticInferenceServiceUnifiedCompletionRequestManager extends Elas
         ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest(
             inferenceInputs.castTo(UnifiedChatInput.class),
             model,
-            traceContext
+            traceContext,
+            productOrigin
         );
 
         execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));

+ 6 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java

@@ -8,9 +8,9 @@
 package org.elasticsearch.xpack.inference.external.request.elastic;
 
 import org.apache.http.client.methods.HttpGet;
+import org.apache.http.client.methods.HttpRequestBase;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
 import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
 import org.elasticsearch.xpack.inference.telemetry.TraceContext;
@@ -20,12 +20,13 @@ import java.net.URI;
 import java.net.URISyntaxException;
 import java.util.Objects;
 
-public class ElasticInferenceServiceAuthorizationRequest implements ElasticInferenceServiceRequest {
+public class ElasticInferenceServiceAuthorizationRequest extends ElasticInferenceServiceRequest {
 
     private final URI uri;
     private final TraceContextHandler traceContextHandler;
 
-    public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext) {
+    public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext, String productOrigin) {
+        super(productOrigin);
         this.uri = createUri(Objects.requireNonNull(url));
         this.traceContextHandler = new TraceContextHandler(traceContext);
     }
@@ -44,11 +45,11 @@ public class ElasticInferenceServiceAuthorizationRequest implements ElasticInfer
     }
 
     @Override
-    public HttpRequest createHttpRequest() {
+    public HttpRequestBase createHttpRequestBase() {
         var httpGet = new HttpGet(uri);
         traceContextHandler.propagateTraceContext(httpGet);
 
-        return new HttpRequest(httpGet, getInferenceEntityId());
+        return httpGet;
     }
 
     public TraceContext getTraceContext() {

+ 25 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java

@@ -7,6 +7,30 @@
 
 package org.elasticsearch.xpack.inference.external.request.elastic;
 
+import org.apache.http.client.methods.HttpRequestBase;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
 
-public interface ElasticInferenceServiceRequest extends Request {}
+public abstract class ElasticInferenceServiceRequest implements Request {
+
+    private final String productOrigin;
+
+    public ElasticInferenceServiceRequest(String productOrigin) {
+        this.productOrigin = productOrigin;
+    }
+
+    public String getProductOrigin() {
+        return productOrigin;
+    }
+
+    @Override
+    public final HttpRequest createHttpRequest() {
+        HttpRequestBase request = createHttpRequestBase();
+        // TODO: consider moving tracing here, too
+        request.setHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, productOrigin);
+        return new HttpRequest(request, getInferenceEntityId());
+    }
+
+    protected abstract HttpRequestBase createHttpRequestBase();
+}

+ 7 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java

@@ -9,13 +9,13 @@ package org.elasticsearch.xpack.inference.external.request.elastic;
 
 import org.apache.http.HttpHeaders;
 import org.apache.http.client.methods.HttpPost;
+import org.apache.http.client.methods.HttpRequestBase;
 import org.apache.http.entity.ByteArrayEntity;
 import org.apache.http.message.BasicHeader;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.common.Truncator;
-import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
 import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
@@ -26,7 +26,7 @@ import java.net.URI;
 import java.nio.charset.StandardCharsets;
 import java.util.Objects;
 
-public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest {
+public class ElasticInferenceServiceSparseEmbeddingsRequest extends ElasticInferenceServiceRequest {
 
     private final URI uri;
     private final ElasticInferenceServiceSparseEmbeddingsModel model;
@@ -40,8 +40,10 @@ public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticIn
         Truncator.TruncationResult truncationResult,
         ElasticInferenceServiceSparseEmbeddingsModel model,
         TraceContext traceContext,
+        String productOrigin,
         InputType inputType
     ) {
+        super(productOrigin);
         this.truncator = truncator;
         this.truncationResult = truncationResult;
         this.model = Objects.requireNonNull(model);
@@ -51,7 +53,7 @@ public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticIn
     }
 
     @Override
-    public HttpRequest createHttpRequest() {
+    public HttpRequestBase createHttpRequestBase() {
         var httpPost = new HttpPost(uri);
         var usageContext = inputTypeToUsageContext(inputType);
         var requestEntity = Strings.toString(
@@ -68,7 +70,7 @@ public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticIn
         traceContextHandler.propagateTraceContext(httpPost);
         httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
 
-        return new HttpRequest(httpPost, getInferenceEntityId());
+        return httpPost;
     }
 
     public TraceContext getTraceContext() {
@@ -93,6 +95,7 @@ public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticIn
             truncatedInput,
             model,
             traceContextHandler.traceContext(),
+            getProductOrigin(),
             inputType
         );
     }

+ 7 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java

@@ -9,12 +9,12 @@ package org.elasticsearch.xpack.inference.external.request.elastic;
 
 import org.apache.http.HttpHeaders;
 import org.apache.http.client.methods.HttpPost;
+import org.apache.http.client.methods.HttpRequestBase;
 import org.apache.http.entity.ByteArrayEntity;
 import org.apache.http.message.BasicHeader;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
-import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
 import org.elasticsearch.xpack.inference.telemetry.TraceContext;
@@ -24,7 +24,7 @@ import java.net.URI;
 import java.nio.charset.StandardCharsets;
 import java.util.Objects;
 
-public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request {
+public class ElasticInferenceServiceUnifiedChatCompletionRequest extends ElasticInferenceServiceRequest {
 
     private final ElasticInferenceServiceCompletionModel model;
     private final UnifiedChatInput unifiedChatInput;
@@ -33,15 +33,17 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Requ
     public ElasticInferenceServiceUnifiedChatCompletionRequest(
         UnifiedChatInput unifiedChatInput,
         ElasticInferenceServiceCompletionModel model,
-        TraceContext traceContext
+        TraceContext traceContext,
+        String productOrigin
     ) {
+        super(productOrigin);
         this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
         this.model = Objects.requireNonNull(model);
         this.traceContextHandler = new TraceContextHandler(traceContext);
     }
 
     @Override
-    public HttpRequest createHttpRequest() {
+    public HttpRequestBase createHttpRequestBase() {
         var httpPost = new HttpPost(model.uri());
         var requestEntity = Strings.toString(
             new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
@@ -53,7 +55,7 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Requ
         traceContextHandler.propagateTraceContext(httpPost);
         httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
 
-        return new HttpRequest(httpPost, getInferenceEntityId());
+        return httpPost;
     }
 
     @Override

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

@@ -108,7 +108,8 @@ public class ElasticInferenceServiceAuthorizationHandler {
                 requestCompleteLatch.countDown();
             });
 
-            var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo());
+            var productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
+            var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), productOrigin);
 
             sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener);
         } catch (Exception e) {

+ 5 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

@@ -158,7 +158,11 @@ public class HttpRequestSenderTests extends ESTestCase {
             webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            var request = new ElasticInferenceServiceAuthorizationRequest(getUrl(webServer), new TraceContext("", ""));
+            var request = new ElasticInferenceServiceAuthorizationRequest(
+                getUrl(webServer),
+                new TraceContext("", ""),
+                randomAlphaOfLength(10)
+            );
             var responseHandler = new ElasticInferenceServiceResponseHandler(
                 String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
                 ElasticInferenceServiceAuthorizationResponseEntity::fromResponse

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java

@@ -30,7 +30,7 @@ public class ElasticInferenceServiceAuthorizationRequestTests extends ESTestCase
 
         ElasticsearchStatusException exception = assertThrows(
             ElasticsearchStatusException.class,
-            () -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext)
+            () -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext, randomAlphaOfLength(10))
         );
 
         assertThat(exception.status(), is(RestStatus.BAD_REQUEST));

+ 61 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java

@@ -0,0 +1,61 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.elastic;
+
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.client.methods.HttpRequestBase;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.net.URI;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class ElasticInferenceServiceRequestTests extends ESTestCase {
+
+    public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_WithProductOrigin() {
+        var productOrigin = "elastic";
+        var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest(productOrigin);
+        var httpRequest = elasticInferenceServiceRequestWrapper.createHttpRequest();
+        var productOriginHeader = httpRequest.httpRequestBase().getFirstHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
+
+        // Make sure this header only exists once
+        assertThat(httpRequest.httpRequestBase().getHeaders(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER).length, equalTo(1));
+        assertThat(productOriginHeader.getValue(), equalTo(productOrigin));
+    }
+
+    private static ElasticInferenceServiceRequest getDummyElasticInferenceServiceRequest(String productOrigin) {
+        return new ElasticInferenceServiceRequest(productOrigin) {
+            @Override
+            protected HttpRequestBase createHttpRequestBase() {
+                return new HttpGet("http://localhost:8080");
+            }
+
+            @Override
+            public URI getURI() {
+                return null;
+            }
+
+            @Override
+            public Request truncate() {
+                return null;
+            }
+
+            @Override
+            public boolean[] getTruncationInfo() {
+                return new boolean[0];
+            }
+
+            @Override
+            public String getInferenceEntityId() {
+                return "";
+            }
+        };
+    }
+}

+ 1 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java

@@ -124,6 +124,7 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestTests extends ESTestC
             new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
             embeddingsModel,
             new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)),
+            randomAlphaOfLength(10),
             inputType
         );
     }