浏览代码

Adding request source for cohere (#104926)

Jonathan Buttner 1 年之前
父节点
当前提交
422e6f6b98

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java

@@ -62,6 +62,7 @@ public class CohereEmbeddingsRequest implements Request {
 
 
         httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
         httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
         httpPost.setHeader(createAuthBearerHeader(account.apiKey()));
         httpPost.setHeader(createAuthBearerHeader(account.apiKey()));
+        httpPost.setHeader(CohereUtils.createRequestSourceHeader());
 
 
         return new HttpRequest(httpPost, getInferenceEntityId());
         return new HttpRequest(httpPost, getInferenceEntityId());
     }
     }

+ 9 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereUtils.java

@@ -7,10 +7,19 @@
 
 
 package org.elasticsearch.xpack.inference.external.request.cohere;
 package org.elasticsearch.xpack.inference.external.request.cohere;
 
 
+import org.apache.http.Header;
+import org.apache.http.message.BasicHeader;
+
 public class CohereUtils {
 public class CohereUtils {
     public static final String HOST = "api.cohere.ai";
     public static final String HOST = "api.cohere.ai";
     public static final String VERSION_1 = "v1";
     public static final String VERSION_1 = "v1";
     public static final String EMBEDDINGS_PATH = "embed";
     public static final String EMBEDDINGS_PATH = "embed";
+    public static final String REQUEST_SOURCE_HEADER = "Request-Source";
+    public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch";
+
+    public static Header createRequestSourceHeader() {
+        return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE);
+    }
 
 
     private CohereUtils() {}
     private CohereUtils() {}
 }
 }

+ 9 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java

@@ -25,6 +25,7 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.HttpResult;
 import org.elasticsearch.xpack.inference.external.http.HttpResult;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils;
 import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
 import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
 import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests;
 import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests;
 import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
 import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
@@ -130,6 +131,10 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
                 equalTo(XContentType.JSON.mediaType())
                 equalTo(XContentType.JSON.mediaType())
             );
             );
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+            MatcherAssert.assertThat(
+                webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER),
+                equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE)
+            );
 
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
             MatcherAssert.assertThat(
             MatcherAssert.assertThat(
@@ -210,6 +215,10 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
                 equalTo(XContentType.JSON.mediaType())
                 equalTo(XContentType.JSON.mediaType())
             );
             );
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+            MatcherAssert.assertThat(
+                webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER),
+                equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE)
+            );
 
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
             MatcherAssert.assertThat(
             MatcherAssert.assertThat(

+ 16 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java

@@ -44,6 +44,10 @@ public class CohereEmbeddingsRequestTests extends ESTestCase {
         MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
         MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+        MatcherAssert.assertThat(
+            httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
+            is(CohereUtils.ELASTIC_REQUEST_SOURCE)
+        );
 
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"))));
         MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"))));
@@ -71,6 +75,10 @@ public class CohereEmbeddingsRequestTests extends ESTestCase {
         MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
         MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+        MatcherAssert.assertThat(
+            httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
+            is(CohereUtils.ELASTIC_REQUEST_SOURCE)
+        );
 
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         MatcherAssert.assertThat(
         MatcherAssert.assertThat(
@@ -114,6 +122,10 @@ public class CohereEmbeddingsRequestTests extends ESTestCase {
         MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
         MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+        MatcherAssert.assertThat(
+            httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
+            is(CohereUtils.ELASTIC_REQUEST_SOURCE)
+        );
 
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         MatcherAssert.assertThat(
         MatcherAssert.assertThat(
@@ -157,6 +169,10 @@ public class CohereEmbeddingsRequestTests extends ESTestCase {
         MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
         MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
         MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+        MatcherAssert.assertThat(
+            httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
+            is(CohereUtils.ELASTIC_REQUEST_SOURCE)
+        );
 
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "truncate", "none")));
         MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "truncate", "none")));