Browse Source

Add Llama support to Inference Plugin (#130092)

* Refactor Hugging Face service settings and completion request methods for consistency

* Add Llama model support for embeddings and chat completions

* Refactor Llama request classes to improve secret settings handling

* Refactor DeltaParser in LlamaStreamingProcessor to improve argument handling

* Enhance Llama streaming processing by adding support for nullable object arrays

* [CI] Auto commit changes from spotless

* Fix error messages in LlamaActionCreator

* [CI] Auto commit changes from spotless

* Add detailed Javadoc comments to Llama classes for improved documentation

* Enhance LlamaChatCompletionResponseHandler to support mid-stream error handling and improve error response parsing

* Add Javadoc comments to Llama classes for improved documentation and clarity

* Fix checkstyle

* Update LlamaEmbeddingsRequest to use mediaTypeWithoutParameters for content type header

* Add unit tests for LlamaActionCreator and related models

* Add unit tests for LlamaChatCompletionServiceSettings to validate configuration parsing and serialization

* Add unit tests for LlamaEmbeddingsServiceSettings to validate configuration parsing and serialization

* Add unit tests for LlamaEmbeddingsServiceSettings to validate various configuration scenarios

* Add unit tests for LlamaChatCompletionResponseHandler to validate error response handling

* Refactor Llama embedding and chat completion tests for consistency and clarity

* Add unit tests for LlamaChatCompletionRequestEntity to validate message serialization

* Add unit tests for LlamaEmbeddingsRequest to validate request creation and truncation behavior

* Add unit tests for LlamaEmbeddingsRequestEntity to validate XContent serialization

* Add unit tests for LlamaErrorResponse to validate error handling from HTTP responses

* Add unit tests for LlamaChatCompletionServiceSettings to validate configuration parsing and serialization

* Add tests for LlamaService request configuration validation and error handling

* Fix error message formatting in LlamaServiceTests for better localization support

* Refactor Llama model classes to implement accept method for action visitors

* Hide Llama service from configuration API to enhance security and reduce exposure

* Refactor Llama model classes to remove modelId and update embedding request handling

* Refactor Llama request classes to use pattern matching for secret settings

* Update embeddings handler to use HuggingFace response entity

* Refactor Mistral model classes to remove modelId and update rate limit hashing

* Refactor Mistral action classes to remove taskSettings parameter and streamline action creation

* Refactor Llama and Mistral models to remove taskSettings parameter and simplify model instantiation

* Refactor Llama service tests to use Model instead of CustomModel and update similarity measure to DOT_PRODUCT

* Remove unused tests and imports from LlamaServiceTests

* Add chunking settings support to Llama embeddings model tests

* Add changelog

* Add support for version checks in Llama settings and define new transport version

* Refactor Llama model assertions and remove unused version support methods

* Refactor Llama service constructors to include ClusterService and improve error message handling

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Jan-Kazlouski-elastic 3 months ago
parent
commit
beb18a87c3
54 changed files with 4517 additions and 92 deletions
  1. 5 0
      docs/changelog/130092.yaml
  2. 21 0
      libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java
  3. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  4. 1 1
      server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java
  5. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java
  6. 21 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
  7. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
  8. 7 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
  9. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java
  10. 1 8
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java
  11. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java
  12. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java
  13. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java
  14. 89 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java
  15. 423 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
  16. 110 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java
  17. 34 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java
  18. 132 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java
  19. 180 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java
  20. 183 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java
  21. 29 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java
  22. 119 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java
  23. 29 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java
  24. 257 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java
  25. 96 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java
  26. 47 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java
  27. 96 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java
  28. 51 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java
  29. 59 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java
  30. 6 6
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java
  31. 10 31
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java
  32. 1 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java
  33. 1 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java
  34. 2 8
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java
  35. 4 8
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java
  36. 3 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java
  37. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java
  38. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java
  39. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java
  40. 14 9
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
  41. 840 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
  42. 283 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java
  43. 142 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java
  44. 162 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java
  45. 198 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java
  46. 51 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java
  47. 479 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java
  48. 64 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java
  49. 82 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java
  50. 38 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java
  51. 101 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java
  52. 34 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java
  53. 0 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java
  54. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java

+ 5 - 0
docs/changelog/130092.yaml

@@ -0,0 +1,5 @@
+pr: 130092
+summary: "Added Llama provider support to the Inference Plugin"
+area: Machine Learning
+type: enhancement
+issues: []

+ 21 - 0
libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java

@@ -220,6 +220,27 @@ public final class ConstructingObjectParser<Value, Context> extends AbstractObje
         }
     }
 
+    /**
+     * Declare a field that is an array of objects or null. Used to avoid calling the consumer when used with
+     * {@link #optionalConstructorArg()} or {@link #constructorArg()}.
+     * @param consumer Consumer that will be passed as is to the {@link #declareField(BiConsumer, ContextParser, ParseField, ValueType)}.
+     * @param objectParser Parser that will parse the objects in the array, checking for nulls.
+     * @param field Field to declare.
+     */
+    @Override
+    public <T> void declareObjectArrayOrNull(
+        BiConsumer<Value, List<T>> consumer,
+        ContextParser<Context, T> objectParser,
+        ParseField field
+    ) {
+        declareField(
+            consumer,
+            (p, c) -> p.currentToken() == XContentParser.Token.VALUE_NULL ? null : parseArray(p, c, objectParser),
+            field,
+            ValueType.OBJECT_ARRAY_OR_NULL
+        );
+    }
+
     @Override
     public <T> void declareNamedObject(
         BiConsumer<Value, T> consumer,

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -343,6 +343,7 @@ public class TransportVersions {
     public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);
     public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00);
     public static final TransportVersion PROJECT_STATE_REGISTRY_ENTRY = def(9_124_0_00);
+    public static final TransportVersion ML_INFERENCE_LLAMA_ADDED = def(9_125_0_00);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 1 - 1
server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

@@ -121,7 +121,7 @@ public record UnifiedCompletionRequest(
      * - Key: {@link #MODEL_FIELD}, Value: modelId
      * - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
      */
-    public static Params withMaxCompletionTokensTokens(String modelId, Params params) {
+    public static Params withMaxCompletionTokens(String modelId, Params params) {
         return new DelegatingMapParams(
             Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD)),
             params

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java

@@ -119,7 +119,7 @@ public class UnifiedCompletionRequestTests extends AbstractBWCWireSerializationT
 
             assertThat(request, is(expected));
             assertThat(
-                Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokensTokens("gpt-4o", ToXContent.EMPTY_PARAMS)),
+                Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokens("gpt-4o", ToXContent.EMPTY_PARAMS)),
                 is(XContentHelper.stripWhitespace(requestJson))
             );
         }

+ 21 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

@@ -106,6 +106,8 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbedd
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
 import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
 import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettings;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
 import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
@@ -175,6 +177,7 @@ public class InferenceNamedWriteablesProvider {
         addJinaAINamedWriteables(namedWriteables);
         addVoyageAINamedWriteables(namedWriteables);
         addCustomNamedWriteables(namedWriteables);
+        addLlamaNamedWriteables(namedWriteables);
 
         addUnifiedNamedWriteables(namedWriteables);
 
@@ -274,8 +277,25 @@ public class InferenceNamedWriteablesProvider {
                 MistralChatCompletionServiceSettings::new
             )
         );
+        // no task settings for Mistral
+    }
 
-        // note - no task settings for Mistral embeddings...
+    private static void addLlamaNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                ServiceSettings.class,
+                LlamaEmbeddingsServiceSettings.NAME,
+                LlamaEmbeddingsServiceSettings::new
+            )
+        );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                ServiceSettings.class,
+                LlamaChatCompletionServiceSettings.NAME,
+                LlamaChatCompletionServiceSettings::new
+            )
+        );
+        // no task settings for Llama
     }
 
     private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -133,6 +133,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService
 import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
 import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
+import org.elasticsearch.xpack.inference.services.llama.LlamaService;
 import org.elasticsearch.xpack.inference.services.mistral.MistralService;
 import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
 import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
@@ -402,6 +403,7 @@ public class InferencePlugin extends Plugin
             context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context),
             context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context),
             context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
+            context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context),
             ElasticsearchInternalService::new,
             context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
         );

+ 7 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

@@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
@@ -304,6 +305,12 @@ public final class ServiceUtils {
         return Strings.format("[%s] does not allow the setting [%s]", scope, settingName);
     }
 
+    public static URI extractUri(Map<String, Object> map, String fieldName, ValidationException validationException) {
+        String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+        return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
+    }
+
     public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) {
         try {
             return createOptionalUri(url);

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java

@@ -28,7 +28,7 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implement
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
-        unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(modelId, params));
+        unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokens(modelId, params));
         builder.endObject();
 
         return builder;

+ 1 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java

@@ -31,11 +31,10 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSION
 import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
 import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
 import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
 
 public class HuggingFaceServiceSettings extends FilteredXContentObject implements ServiceSettings, HuggingFaceRateLimitServiceSettings {
     public static final String NAME = "hugging_face_service_settings";
@@ -70,12 +69,6 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
         return new HuggingFaceServiceSettings(uri, similarityMeasure, dims, maxInputTokens, rateLimitSettings);
     }
 
-    public static URI extractUri(Map<String, Object> map, String fieldName, ValidationException validationException) {
-        String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
-
-        return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
-    }
-
     private final URI uri;
     private final SimilarityMeasure similarity;
     private final Integer dimensions;

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java

@@ -31,7 +31,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
 import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
-import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
 
 /**
  * Settings for the Hugging Face chat completion service.

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java

@@ -28,7 +28,7 @@ import java.util.Objects;
 
 import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
-import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
 
 public class HuggingFaceElserServiceSettings extends FilteredXContentObject
     implements

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java

@@ -27,7 +27,7 @@ import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
-import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
 
 public class HuggingFaceRerankServiceSettings extends FilteredXContentObject
     implements

+ 89 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java

@@ -0,0 +1,89 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama;
+
+import org.elasticsearch.inference.EmptySecretSettings;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SecretSettings;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
+import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Abstract class representing a Llama model for inference.
+ * This class extends RateLimitGroupingModel and provides common functionality for Llama models.
+ */
+public abstract class LlamaModel extends RateLimitGroupingModel {
+    protected URI uri;
+    protected RateLimitSettings rateLimitSettings;
+
+    /**
+     * Constructor for creating a LlamaModel with specified configurations and secrets.
+     *
+     * @param configurations the model configurations
+     * @param secrets the secret settings for the model
+     */
+    protected LlamaModel(ModelConfigurations configurations, ModelSecrets secrets) {
+        super(configurations, secrets);
+    }
+
+    /**
+     * Constructor for creating a LlamaModel with specified model, service settings, and secret settings.
+     * @param model the model configurations
+     * @param serviceSettings the settings for the inference service
+     */
+    protected LlamaModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
+        super(model, serviceSettings);
+    }
+
+    public URI uri() {
+        return this.uri;
+    }
+
+    @Override
+    public RateLimitSettings rateLimitSettings() {
+        return this.rateLimitSettings;
+    }
+
+    @Override
+    public int rateLimitGroupingHash() {
+        return Objects.hash(getServiceSettings().modelId(), uri, getSecretSettings());
+    }
+
+    // Needed for testing only
+    public void setURI(String newUri) {
+        try {
+            this.uri = new URI(newUri);
+        } catch (URISyntaxException e) {
+            // swallow any error
+        }
+    }
+
+    /**
+     * Retrieves the secret settings from the provided map of secrets.
+     * If the map is null or empty, it returns an instance of EmptySecretSettings.
+     * Caused by the fact that Llama model doesn't have out of the box security settings and can be used witout authentication.
+     *
+     * @param secrets the map containing secret settings
+     * @return an instance of SecretSettings
+     */
+    protected static SecretSettings retrieveSecretSettings(Map<String, Object> secrets) {
+        return (secrets != null && secrets.isEmpty()) ? EmptySecretSettings.INSTANCE : DefaultSecretSettings.fromMap(secrets);
+    }
+
+    protected abstract ExecutableAction accept(LlamaActionVisitor creator);
+}

+ 423 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java

@@ -0,0 +1,423 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.util.LazyInitializable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkedInference;
+import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InferenceServiceConfiguration;
+import org.elasticsearch.inference.InferenceServiceExtension;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SettingsConfiguration;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
+import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
+import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.SenderService;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionCreator;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest;
+import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION;
+import static org.elasticsearch.inference.TaskType.COMPLETION;
+import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+
+/**
+ * LlamaService is an inference service for Llama models, supporting text embedding and chat completion tasks.
+ * It extends SenderService to handle HTTP requests and responses for Llama models.
+ */
+public class LlamaService extends SenderService {
+    public static final String NAME = "llama";
+    private static final String SERVICE_NAME = "Llama";
+    /**
+     * The optimal batch size depends on the hardware the model is deployed on.
+     * For Llama use a conservatively small max batch size as it is
+     * unknown how the model is deployed
+     */
+    static final int EMBEDDING_MAX_BATCH_SIZE = 20;
+    private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION);
+    private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler(
+        "llama chat completion",
+        OpenAiChatCompletionResponseEntity::fromResponse
+    );
+
+    /**
+     * Constructor for creating a LlamaService with specified HTTP request sender factory and service components.
+     *
+     * @param factory the factory to create HTTP request senders
+     * @param serviceComponents the components required for the inference service
+     * @param context the context for the inference service factory
+     */
+    public LlamaService(
+        HttpRequestSender.Factory factory,
+        ServiceComponents serviceComponents,
+        InferenceServiceExtension.InferenceServiceFactoryContext context
+    ) {
+        this(factory, serviceComponents, context.clusterService());
+    }
+
+    public LlamaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
+        super(factory, serviceComponents, clusterService);
+    }
+
+    @Override
+    protected void doInfer(
+        Model model,
+        InferenceInputs inputs,
+        Map<String, Object> taskSettings,
+        TimeValue timeout,
+        ActionListener<InferenceServiceResults> listener
+    ) {
+        var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents());
+        if (model instanceof LlamaModel llamaModel) {
+            llamaModel.accept(actionCreator).execute(inputs, timeout, listener);
+        } else {
+            listener.onFailure(createInvalidModelException(model));
+        }
+    }
+
+    @Override
+    protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
+        ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException);
+    }
+
+    /**
+     * Creates a LlamaModel based on the provided parameters.
+     *
+     * @param inferenceId the unique identifier for the inference entity
+     * @param taskType the type of task this model is designed for
+     * @param serviceSettings the settings for the inference service
+     * @param chunkingSettings the settings for chunking, if applicable
+     * @param secretSettings the secret settings for the model, such as API keys or tokens
+     * @param failureMessage the message to use in case of failure
+     * @param context the context for parsing configuration settings
+     * @return a new instance of LlamaModel based on the provided parameters
+     */
+    protected LlamaModel createModel(
+        String inferenceId,
+        TaskType taskType,
+        Map<String, Object> serviceSettings,
+        ChunkingSettings chunkingSettings,
+        Map<String, Object> secretSettings,
+        String failureMessage,
+        ConfigurationParseContext context
+    ) {
+        switch (taskType) {
+            case TEXT_EMBEDDING:
+                return new LlamaEmbeddingsModel(inferenceId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context);
+            case CHAT_COMPLETION, COMPLETION:
+                return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context);
+            default:
+                throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+        }
+    }
+
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
+        if (model instanceof LlamaEmbeddingsModel embeddingsModel) {
+            var serviceSettings = embeddingsModel.getServiceSettings();
+            var similarityFromModel = serviceSettings.similarity();
+            var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+
+            var updatedServiceSettings = new LlamaEmbeddingsServiceSettings(
+                serviceSettings.modelId(),
+                serviceSettings.uri(),
+                embeddingSize,
+                similarityToUse,
+                serviceSettings.maxInputTokens(),
+                serviceSettings.rateLimitSettings()
+            );
+
+            return new LlamaEmbeddingsModel(embeddingsModel, updatedServiceSettings);
+        } else {
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
+        }
+    }
+
+    @Override
+    protected void doChunkedInfer(
+        Model model,
+        EmbeddingsInput inputs,
+        Map<String, Object> taskSettings,
+        InputType inputType,
+        TimeValue timeout,
+        ActionListener<List<ChunkedInference>> listener
+    ) {
+        if (model instanceof LlamaEmbeddingsModel == false) {
+            listener.onFailure(createInvalidModelException(model));
+            return;
+        }
+
+        var llamaModel = (LlamaEmbeddingsModel) model;
+        var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents());
+
+        List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
+            inputs.getInputs(),
+            EMBEDDING_MAX_BATCH_SIZE,
+            llamaModel.getConfigurations().getChunkingSettings()
+        ).batchRequestsWithListeners(listener);
+
+        for (var request : batchedRequests) {
+            var action = llamaModel.accept(actionCreator);
+            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+        }
+    }
+
+    @Override
+    protected void doUnifiedCompletionInfer(
+        Model model,
+        UnifiedChatInput inputs,
+        TimeValue timeout,
+        ActionListener<InferenceServiceResults> listener
+    ) {
+        if (model instanceof LlamaChatCompletionModel == false) {
+            listener.onFailure(createInvalidModelException(model));
+            return;
+        }
+
+        var llamaChatCompletionModel = (LlamaChatCompletionModel) model;
+        var overriddenModel = LlamaChatCompletionModel.of(llamaChatCompletionModel, inputs.getRequest());
+        var manager = new GenericRequestManager<>(
+            getServiceComponents().threadPool(),
+            overriddenModel,
+            UNIFIED_CHAT_COMPLETION_HANDLER,
+            unifiedChatInput -> new LlamaChatCompletionRequest(unifiedChatInput, overriddenModel),
+            UnifiedChatInput.class
+        );
+        var errorMessage = LlamaActionCreator.buildErrorMessage(CHAT_COMPLETION, model.getInferenceEntityId());
+        var action = new SenderExecutableAction(getSender(), manager, errorMessage);
+
+        action.execute(inputs, timeout, listener);
+    }
+
+    @Override
+    public Set<TaskType> supportedStreamingTasks() {
+        return EnumSet.of(COMPLETION, CHAT_COMPLETION);
+    }
+
+    @Override
+    public InferenceServiceConfiguration getConfiguration() {
+        return Configuration.get();
+    }
+
+    @Override
+    public EnumSet<TaskType> supportedTaskTypes() {
+        return SUPPORTED_TASK_TYPES;
+    }
+
+    @Override
+    public String name() {
+        return NAME;
+    }
+
+    @Override
+    public void parseRequestConfig(
+        String modelId,
+        TaskType taskType,
+        Map<String, Object> config,
+        ActionListener<Model> parsedModelListener
+    ) {
+        try {
+            Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+            Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+
+            ChunkingSettings chunkingSettings = null;
+            if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+                chunkingSettings = ChunkingSettingsBuilder.fromMap(
+                    removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
+                );
+            }
+
+            LlamaModel model = createModel(
+                modelId,
+                taskType,
+                serviceSettingsMap,
+                chunkingSettings,
+                serviceSettingsMap,
+                TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
+                ConfigurationParseContext.REQUEST
+            );
+
+            throwIfNotEmptyMap(config, NAME);
+            throwIfNotEmptyMap(serviceSettingsMap, NAME);
+            throwIfNotEmptyMap(taskSettingsMap, NAME);
+
+            parsedModelListener.onResponse(model);
+        } catch (Exception e) {
+            parsedModelListener.onFailure(e);
+        }
+    }
+
+    @Override
+    public Model parsePersistedConfigWithSecrets(
+        String modelId,
+        TaskType taskType,
+        Map<String, Object> config,
+        Map<String, Object> secrets
+    ) {
+        Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+        removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+        Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
+
+        ChunkingSettings chunkingSettings = null;
+        if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+            chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+        }
+
+        return createModelFromPersistent(
+            modelId,
+            taskType,
+            serviceSettingsMap,
+            chunkingSettings,
+            secretSettingsMap,
+            parsePersistedConfigErrorMsg(modelId, NAME)
+        );
+    }
+
+    private LlamaModel createModelFromPersistent(
+        String inferenceEntityId,
+        TaskType taskType,
+        Map<String, Object> serviceSettings,
+        ChunkingSettings chunkingSettings,
+        Map<String, Object> secretSettings,
+        String failureMessage
+    ) {
+        return createModel(
+            inferenceEntityId,
+            taskType,
+            serviceSettings,
+            chunkingSettings,
+            secretSettings,
+            failureMessage,
+            ConfigurationParseContext.PERSISTENT
+        );
+    }
+
+    @Override
+    public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
+        Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+        removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+
+        ChunkingSettings chunkingSettings = null;
+        if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+            chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+        }
+
+        return createModelFromPersistent(
+            modelId,
+            taskType,
+            serviceSettingsMap,
+            chunkingSettings,
+            null,
+            parsePersistedConfigErrorMsg(modelId, NAME)
+        );
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_LLAMA_ADDED;
+    }
+
+    @Override
+    public boolean hideFromConfigurationApi() {
+        // The Llama service is very configurable so we're going to hide it from being exposed in the service API.
+        return true;
+    }
+
+    /**
+     * Configuration class for the Llama inference service.
+     * It provides the settings and configurations required for the service.
+     */
+    public static class Configuration {
+        public static InferenceServiceConfiguration get() {
+            return CONFIGURATION.getOrCompute();
+        }
+
+        private Configuration() {}
+
+        private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> CONFIGURATION = new LazyInitializable<>(
+            () -> {
+                var configurationMap = new HashMap<String, SettingsConfiguration>();
+
+                configurationMap.put(
+                    URL,
+                    new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.")
+                        .setLabel("URL")
+                        .setRequired(true)
+                        .setSensitive(false)
+                        .setUpdatable(false)
+                        .setType(SettingsConfigurationFieldType.STRING)
+                        .build()
+                );
+                configurationMap.put(
+                    MODEL_ID,
+                    new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription(
+                        "Refer to the Llama models documentation for the list of available models."
+                    )
+                        .setLabel("Model")
+                        .setRequired(true)
+                        .setSensitive(false)
+                        .setUpdatable(false)
+                        .setType(SettingsConfigurationFieldType.STRING)
+                        .build()
+                );
+                configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
+                configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
+
+                return new InferenceServiceConfiguration.Builder().setService(NAME)
+                    .setName(SERVICE_NAME)
+                    .setTaskTypes(SUPPORTED_TASK_TYPES)
+                    .setConfigurations(configurationMap)
+                    .build();
+            }
+        );
+    }
+}

+ 110 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java

@@ -0,0 +1,110 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.action;
+
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsResponseHandler;
+import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest;
+import org.elasticsearch.xpack.inference.services.llama.request.embeddings.LlamaEmbeddingsRequest;
+import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
+
+import java.util.Objects;
+
+import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
+
+/**
+ * Creates actions for Llama inference requests, handling both embeddings and completions.
+ * This class implements the {@link LlamaActionVisitor} interface to provide specific action creation methods.
+ */
+public class LlamaActionCreator implements LlamaActionVisitor {
+
+    private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Llama %s request from inference entity id [%s]";
+    private static final String COMPLETION_ERROR_PREFIX = "Llama completions";
+    private static final String USER_ROLE = "user";
+
+    private static final ResponseHandler EMBEDDINGS_HANDLER = new LlamaEmbeddingsResponseHandler(
+        "llama text embedding",
+        HuggingFaceEmbeddingsResponseEntity::fromResponse
+    );
+    private static final ResponseHandler COMPLETION_HANDLER = new LlamaCompletionResponseHandler(
+        "llama completion",
+        OpenAiChatCompletionResponseEntity::fromResponse
+    );
+
+    private final Sender sender;
+    private final ServiceComponents serviceComponents;
+
+    /**
+     * Constructs a new LlamaActionCreator with the specified sender and service components.
+     *
+     * @param sender the sender to use for executing actions
+     * @param serviceComponents the service components providing necessary services
+     */
+    public LlamaActionCreator(Sender sender, ServiceComponents serviceComponents) {
+        this.sender = Objects.requireNonNull(sender);
+        this.serviceComponents = Objects.requireNonNull(serviceComponents);
+    }
+
+    @Override
+    public ExecutableAction create(LlamaEmbeddingsModel model) {
+        var manager = new GenericRequestManager<>(
+            serviceComponents.threadPool(),
+            model,
+            EMBEDDINGS_HANDLER,
+            embeddingsInput -> new LlamaEmbeddingsRequest(
+                serviceComponents.truncator(),
+                truncate(embeddingsInput.getStringInputs(), model.getServiceSettings().maxInputTokens()),
+                model
+            ),
+            EmbeddingsInput.class
+        );
+
+        var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId());
+        return new SenderExecutableAction(sender, manager, errorMessage);
+    }
+
+    @Override
+    public ExecutableAction create(LlamaChatCompletionModel model) {
+        var manager = new GenericRequestManager<>(
+            serviceComponents.threadPool(),
+            model,
+            COMPLETION_HANDLER,
+            inputs -> new LlamaChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
+            ChatCompletionInput.class
+        );
+
+        var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId());
+        return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
+    }
+
+    /**
+     * Builds an error message for failed requests.
+     *
+     * @param requestType the type of request that failed
+     * @param inferenceId the inference entity ID associated with the request
+     * @return a formatted error message
+     */
+    public static String buildErrorMessage(TaskType requestType, String inferenceId) {
+        return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId);
+    }
+}

+ 34 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java

@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.action;
+
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel;
+
+/**
+ * Visitor interface for creating executable actions for Llama inference models.
+ * This interface defines methods to create actions for both embeddings and chat completion models.
+ */
+public interface LlamaActionVisitor {
+    /**
+     * Creates an executable action for the given Llama embeddings model.
+     *
+     * @param model the Llama embeddings model
+     * @return an executable action for the embeddings model
+     */
+    ExecutableAction create(LlamaEmbeddingsModel model);
+
+    /**
+     * Creates an executable action for the given Llama chat completion model.
+     *
+     * @param model the Llama chat completion model
+     * @return an executable action for the chat completion model
+     */
+    ExecutableAction create(LlamaChatCompletionModel model);
+}

+ 132 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java

@@ -0,0 +1,132 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.completion;
+
+import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SecretSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.llama.LlamaModel;
+import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor;
+
+import java.util.Map;
+
+/**
+ * Represents a Llama chat completion model for inference.
+ * This class extends the LlamaModel and provides specific configurations and settings for chat completion tasks.
+ */
+public class LlamaChatCompletionModel extends LlamaModel {
+
+    /**
+     * Constructor for creating a LlamaChatCompletionModel with specified parameters.
+     * @param inferenceEntityId the unique identifier for the inference entity
+     * @param taskType the type of task this model is designed for
+     * @param service the name of the inference service
+     * @param serviceSettings the settings for the inference service, specific to chat completion
+     * @param secrets the secret settings for the model, such as API keys or tokens
+     * @param context the context for parsing configuration settings
+     */
+    public LlamaChatCompletionModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        Map<String, Object> serviceSettings,
+        Map<String, Object> secrets,
+        ConfigurationParseContext context
+    ) {
+        this(
+            inferenceEntityId,
+            taskType,
+            service,
+            LlamaChatCompletionServiceSettings.fromMap(serviceSettings, context),
+            retrieveSecretSettings(secrets)
+        );
+    }
+
+    /**
+     * Constructor for creating a LlamaChatCompletionModel with specified parameters.
+     * @param inferenceEntityId the unique identifier for the inference entity
+     * @param taskType the type of task this model is designed for
+     * @param service the name of the inference service
+     * @param serviceSettings the settings for the inference service, specific to chat completion
+     * @param secrets the secret settings for the model, such as API keys or tokens
+     */
+    public LlamaChatCompletionModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        LlamaChatCompletionServiceSettings serviceSettings,
+        SecretSettings secrets
+    ) {
+        super(
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE),
+            new ModelSecrets(secrets)
+        );
+        setPropertiesFromServiceSettings(serviceSettings);
+    }
+
+    /**
+     * Factory method to create a LlamaChatCompletionModel with overridden model settings based on the request.
+     * If the request does not specify a model, the original model is returned.
+     *
+     * @param model the original LlamaChatCompletionModel
+     * @param request the UnifiedCompletionRequest containing potential overrides
+     * @return a new LlamaChatCompletionModel with overridden settings or the original model if no overrides are specified
+     */
+    public static LlamaChatCompletionModel of(LlamaChatCompletionModel model, UnifiedCompletionRequest request) {
+        if (request.model() == null) {
+            // If no model id is specified in the request, return the original model
+            return model;
+        }
+
+        var originalModelServiceSettings = model.getServiceSettings();
+        var overriddenServiceSettings = new LlamaChatCompletionServiceSettings(
+            request.model(),
+            originalModelServiceSettings.uri(),
+            originalModelServiceSettings.rateLimitSettings()
+        );
+
+        return new LlamaChatCompletionModel(
+            model.getInferenceEntityId(),
+            model.getTaskType(),
+            model.getConfigurations().getService(),
+            overriddenServiceSettings,
+            model.getSecretSettings()
+        );
+    }
+
+    private void setPropertiesFromServiceSettings(LlamaChatCompletionServiceSettings serviceSettings) {
+        this.uri = serviceSettings.uri();
+        this.rateLimitSettings = serviceSettings.rateLimitSettings();
+    }
+
+    /**
+     * Returns the service settings specific to Llama chat completion.
+     *
+     * @return the LlamaChatCompletionServiceSettings associated with this model
+     */
+    @Override
+    public LlamaChatCompletionServiceSettings getServiceSettings() {
+        return (LlamaChatCompletionServiceSettings) super.getServiceSettings();
+    }
+
+    /**
+     * Accepts a visitor that creates an executable action for this Llama chat completion model.
+     *
+     * @param creator the visitor that creates the executable action
+     * @return an ExecutableAction representing this model
+     */
+    @Override
+    public ExecutableAction accept(LlamaActionVisitor creator) {
+        return creator.create(this);
+    }
+}

+ 180 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java

@@ -0,0 +1,180 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.completion;
+
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
+
+import java.util.Locale;
+import java.util.Optional;
+
+import static org.elasticsearch.core.Strings.format;
+
+/**
+ * Handles streaming chat completion responses and error parsing for Llama inference endpoints.
+ * This handler is designed to work with the unified Llama chat completion API.
+ */
+public class LlamaChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
+
+    private static final String LLAMA_ERROR = "llama_error";
+    private static final String STREAM_ERROR = "stream_error";
+
+    /**
+     * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type and response parser.
+     *
+     * @param requestType the type of request this handler will process
+     * @param parseFunction the function to parse the response
+     */
+    public LlamaChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
+        super(requestType, parseFunction, LlamaErrorResponse::fromResponse);
+    }
+
+    /**
+     * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type,
+     * @param message the error message to include in the exception
+     * @param request the request that caused the error
+     * @param result the HTTP result containing the response
+     * @param errorResponse the error response parsed from the HTTP result
+     * @return an exception representing the error, specific to Llama chat completion
+     */
+    @Override
+    protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
+        assert request.isStreaming() : "Only streaming requests support this format";
+        var responseStatusCode = result.response().getStatusLine().getStatusCode();
+        if (request.isStreaming()) {
+            var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode);
+            var restStatus = toRestStatus(responseStatusCode);
+            return errorResponse instanceof LlamaErrorResponse
+                ? new UnifiedChatCompletionException(restStatus, errorMessage, LLAMA_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
+                : new UnifiedChatCompletionException(
+                    restStatus,
+                    errorMessage,
+                    createErrorType(errorResponse),
+                    restStatus.name().toLowerCase(Locale.ROOT)
+                );
+        } else {
+            return super.buildError(message, request, result, errorResponse);
+        }
+    }
+
+    /**
+     * Builds an exception for mid-stream errors encountered during Llama chat completion requests.
+     *
+     * @param request the request that caused the error
+     * @param message the error message
+     * @param e the exception that occurred, if any
+     * @return a UnifiedChatCompletionException representing the error
+     */
+    @Override
+    protected Exception buildMidStreamError(Request request, String message, Exception e) {
+        var errorResponse = StreamingLlamaErrorResponseEntity.fromString(message);
+        if (errorResponse instanceof StreamingLlamaErrorResponseEntity) {
+            return new UnifiedChatCompletionException(
+                RestStatus.INTERNAL_SERVER_ERROR,
+                format(
+                    "%s for request from inference entity id [%s]. Error message: [%s]",
+                    SERVER_ERROR_OBJECT,
+                    request.getInferenceEntityId(),
+                    errorResponse.getErrorMessage()
+                ),
+                LLAMA_ERROR,
+                STREAM_ERROR
+            );
+        } else if (e != null) {
+            return UnifiedChatCompletionException.fromThrowable(e);
+        } else {
+            return new UnifiedChatCompletionException(
+                RestStatus.INTERNAL_SERVER_ERROR,
+                format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
+                createErrorType(errorResponse),
+                STREAM_ERROR
+            );
+        }
+    }
+
+    /**
+     * StreamingLlamaErrorResponseEntity allows creation of {@link ErrorResponse} from a JSON string.
+     * This entity is used to parse error responses from streaming Llama requests.
+     * For non-streaming requests {@link LlamaErrorResponse} should be used.
+     * Example error response for Bad Request error would look like:
+     * <pre><code>
+     *  {
+     *      "error": {
+     *          "message": "400: Invalid value: Model 'llama3.12:3b' not found"
+     *      }
+     *  }
+     * </code></pre>
+     */
+    private static class StreamingLlamaErrorResponseEntity extends ErrorResponse {
+        private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
+            LLAMA_ERROR,
+            true,
+            args -> Optional.ofNullable((LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity) args[0])
+        );
+        private static final ConstructingObjectParser<
+            LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity,
+            Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
+                LLAMA_ERROR,
+                true,
+                args -> new LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity(
+                    args[0] != null ? (String) args[0] : "unknown"
+                )
+            );
+
+        static {
+            ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message"));
+
+            ERROR_PARSER.declareObjectOrNull(
+                ConstructingObjectParser.optionalConstructorArg(),
+                ERROR_BODY_PARSER,
+                null,
+                new ParseField("error")
+            );
+        }
+
+        /**
+         * Parses a streaming Llama error response from a JSON string.
+         *
+         * @param response the raw JSON string representing an error
+         * @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
+         */
+        private static ErrorResponse fromString(String response) {
+            try (
+                XContentParser parser = XContentFactory.xContent(XContentType.JSON)
+                    .createParser(XContentParserConfiguration.EMPTY, response)
+            ) {
+                return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
+            } catch (Exception e) {
+                // swallow the error
+            }
+
+            return ErrorResponse.UNDEFINED_ERROR;
+        }
+
+        /**
+         * Constructs a StreamingLlamaErrorResponseEntity with the specified error message.
+         *
+         * @param errorMessage the error message to include in the response entity
+         */
+        StreamingLlamaErrorResponseEntity(String errorMessage) {
+            super(errorMessage);
+        }
+    }
+}

+ 183 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java

@@ -0,0 +1,183 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.completion;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.llama.LlamaService;
+import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
+
+/**
+ * Represents the settings for a Llama chat completion service.
+ * This class encapsulates the model ID, URI, and rate limit settings for the Llama chat completion service.
+ */
+public class LlamaChatCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings {
+    public static final String NAME = "llama_completion_service_settings";
+    // There is no default rate limit for Llama, so we set a reasonable default of 3000 requests per minute
+    protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
+
+    private final String modelId;
+    private final URI uri;
+    private final RateLimitSettings rateLimitSettings;
+
+    /**
+     * Creates a new instance of LlamaChatCompletionServiceSettings from a map of settings.
+     *
+     * @param map the map containing the service settings
+     * @param context the context for parsing configuration settings
+     * @return a new instance of LlamaChatCompletionServiceSettings
+     * @throws ValidationException if required fields are missing or invalid
+     */
+    public static LlamaChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
+        ValidationException validationException = new ValidationException();
+
+        var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        var uri = extractUri(map, URL, validationException);
+        RateLimitSettings rateLimitSettings = RateLimitSettings.of(
+            map,
+            DEFAULT_RATE_LIMIT_SETTINGS,
+            validationException,
+            LlamaService.NAME,
+            context
+        );
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new LlamaChatCompletionServiceSettings(model, uri, rateLimitSettings);
+    }
+
+    /**
+     * Constructs a new LlamaChatCompletionServiceSettings from a StreamInput.
+     *
+     * @param in the StreamInput to read from
+     * @throws IOException if an I/O error occurs during reading
+     */
+    public LlamaChatCompletionServiceSettings(StreamInput in) throws IOException {
+        this.modelId = in.readString();
+        this.uri = createUri(in.readString());
+        this.rateLimitSettings = new RateLimitSettings(in);
+    }
+
+    /**
+     * Constructs a new LlamaChatCompletionServiceSettings with the specified model ID, URI, and rate limit settings.
+     *
+     * @param modelId the ID of the model
+     * @param uri the URI of the service
+     * @param rateLimitSettings the rate limit settings for the service
+     */
+    public LlamaChatCompletionServiceSettings(String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) {
+        this.modelId = modelId;
+        this.uri = uri;
+        this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
+    }
+
+    /**
+     * Constructs a new LlamaChatCompletionServiceSettings with the specified model ID and URL.
+     * The rate limit settings will be set to the default value.
+     *
+     * @param modelId the ID of the model
+     * @param url the URL of the service
+     */
+    public LlamaChatCompletionServiceSettings(String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) {
+        this(modelId, createUri(url), rateLimitSettings);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_LLAMA_ADDED;
+    }
+
+    @Override
+    public String modelId() {
+        return this.modelId;
+    }
+
+    /**
+     * Returns the URI of the Llama chat completion service.
+     *
+     * @return the URI of the service
+     */
+    public URI uri() {
+        return this.uri;
+    }
+
+    /**
+     * Returns the rate limit settings for the Llama chat completion service.
+     *
+     * @return the rate limit settings
+     */
+    public RateLimitSettings rateLimitSettings() {
+        return this.rateLimitSettings;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(modelId);
+        out.writeString(uri.toString());
+        rateLimitSettings.writeTo(out);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        toXContentFragmentOfExposedFields(builder, params);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
+        builder.field(MODEL_ID, modelId);
+        builder.field(URL, uri.toString());
+        rateLimitSettings.toXContent(builder, params);
+
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        LlamaChatCompletionServiceSettings that = (LlamaChatCompletionServiceSettings) o;
+        return Objects.equals(modelId, that.modelId)
+            && Objects.equals(uri, that.uri)
+            && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(modelId, uri, rateLimitSettings);
+    }
+}

+ 29 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java

@@ -0,0 +1,29 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.completion;
+
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
+
+/**
+ * Handles non-streaming completion responses for Llama models, extending the OpenAI completion response handler.
+ * This class is specifically designed to handle Llama's error response format.
+ */
+public class LlamaCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
+
+    /**
+     * Constructs a LlamaCompletionResponseHandler with the specified request type and response parser.
+     *
+     * @param requestType The type of request being handled (e.g., "llama completions").
+     * @param parseFunction The function to parse the response.
+     */
+    public LlamaCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
+        super(requestType, parseFunction, LlamaErrorResponse::fromResponse);
+    }
+}

+ 119 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java

@@ -0,0 +1,119 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.embeddings;
+
+import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SecretSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.llama.LlamaModel;
+import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor;
+
+import java.util.Map;
+
+/**
+ * Represents a Llama embeddings model for inference.
+ * This class extends the LlamaModel and provides specific configurations and settings for embeddings tasks.
+ */
+public class LlamaEmbeddingsModel extends LlamaModel {
+
+    /**
+     * Constructor for creating a LlamaEmbeddingsModel with specified parameters.
+     *
+     * @param inferenceEntityId the unique identifier for the inference entity
+     * @param taskType the type of task this model is designed for
+     * @param service the name of the inference service
+     * @param serviceSettings the settings for the inference service, specific to embeddings
+     * @param secrets the secret settings for the model, such as API keys or tokens
+     * @param context the context for parsing configuration settings
+     */
+    public LlamaEmbeddingsModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        Map<String, Object> serviceSettings,
+        ChunkingSettings chunkingSettings,
+        Map<String, Object> secrets,
+        ConfigurationParseContext context
+    ) {
+        this(
+            inferenceEntityId,
+            taskType,
+            service,
+            LlamaEmbeddingsServiceSettings.fromMap(serviceSettings, context),
+            chunkingSettings,
+            retrieveSecretSettings(secrets)
+        );
+    }
+
+    /**
+     * Constructor for creating a LlamaEmbeddingsModel with specified parameters.
+     *
+     * @param model the base LlamaEmbeddingsModel to copy properties from
+     * @param serviceSettings the settings for the inference service, specific to embeddings
+     */
+    public LlamaEmbeddingsModel(LlamaEmbeddingsModel model, LlamaEmbeddingsServiceSettings serviceSettings) {
+        super(model, serviceSettings);
+        setPropertiesFromServiceSettings(serviceSettings);
+    }
+
+    /**
+     * Sets properties from the provided LlamaEmbeddingsServiceSettings.
+     *
+     * @param serviceSettings the service settings to extract properties from
+     */
+    private void setPropertiesFromServiceSettings(LlamaEmbeddingsServiceSettings serviceSettings) {
+        this.uri = serviceSettings.uri();
+        this.rateLimitSettings = serviceSettings.rateLimitSettings();
+    }
+
+    /**
+     * Constructor for creating a LlamaEmbeddingsModel with specified parameters.
+     *
+     * @param inferenceEntityId the unique identifier for the inference entity
+     * @param taskType the type of task this model is designed for
+     * @param service the name of the inference service
+     * @param serviceSettings the settings for the inference service, specific to embeddings
+     * @param chunkingSettings the chunking settings for processing input data
+     * @param secrets the secret settings for the model, such as API keys or tokens
+     */
+    public LlamaEmbeddingsModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        LlamaEmbeddingsServiceSettings serviceSettings,
+        ChunkingSettings chunkingSettings,
+        SecretSettings secrets
+    ) {
+        super(
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings),
+            new ModelSecrets(secrets)
+        );
+        setPropertiesFromServiceSettings(serviceSettings);
+    }
+
+    @Override
+    public LlamaEmbeddingsServiceSettings getServiceSettings() {
+        return (LlamaEmbeddingsServiceSettings) super.getServiceSettings();
+    }
+
+    /**
+     * Accepts a visitor to create an executable action for this Llama embeddings model.
+     *
+     * @param creator the visitor that creates the executable action
+     * @return an ExecutableAction representing the Llama embeddings model
+     */
+    @Override
+    public ExecutableAction accept(LlamaActionVisitor creator) {
+        return creator.create(this);
+    }
+}

+ 29 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java

@@ -0,0 +1,29 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.embeddings;
+
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler;
+
+/**
+ * Handles responses for Llama embeddings requests, parsing the response and handling errors.
+ * This class extends OpenAiResponseHandler to provide specific functionality for Llama embeddings.
+ */
+public class LlamaEmbeddingsResponseHandler extends OpenAiResponseHandler {
+
+    /**
+     * Constructs a new LlamaEmbeddingsResponseHandler with the specified request type and response parser.
+     *
+     * @param requestType the type of request this handler will process
+     * @param parseFunction the function to parse the response
+     */
+    public LlamaEmbeddingsResponseHandler(String requestType, ResponseParser parseFunction) {
+        super(requestType, parseFunction, LlamaErrorResponse::fromResponse, false);
+    }
+}

+ 257 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java

@@ -0,0 +1,257 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.embeddings;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.llama.LlamaService;
+import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
+
+/**
+ * Settings for the Llama embeddings service.
+ * This class encapsulates the configuration settings required to use Llama for generating embeddings.
+ */
+public class LlamaEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings {
+    public static final String NAME = "llama_embeddings_service_settings";
+    // There is no default rate limit for Llama, so we set a reasonable default of 3000 requests per minute
+    protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
+
+    private final String modelId;
+    private final URI uri;
+    private final Integer dimensions;
+    private final SimilarityMeasure similarity;
+    private final Integer maxInputTokens;
+    private final RateLimitSettings rateLimitSettings;
+
+    /**
+     * Creates a new instance of LlamaEmbeddingsServiceSettings from a map of settings.
+     *
+     * @param map the map containing the settings
+     * @param context the context for parsing configuration settings
+     * @return a new instance of LlamaEmbeddingsServiceSettings
+     * @throws ValidationException if any required fields are missing or invalid
+     */
+    public static LlamaEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
+        ValidationException validationException = new ValidationException();
+
+        var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        var uri = extractUri(map, URL, validationException);
+        var dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        var similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        var maxInputTokens = extractOptionalPositiveInteger(
+            map,
+            MAX_INPUT_TOKENS,
+            ModelConfigurations.SERVICE_SETTINGS,
+            validationException
+        );
+        var rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException, LlamaService.NAME, context);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new LlamaEmbeddingsServiceSettings(model, uri, dimensions, similarity, maxInputTokens, rateLimitSettings);
+    }
+
+    /**
+     * Constructs a new LlamaEmbeddingsServiceSettings from a StreamInput.
+     *
+     * @param in the StreamInput to read from
+     * @throws IOException if an I/O error occurs during reading
+     */
+    public LlamaEmbeddingsServiceSettings(StreamInput in) throws IOException {
+        this.modelId = in.readString();
+        this.uri = createUri(in.readString());
+        this.dimensions = in.readOptionalVInt();
+        this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
+        this.maxInputTokens = in.readOptionalVInt();
+        this.rateLimitSettings = new RateLimitSettings(in);
+    }
+
+    /**
+     * Constructs a new LlamaEmbeddingsServiceSettings with the specified parameters.
+     *
+     * @param modelId the identifier for the model
+     * @param uri the URI of the Llama service
+     * @param dimensions the number of dimensions for the embeddings, can be null
+     * @param similarity the similarity measure to use, can be null
+     * @param maxInputTokens the maximum number of input tokens, can be null
+     * @param rateLimitSettings the rate limit settings for the service, can be null
+     */
+    public LlamaEmbeddingsServiceSettings(
+        String modelId,
+        URI uri,
+        @Nullable Integer dimensions,
+        @Nullable SimilarityMeasure similarity,
+        @Nullable Integer maxInputTokens,
+        @Nullable RateLimitSettings rateLimitSettings
+    ) {
+        this.modelId = modelId;
+        this.uri = uri;
+        this.dimensions = dimensions;
+        this.similarity = similarity;
+        this.maxInputTokens = maxInputTokens;
+        this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
+    }
+
+    /**
+     * Constructs a new LlamaEmbeddingsServiceSettings with the specified parameters.
+     *
+     * @param modelId the identifier for the model
+     * @param url the URL of the Llama service
+     * @param dimensions the number of dimensions for the embeddings, can be null
+     * @param similarity the similarity measure to use, can be null
+     * @param maxInputTokens the maximum number of input tokens, can be null
+     * @param rateLimitSettings the rate limit settings for the service, can be null
+     */
+    public LlamaEmbeddingsServiceSettings(
+        String modelId,
+        String url,
+        @Nullable Integer dimensions,
+        @Nullable SimilarityMeasure similarity,
+        @Nullable Integer maxInputTokens,
+        @Nullable RateLimitSettings rateLimitSettings
+    ) {
+        this(modelId, createUri(url), dimensions, similarity, maxInputTokens, rateLimitSettings);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_LLAMA_ADDED;
+    }
+
+    @Override
+    public String modelId() {
+        return this.modelId;
+    }
+
+    public URI uri() {
+        return this.uri;
+    }
+
+    @Override
+    public Integer dimensions() {
+        return this.dimensions;
+    }
+
+    @Override
+    public SimilarityMeasure similarity() {
+        return this.similarity;
+    }
+
+    @Override
+    public DenseVectorFieldMapper.ElementType elementType() {
+        return DenseVectorFieldMapper.ElementType.FLOAT;
+    }
+
+    /**
+     * Returns the maximum number of input tokens allowed for this service.
+     *
+     * @return the maximum input tokens, or null if not specified
+     */
+    public Integer maxInputTokens() {
+        return this.maxInputTokens;
+    }
+
+    /**
+     * Returns the rate limit settings for this service.
+     *
+     * @return the rate limit settings, never null
+     */
+    public RateLimitSettings rateLimitSettings() {
+        return this.rateLimitSettings;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(modelId);
+        out.writeString(uri.toString());
+        out.writeOptionalVInt(dimensions);
+        out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
+        out.writeOptionalVInt(maxInputTokens);
+        rateLimitSettings.writeTo(out);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        toXContentFragmentOfExposedFields(builder, params);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
+        builder.field(MODEL_ID, modelId);
+        builder.field(URL, uri.toString());
+
+        if (dimensions != null) {
+            builder.field(DIMENSIONS, dimensions);
+        }
+        if (similarity != null) {
+            builder.field(SIMILARITY, similarity);
+        }
+        if (maxInputTokens != null) {
+            builder.field(MAX_INPUT_TOKENS, maxInputTokens);
+        }
+        rateLimitSettings.toXContent(builder, params);
+
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        LlamaEmbeddingsServiceSettings that = (LlamaEmbeddingsServiceSettings) o;
+        return Objects.equals(modelId, that.modelId)
+            && Objects.equals(uri, that.uri)
+            && Objects.equals(dimensions, that.dimensions)
+            && Objects.equals(maxInputTokens, that.maxInputTokens)
+            && Objects.equals(similarity, that.similarity)
+            && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(modelId, uri, dimensions, maxInputTokens, similarity, rateLimitSettings);
+    }
+
+}

+ 96 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java

@@ -0,0 +1,96 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.request.completion;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.ByteArrayEntity;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
+
+/**
+ * Llama Chat Completion Request
+ * This class is responsible for creating a request to the Llama chat completion model.
+ * It constructs an HTTP POST request with the necessary headers and body content.
+ */
+public class LlamaChatCompletionRequest implements Request {
+
+    private final LlamaChatCompletionModel model;
+    private final UnifiedChatInput chatInput;
+
+    /**
+     * Constructs a new LlamaChatCompletionRequest with the specified chat input and model.
+     *
+     * @param chatInput the chat input containing the messages and parameters for the completion request
+     * @param model the Llama chat completion model to be used for the request
+     */
+    public LlamaChatCompletionRequest(UnifiedChatInput chatInput, LlamaChatCompletionModel model) {
+        this.chatInput = Objects.requireNonNull(chatInput);
+        this.model = Objects.requireNonNull(model);
+    }
+
+    /**
+     * Returns the chat input for this request.
+     *
+     * @return the chat input containing the messages and parameters
+     */
+    @Override
+    public HttpRequest createHttpRequest() {
+        HttpPost httpPost = new HttpPost(model.uri());
+
+        ByteArrayEntity byteEntity = new ByteArrayEntity(
+            Strings.toString(new LlamaChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8)
+        );
+        httpPost.setEntity(byteEntity);
+
+        httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
+        if (model.getSecretSettings() instanceof DefaultSecretSettings secretSettings) {
+            httpPost.setHeader(createAuthBearerHeader(secretSettings.apiKey()));
+        }
+
+        return new HttpRequest(httpPost, getInferenceEntityId());
+    }
+
+    @Override
+    public URI getURI() {
+        return model.uri();
+    }
+
+    @Override
+    public Request truncate() {
+        // No truncation for Llama chat completions
+        return this;
+    }
+
+    @Override
+    public boolean[] getTruncationInfo() {
+        // No truncation for Llama chat completions
+        return null;
+    }
+
+    @Override
+    public String getInferenceEntityId() {
+        return model.getInferenceEntityId();
+    }
+
+    @Override
+    public boolean isStreaming() {
+        return chatInput.stream();
+    }
+}

+ 47 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.request.completion;
+
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * LlamaChatCompletionRequestEntity is responsible for creating the request entity for Llama chat completion.
+ * It implements ToXContentObject to allow serialization to XContent format.
+ */
+public class LlamaChatCompletionRequestEntity implements ToXContentObject {
+
+    private final LlamaChatCompletionModel model;
+    private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
+
+    /**
+     * Constructs a LlamaChatCompletionRequestEntity with the specified unified chat input and model.
+     *
+     * @param unifiedChatInput the unified chat input containing messages and parameters for the completion request
+     * @param model the Llama chat completion model to be used for the request
+     */
+    public LlamaChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, LlamaChatCompletionModel model) {
+        this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
+        this.model = Objects.requireNonNull(model);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params));
+        builder.endObject();
+        return builder;
+    }
+}

+ 96 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java

@@ -0,0 +1,96 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.request.embeddings;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.ByteArrayEntity;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.common.Truncator;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
+
+/**
+ * Llama Embeddings Request
+ * This class is responsible for creating a request to the Llama embeddings model.
+ * It constructs an HTTP POST request with the necessary headers and body content.
+ */
+public class LlamaEmbeddingsRequest implements Request {
+    private final URI uri;
+    private final LlamaEmbeddingsModel model;
+    private final String inferenceEntityId;
+    private final Truncator.TruncationResult truncationResult;
+    private final Truncator truncator;
+
+    /**
+     * Constructs a new LlamaEmbeddingsRequest with the specified truncator, input, and model.
+     *
+     * @param truncator the truncator to handle input truncation
+     * @param input the input to be truncated
+     * @param model the Llama embeddings model to be used for the request
+     */
+    public LlamaEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, LlamaEmbeddingsModel model) {
+        this.uri = model.uri();
+        this.model = model;
+        this.inferenceEntityId = model.getInferenceEntityId();
+        this.truncator = truncator;
+        this.truncationResult = input;
+    }
+
+    /**
+     * Returns the URI for this request.
+     *
+     * @return the URI of the Llama embeddings model
+     */
+    @Override
+    public HttpRequest createHttpRequest() {
+        HttpPost httpPost = new HttpPost(this.uri);
+
+        ByteArrayEntity byteEntity = new ByteArrayEntity(
+            Strings.toString(new LlamaEmbeddingsRequestEntity(model.getServiceSettings().modelId(), truncationResult.input()))
+                .getBytes(StandardCharsets.UTF_8)
+        );
+        httpPost.setEntity(byteEntity);
+
+        httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
+        if (model.getSecretSettings() instanceof DefaultSecretSettings secretSettings) {
+            httpPost.setHeader(createAuthBearerHeader(secretSettings.apiKey()));
+        }
+
+        return new HttpRequest(httpPost, getInferenceEntityId());
+    }
+
+    @Override
+    public URI getURI() {
+        return uri;
+    }
+
+    @Override
+    public Request truncate() {
+        var truncatedInput = truncator.truncate(truncationResult.input());
+        return new LlamaEmbeddingsRequest(truncator, truncatedInput, model);
+    }
+
+    @Override
+    public boolean[] getTruncationInfo() {
+        return truncationResult.truncated().clone();
+    }
+
+    @Override
+    public String getInferenceEntityId() {
+        return inferenceEntityId;
+    }
+}

+ 51 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java

@@ -0,0 +1,51 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.request.embeddings;
+
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * LlamaEmbeddingsRequestEntity is responsible for creating the request entity for Llama embeddings.
+ * It implements ToXContentObject to allow serialization to XContent format.
+ */
+public record LlamaEmbeddingsRequestEntity(String modelId, List<String> contents) implements ToXContentObject {
+
+    public static final String CONTENTS_FIELD = "contents";
+    public static final String MODEL_ID_FIELD = "model_id";
+
+    /**
+     * Constructs a LlamaEmbeddingsRequestEntity with the specified model ID and contents.
+     *
+     * @param modelId  the ID of the model to use for embeddings
+     * @param contents the list of contents to generate embeddings for
+     */
+    public LlamaEmbeddingsRequestEntity {
+        Objects.requireNonNull(modelId);
+        Objects.requireNonNull(contents);
+    }
+
+    /**
+     * Constructs a LlamaEmbeddingsRequestEntity with the specified model ID and a single content string.
+     */
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        builder.field(MODEL_ID_FIELD, modelId);
+        builder.field(CONTENTS_FIELD, contents);
+
+        builder.endObject();
+
+        return builder;
+    }
+}

+ 59 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java

@@ -0,0 +1,59 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.response;
+
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
+
+import java.nio.charset.StandardCharsets;
+
+/**
+ * LlamaErrorResponse is responsible for handling error responses from Llama inference services.
+ * It extends ErrorResponse to provide specific functionality for Llama errors.
+ * An example error response for Not Found error would look like:
+ * <pre><code>
+ *  {
+ *      "detail": "Not Found"
+ *  }
+ * </code></pre>
+ * An example error response for Bad Request error would look like:
+ * <pre><code>
+ *  {
+ *     "error": {
+ *         "detail": {
+ *             "errors": [
+ *                 {
+ *                     "loc": [
+ *                         "body",
+ *                         "model"
+ *                     ],
+ *                     "msg": "Field required",
+ *                     "type": "missing"
+ *                 }
+ *             ]
+ *         }
+ *     }
+ *  }
+ * </code></pre>
+ */
+public class LlamaErrorResponse extends ErrorResponse {
+
+    public LlamaErrorResponse(String message) {
+        super(message);
+    }
+
+    public static ErrorResponse fromResponse(HttpResult response) {
+        try {
+            String errorMessage = new String(response.body(), StandardCharsets.UTF_8);
+            return new LlamaErrorResponse(errorMessage);
+        } catch (Exception e) {
+            // swallow the error
+        }
+        return ErrorResponse.UNDEFINED_ERROR;
+    }
+}

+ 6 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java

@@ -10,19 +10,21 @@ package org.elasticsearch.xpack.inference.services.mistral;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
+import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor;
 import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
 
 import java.net.URI;
 import java.net.URISyntaxException;
+import java.util.Objects;
 
 /**
  * Represents a Mistral model that can be used for inference tasks.
  * This class extends RateLimitGroupingModel to handle rate limiting based on model and API key.
  */
 public abstract class MistralModel extends RateLimitGroupingModel {
-    protected String model;
     protected URI uri;
     protected RateLimitSettings rateLimitSettings;
 
@@ -34,10 +36,6 @@ public abstract class MistralModel extends RateLimitGroupingModel {
         super(model, serviceSettings);
     }
 
-    public String model() {
-        return this.model;
-    }
-
     public URI uri() {
         return this.uri;
     }
@@ -49,7 +47,7 @@ public abstract class MistralModel extends RateLimitGroupingModel {
 
     @Override
     public int rateLimitGroupingHash() {
-        return 0;
+        return Objects.hash(getServiceSettings().modelId(), getSecretSettings().apiKey());
     }
 
     // Needed for testing only
@@ -65,4 +63,6 @@ public abstract class MistralModel extends RateLimitGroupingModel {
     public DefaultSecretSettings getSecretSettings() {
         return (DefaultSecretSettings) super.getSecretSettings();
     }
+
+    public abstract ExecutableAction accept(MistralActionVisitor creator);
 }

+ 10 - 31
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java

@@ -108,16 +108,10 @@ public class MistralService extends SenderService {
     ) {
         var actionCreator = new MistralActionCreator(getSender(), getServiceComponents());
 
-        switch (model) {
-            case MistralEmbeddingsModel mistralEmbeddingsModel:
-                mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener);
-                break;
-            case MistralChatCompletionModel mistralChatCompletionModel:
-                mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener);
-                break;
-            default:
-                listener.onFailure(createInvalidModelException(model));
-                break;
+        if (model instanceof MistralModel mistralModel) {
+            mistralModel.accept(actionCreator).execute(inputs, timeout, listener);
+        } else {
+            listener.onFailure(createInvalidModelException(model));
         }
     }
 
@@ -172,7 +166,7 @@ public class MistralService extends SenderService {
             ).batchRequestsWithListeners(listener);
 
             for (var request : batchedRequests) {
-                var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
+                var action = mistralEmbeddingsModel.accept(actionCreator);
                 action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
             }
         } else {
@@ -217,7 +211,6 @@ public class MistralService extends SenderService {
                 modelId,
                 taskType,
                 serviceSettingsMap,
-                taskSettingsMap,
                 chunkingSettings,
                 serviceSettingsMap,
                 TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
@@ -242,7 +235,7 @@ public class MistralService extends SenderService {
         Map<String, Object> secrets
     ) {
         Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
-        Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+        removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
         Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
 
         ChunkingSettings chunkingSettings = null;
@@ -254,7 +247,6 @@ public class MistralService extends SenderService {
             modelId,
             taskType,
             serviceSettingsMap,
-            taskSettingsMap,
             chunkingSettings,
             secretSettingsMap,
             parsePersistedConfigErrorMsg(modelId, NAME)
@@ -264,7 +256,7 @@ public class MistralService extends SenderService {
     @Override
     public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
         Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
-        Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+        removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
 
         ChunkingSettings chunkingSettings = null;
         if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@@ -275,7 +267,6 @@ public class MistralService extends SenderService {
             modelId,
             taskType,
             serviceSettingsMap,
-            taskSettingsMap,
             chunkingSettings,
             null,
             parsePersistedConfigErrorMsg(modelId, NAME)
@@ -296,7 +287,6 @@ public class MistralService extends SenderService {
         String modelId,
         TaskType taskType,
         Map<String, Object> serviceSettings,
-        Map<String, Object> taskSettings,
         ChunkingSettings chunkingSettings,
         @Nullable Map<String, Object> secretSettings,
         String failureMessage,
@@ -304,16 +294,7 @@ public class MistralService extends SenderService {
     ) {
         switch (taskType) {
             case TEXT_EMBEDDING:
-                return new MistralEmbeddingsModel(
-                    modelId,
-                    taskType,
-                    NAME,
-                    serviceSettings,
-                    taskSettings,
-                    chunkingSettings,
-                    secretSettings,
-                    context
-                );
+                return new MistralEmbeddingsModel(modelId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context);
             case CHAT_COMPLETION, COMPLETION:
                 return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
             default:
@@ -325,7 +306,6 @@ public class MistralService extends SenderService {
         String inferenceEntityId,
         TaskType taskType,
         Map<String, Object> serviceSettings,
-        Map<String, Object> taskSettings,
         ChunkingSettings chunkingSettings,
         Map<String, Object> secretSettings,
         String failureMessage
@@ -334,7 +314,6 @@ public class MistralService extends SenderService {
             inferenceEntityId,
             taskType,
             serviceSettings,
-            taskSettings,
             chunkingSettings,
             secretSettings,
             failureMessage,
@@ -369,10 +348,10 @@ public class MistralService extends SenderService {
      */
     public static class Configuration {
         public static InferenceServiceConfiguration get() {
-            return configuration.getOrCompute();
+            return CONFIGURATION.getOrCompute();
         }
 
-        private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
+        private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> CONFIGURATION = new LazyInitializable<>(
             () -> {
                 var configurationMap = new HashMap<String, SettingsConfiguration>();
 

+ 1 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java

@@ -24,7 +24,6 @@ import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbe
 import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
 import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
 
-import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.core.Strings.format;
@@ -51,7 +50,7 @@ public class MistralActionCreator implements MistralActionVisitor {
     }
 
     @Override
-    public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
+    public ExecutableAction create(MistralEmbeddingsModel embeddingsModel) {
         var requestManager = new MistralEmbeddingsRequestManager(
             embeddingsModel,
             serviceComponents.truncator(),

+ 1 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java

@@ -11,8 +11,6 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
 
-import java.util.Map;
-
 /**
  * Interface for creating {@link ExecutableAction} instances for Mistral models.
  * <p>
@@ -25,10 +23,9 @@ public interface MistralActionVisitor {
      * Creates an {@link ExecutableAction} for the given {@link MistralEmbeddingsModel}.
      *
      * @param embeddingsModel The model to create the action for.
-     * @param taskSettings    The task settings to use.
      * @return An {@link ExecutableAction} for the given model.
      */
-    ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
+    ExecutableAction create(MistralEmbeddingsModel embeddingsModel);
 
     /**
      * Creates an {@link ExecutableAction} for the given {@link MistralChatCompletionModel}.

+ 2 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java

@@ -22,7 +22,6 @@ import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings
 import java.net.URI;
 import java.net.URISyntaxException;
 import java.util.Map;
-import java.util.Objects;
 
 import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_COMPLETIONS_PATH;
 
@@ -95,23 +94,17 @@ public class MistralChatCompletionModel extends MistralModel {
         DefaultSecretSettings secrets
     ) {
         super(
-            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()),
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE),
             new ModelSecrets(secrets)
         );
         setPropertiesFromServiceSettings(serviceSettings);
     }
 
     private void setPropertiesFromServiceSettings(MistralChatCompletionServiceSettings serviceSettings) {
-        this.model = serviceSettings.modelId();
         this.rateLimitSettings = serviceSettings.rateLimitSettings();
         setEndpointUrl();
     }
 
-    @Override
-    public int rateLimitGroupingHash() {
-        return Objects.hash(model, getSecretSettings().apiKey());
-    }
-
     private void setEndpointUrl() {
         try {
             this.uri = new URI(API_COMPLETIONS_PATH);
@@ -131,6 +124,7 @@ public class MistralChatCompletionModel extends MistralModel {
      * @param creator The visitor that creates the executable action.
      * @return An ExecutableAction that can be executed.
      */
+    @Override
     public ExecutableAction accept(MistralActionVisitor creator) {
         return creator.create(this);
     }

+ 4 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java

@@ -12,7 +12,6 @@ import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
-import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
@@ -37,7 +36,6 @@ public class MistralEmbeddingsModel extends MistralModel {
         TaskType taskType,
         String service,
         Map<String, Object> serviceSettings,
-        Map<String, Object> taskSettings,
         ChunkingSettings chunkingSettings,
         @Nullable Map<String, Object> secrets,
         ConfigurationParseContext context
@@ -47,7 +45,6 @@ public class MistralEmbeddingsModel extends MistralModel {
             taskType,
             service,
             MistralEmbeddingsServiceSettings.fromMap(serviceSettings, context),
-            EmptyTaskSettings.INSTANCE,    // no task settings for Mistral embeddings
             chunkingSettings,
             DefaultSecretSettings.fromMap(secrets)
         );
@@ -59,7 +56,6 @@ public class MistralEmbeddingsModel extends MistralModel {
     }
 
     private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) {
-        this.model = serviceSettings.modelId();
         this.rateLimitSettings = serviceSettings.rateLimitSettings();
         setEndpointUrl();
     }
@@ -77,12 +73,11 @@ public class MistralEmbeddingsModel extends MistralModel {
         TaskType taskType,
         String service,
         MistralEmbeddingsServiceSettings serviceSettings,
-        TaskSettings taskSettings,
         ChunkingSettings chunkingSettings,
         DefaultSecretSettings secrets
     ) {
         super(
-            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings),
             new ModelSecrets(secrets)
         );
         setPropertiesFromServiceSettings(serviceSettings);
@@ -93,7 +88,8 @@ public class MistralEmbeddingsModel extends MistralModel {
         return (MistralEmbeddingsServiceSettings) super.getServiceSettings();
     }
 
-    public ExecutableAction accept(MistralActionVisitor creator, Map<String, Object> taskSettings) {
-        return creator.create(this, taskSettings);
+    @Override
+    public ExecutableAction accept(MistralActionVisitor creator) {
+        return creator.create(this);
     }
 }

+ 3 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java

@@ -178,12 +178,13 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
         return Objects.equals(model, that.model)
             && Objects.equals(dimensions, that.dimensions)
             && Objects.equals(maxInputTokens, that.maxInputTokens)
-            && Objects.equals(similarity, that.similarity);
+            && Objects.equals(similarity, that.similarity)
+            && Objects.equals(rateLimitSettings, that.rateLimitSettings);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(model, dimensions, maxInputTokens, similarity);
+        return Objects.hash(model, dimensions, maxInputTokens, similarity, rateLimitSettings);
     }
 
 }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java

@@ -42,7 +42,7 @@ public class MistralEmbeddingsRequest implements Request {
         HttpPost httpPost = new HttpPost(this.uri);
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
-            Strings.toString(new MistralEmbeddingsRequestEntity(embeddingsModel.model(), truncationResult.input()))
+            Strings.toString(new MistralEmbeddingsRequestEntity(embeddingsModel.getServiceSettings().modelId(), truncationResult.input()))
                 .getBytes(StandardCharsets.UTF_8)
         );
         httpPost.setEntity(byteEntity);

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java

@@ -198,7 +198,7 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
                 PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD));
                 PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD));
                 PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD));
-                PARSER.declareObjectArray(
+                PARSER.declareObjectArrayOrNull(
                     ConstructingObjectParser.optionalConstructorArg(),
                     (p, c) -> ChatCompletionChunkParser.ToolCallParser.parse(p),
                     new ParseField(TOOL_CALLS_FIELD)

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java

@@ -34,7 +34,7 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
         builder.startObject();
         unifiedRequestEntity.toXContent(
             builder,
-            UnifiedCompletionRequest.withMaxCompletionTokensTokens(model.getServiceSettings().modelId(), params)
+            UnifiedCompletionRequest.withMaxCompletionTokens(model.getServiceSettings().modelId(), params)
         );
 
         if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) {

+ 14 - 9
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java

@@ -23,7 +23,6 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
-import org.elasticsearch.xpack.inference.services.custom.CustomModel;
 import org.junit.After;
 import org.junit.Assume;
 import org.junit.Before;
@@ -141,7 +140,7 @@ public abstract class AbstractInferenceServiceTests extends ESTestCase {
             return true;
         }
 
-        protected abstract CustomModel createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure);
+        protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure);
     }
 
     private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() {
@@ -151,7 +150,7 @@ public abstract class AbstractInferenceServiceTests extends ESTestCase {
         }
 
         @Override
-        protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) {
+        protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) {
             throw new UnsupportedOperationException("Update model tests are disabled");
         }
     };
@@ -351,11 +350,17 @@ public abstract class AbstractInferenceServiceTests extends ESTestCase {
 
             assertThat(
                 exception.getMessage(),
-                containsString(Strings.format("service does not support task type [%s]", parseConfigTestConfig.unsupportedTaskType))
+                containsString(
+                    Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType)
+                )
             );
         }
     }
 
+    protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
+        return "service does not support task type [%s]";
+    }
+
     public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
         var parseConfigTestConfig = testConfiguration.commonConfig;
 
@@ -374,7 +379,7 @@ public abstract class AbstractInferenceServiceTests extends ESTestCase {
                 persistedConfigMap.secrets()
             );
 
-            parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
+            parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
         }
     }
 
@@ -396,7 +401,7 @@ public abstract class AbstractInferenceServiceTests extends ESTestCase {
                 persistedConfigMap.secrets()
             );
 
-            parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
+            parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
         }
     }
 
@@ -413,7 +418,7 @@ public abstract class AbstractInferenceServiceTests extends ESTestCase {
 
             var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets());
 
-            parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
+            parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
         }
     }
 
@@ -430,7 +435,7 @@ public abstract class AbstractInferenceServiceTests extends ESTestCase {
 
             var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets());
 
-            parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
+            parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
         }
     }
 
@@ -468,7 +473,7 @@ public abstract class AbstractInferenceServiceTests extends ESTestCase {
                 () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt())
             );
 
-            assertThat(exception.getMessage(), containsString("Can't update embedding details for model of type:"));
+            assertThat(exception.getMessage(), containsString("Can't update embedding details for model"));
         }
     }
 

+ 840 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java

@@ -0,0 +1,840 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ActionTestUtils;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
+import org.elasticsearch.inference.ChunkedInference;
+import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.InferenceServiceConfiguration;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests;
+import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
+import org.elasticsearch.xpack.inference.services.SenderService;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.hamcrest.CoreMatchers;
+import org.hamcrest.Matchers;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.ExceptionsHelper.unwrapCause;
+import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
+import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests.createChatCompletionModel;
+import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettingsTests.getServiceSettingsMap;
+import static org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettingsTests.buildServiceSettingsMap;
+import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.isA;
+import static org.mockito.Mockito.mock;
+
+public class LlamaServiceTests extends AbstractInferenceServiceTests {
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    private final MockWebServer webServer = new MockWebServer();
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    public LlamaServiceTests() {
+        super(createTestConfiguration());
+    }
+
+    private static TestConfiguration createTestConfiguration() {
+        return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) {
+
+            @Override
+            protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+                return LlamaServiceTests.createService(threadPool, clientManager);
+            }
+
+            @Override
+            protected Map<String, Object> createServiceSettingsMap(TaskType taskType) {
+                return LlamaServiceTests.createServiceSettingsMap(taskType);
+            }
+
+            @Override
+            protected Map<String, Object> createTaskSettingsMap() {
+                return new HashMap<>();
+            }
+
+            @Override
+            protected Map<String, Object> createSecretSettingsMap() {
+                return LlamaServiceTests.createSecretSettingsMap();
+            }
+
+            @Override
+            protected void assertModel(Model model, TaskType taskType) {
+                LlamaServiceTests.assertModel(model, taskType);
+            }
+
+            @Override
+            protected EnumSet<TaskType> supportedStreamingTasks() {
+                return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
+            }
+        }).enableUpdateModelTests(new UpdateModelConfiguration() {
+            @Override
+            protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) {
+                return createInternalEmbeddingModel(similarityMeasure);
+            }
+        }).build();
+    }
+
+    private static void assertModel(Model model, TaskType taskType) {
+        switch (taskType) {
+            case TEXT_EMBEDDING -> assertTextEmbeddingModel(model);
+            case COMPLETION -> assertCompletionModel(model);
+            case CHAT_COMPLETION -> assertChatCompletionModel(model);
+            default -> fail("unexpected task type [" + taskType + "]");
+        }
+    }
+
+    private static void assertTextEmbeddingModel(Model model) {
+        var llamaModel = assertCommonModelFields(model);
+
+        assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING));
+    }
+
+    private static LlamaModel assertCommonModelFields(Model model) {
+        assertThat(model, instanceOf(LlamaModel.class));
+
+        var llamaModel = (LlamaModel) model;
+        assertThat(llamaModel.getServiceSettings().modelId(), is("model_id"));
+        assertThat(llamaModel.uri.toString(), Matchers.is("http://www.abc.com"));
+        assertThat(llamaModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE));
+        assertThat(
+            ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(),
+            Matchers.is(new SecureString("secret".toCharArray()))
+        );
+
+        return llamaModel;
+    }
+
+    private static void assertCompletionModel(Model model) {
+        var llamaModel = assertCommonModelFields(model);
+        assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.COMPLETION));
+    }
+
+    private static void assertChatCompletionModel(Model model) {
+        var llamaModel = assertCommonModelFields(model);
+        assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION));
+    }
+
+    public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        return new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty());
+    }
+
+    private static Map<String, Object> createServiceSettingsMap(TaskType taskType) {
+        Map<String, Object> settingsMap = new HashMap<>(
+            Map.of(ServiceFields.URL, "http://www.abc.com", ServiceFields.MODEL_ID, "model_id")
+        );
+
+        if (taskType == TaskType.TEXT_EMBEDDING) {
+            settingsMap.putAll(
+                Map.of(
+                    ServiceFields.SIMILARITY,
+                    SimilarityMeasure.COSINE.toString(),
+                    ServiceFields.DIMENSIONS,
+                    1536,
+                    ServiceFields.MAX_INPUT_TOKENS,
+                    512
+                )
+            );
+        }
+
+        return settingsMap;
+    }
+
+    private static Map<String, Object> createSecretSettingsMap() {
+        return new HashMap<>(Map.of("api_key", "secret"));
+    }
+
+    private static LlamaEmbeddingsModel createInternalEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure) {
+        var inferenceId = "inference_id";
+
+        return new LlamaEmbeddingsModel(
+            inferenceId,
+            TaskType.TEXT_EMBEDDING,
+            LlamaService.NAME,
+            new LlamaEmbeddingsServiceSettings(
+                "model_id",
+                "http://www.abc.com",
+                1536,
+                similarityMeasure,
+                512,
+                new RateLimitSettings(10_000)
+            ),
+            ChunkingSettingsTests.createRandomChunkingSettings(),
+            new DefaultSecretSettings(new SecureString("secret".toCharArray()))
+        );
+    }
+
+    protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
+        return "Failed to parse stored model [id] for [llama] service, please delete and add the service again";
+    }
+
+    @Before
+    public void init() throws Exception {
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
+        try (var service = createService()) {
+            ActionListener<Model> modelVerificationActionListener = ActionListener.wrap(model -> {
+                assertThat(model, instanceOf(LlamaEmbeddingsModel.class));
+
+                var embeddingsModel = (LlamaEmbeddingsModel) model;
+                assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+                assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
+                assertThat(((DefaultSecretSettings) (embeddingsModel.getSecretSettings())).apiKey().toString(), is("secret"));
+            }, e -> fail("parse request should not fail " + e.getMessage()));
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                getRequestConfigMap(
+                    getServiceSettingsMap("model", "url"),
+                    createRandomChunkingSettingsMap(),
+                    getSecretSettingsMap("secret")
+                ),
+                modelVerificationActionListener
+            );
+        }
+    }
+
+    public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
+        try (var service = createService()) {
+            ActionListener<Model> modelVerificationActionListener = ActionListener.wrap(model -> {
+                assertThat(model, instanceOf(LlamaEmbeddingsModel.class));
+
+                var embeddingsModel = (LlamaEmbeddingsModel) model;
+                assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
+                assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
+                assertThat(((DefaultSecretSettings) (embeddingsModel.getSecretSettings())).apiKey().toString(), is("secret"));
+            }, e -> fail("parse request should not fail " + e.getMessage()));
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                getRequestConfigMap(
+                    getServiceSettingsMap("model", "url"),
+                    createRandomChunkingSettingsMap(),
+                    getSecretSettingsMap("secret")
+                ),
+                modelVerificationActionListener
+            );
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOException {
+        var url = "url";
+        var secret = "secret";
+
+        try (var service = createService()) {
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
+                assertThat(m, instanceOf(LlamaChatCompletionModel.class));
+
+                var chatCompletionModel = (LlamaChatCompletionModel) m;
+
+                assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(url));
+                assertNull(chatCompletionModel.getServiceSettings().modelId());
+                assertThat(((DefaultSecretSettings) (chatCompletionModel.getSecretSettings())).apiKey().toString(), is("secret"));
+
+            }, exception -> {
+                assertThat(exception, instanceOf(ValidationException.class));
+                assertThat(
+                    exception.getMessage(),
+                    is("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];")
+                );
+            });
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.CHAT_COMPLETION,
+                getRequestConfigMap(getServiceSettingsMap(null, url), getSecretSettingsMap(secret)),
+                modelVerificationListener
+            );
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsException_WithoutUrl() throws IOException {
+        var model = "model";
+        var secret = "secret";
+
+        try (var service = createService()) {
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
+                assertThat(m, instanceOf(LlamaChatCompletionModel.class));
+
+                var chatCompletionModel = (LlamaChatCompletionModel) m;
+
+                assertThat(chatCompletionModel.getServiceSettings().modelId(), is(model));
+                assertNull(chatCompletionModel.getServiceSettings().modelId());
+                assertThat(((DefaultSecretSettings) (chatCompletionModel.getSecretSettings())).apiKey().toString(), is("secret"));
+
+            }, exception -> {
+                assertThat(exception, instanceOf(ValidationException.class));
+                assertThat(
+                    exception.getMessage(),
+                    is("Validation Failed: 1: [service_settings] does not contain the required setting [url];")
+                );
+            });
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.CHAT_COMPLETION,
+                getRequestConfigMap(getServiceSettingsMap(model, null), getSecretSettingsMap(secret)),
+                modelVerificationListener
+            );
+        }
+    }
+
+    public void testUnifiedCompletionInfer() throws Exception {
+        // The escapes are because the streaming response must be on a single line
+        String responseJson = """
+            data: {\
+                "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\
+                "choices": [{\
+                        "delta": {\
+                            "content": "Deep",\
+                            "function_call": null,\
+                            "refusal": null,\
+                            "role": "assistant",\
+                            "tool_calls": null\
+                        },\
+                        "finish_reason": null,\
+                        "index": 0,\
+                        "logprobs": null\
+                    }\
+                ],\
+                "created": 1750158492,\
+                "model": "llama3.2:3b",\
+                "object": "chat.completion.chunk",\
+                "service_tier": null,\
+                "system_fingerprint": "fp_ollama",\
+                "usage": null\
+            }
+
+            """;
+        webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
+            var model = createChatCompletionModel("model", getUrl(webServer), "secret");
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            service.unifiedCompletionInfer(
+                model,
+                UnifiedCompletionRequest.of(
+                    List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
+                ),
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            var result = listener.actionGet(TIMEOUT);
+            InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace("""
+                {
+                    "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",
+                    "choices": [{
+                            "delta": {
+                                "content": "Deep",
+                                "role": "assistant"
+                            },
+                            "index": 0
+                        }
+                    ],
+                    "model": "llama3.2:3b",
+                    "object": "chat.completion.chunk"
+                }
+                """));
+        }
+    }
+
+    public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception {
+        String responseJson = """
+            {
+                "detail": "Not Found"
+            }
+            """;
+        webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
+
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
+            var model = LlamaChatCompletionModelTests.createChatCompletionModel("model", getUrl(webServer), "secret");
+            var latch = new CountDownLatch(1);
+            service.unifiedCompletionInfer(
+                model,
+                UnifiedCompletionRequest.of(
+                    List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
+                ),
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
+                    try (var builder = XContentFactory.jsonBuilder()) {
+                        var t = unwrapCause(e);
+                        assertThat(t, isA(UnifiedChatCompletionException.class));
+                        ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
+                            try {
+                                xContent.toXContent(builder, EMPTY_PARAMS);
+                            } catch (IOException ex) {
+                                throw new RuntimeException(ex);
+                            }
+                        });
+                        var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
+                        assertThat(json, is(String.format(Locale.ROOT, XContentHelper.stripWhitespace("""
+                            {
+                              "error" : {
+                                "code" : "not_found",
+                                "message" : "Resource not found at [%s] for request from inference entity id [id] status \
+                            [404]. Error message: [{\\n    \\"detail\\": \\"Not Found\\"\\n}\\n]",
+                                "type" : "llama_error"
+                              }
+                            }"""), getUrl(webServer))));
+                    } catch (IOException ex) {
+                        throw new RuntimeException(ex);
+                    }
+                }), latch::countDown)
+            );
+            assertTrue(latch.await(30, TimeUnit.SECONDS));
+        }
+    }
+
+    public void testMidStreamUnifiedCompletionError() throws Exception {
+        String responseJson = """
+            data: {"error": {"message": "400: Invalid value: Model 'llama3.12:3b' not found"}}
+
+            """;
+        webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+        testStreamError(XContentHelper.stripWhitespace("""
+            {
+                  "error": {
+                      "code": "stream_error",
+                      "message": "Received an error response for request from inference entity id [id].\
+             Error message: [400: Invalid value: Model 'llama3.12:3b' not found]",
+                      "type": "llama_error"
+                  }
+              }
+            """));
+    }
+
+    public void testInfer_StreamRequest() throws Exception {
+        String responseJson = """
+            data: {\
+                "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\
+                "choices": [{\
+                        "delta": {\
+                            "content": "Deep",\
+                            "function_call": null,\
+                            "refusal": null,\
+                            "role": "assistant",\
+                            "tool_calls": null\
+                        },\
+                        "finish_reason": null,\
+                        "index": 0,\
+                        "logprobs": null\
+                    }\
+                ],\
+                "created": 1750158492,\
+                "model": "llama3.2:3b",\
+                "object": "chat.completion.chunk",\
+                "service_tier": null,\
+                "system_fingerprint": "fp_ollama",\
+                "usage": null\
+            }
+
+            """;
+        webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+        streamCompletion().hasNoErrors().hasEvent("""
+            {"completion":[{"delta":"Deep"}]}""");
+    }
+
+    private void testStreamError(String expectedResponse) throws Exception {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
+            var model = LlamaChatCompletionModelTests.createChatCompletionModel("model", getUrl(webServer), "secret");
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            service.unifiedCompletionInfer(
+                model,
+                UnifiedCompletionRequest.of(
+                    List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
+                ),
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            var result = listener.actionGet(TIMEOUT);
+
+            InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
+                e = unwrapCause(e);
+                assertThat(e, isA(UnifiedChatCompletionException.class));
+                try (var builder = XContentFactory.jsonBuilder()) {
+                    ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
+                        try {
+                            xContent.toXContent(builder, EMPTY_PARAMS);
+                        } catch (IOException ex) {
+                            throw new RuntimeException(ex);
+                        }
+                    });
+                    var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
+
+                    assertThat(json, is(expectedResponse));
+                }
+            });
+        }
+    }
+
+    public void testInfer_StreamRequest_ErrorResponse() {
+        String responseJson = """
+            {
+                "detail": "Not Found"
+            }""";
+        webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
+
+        var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion);
+        assertThat(e.status(), equalTo(RestStatus.NOT_FOUND));
+        assertThat(e.getMessage(), equalTo(String.format(Locale.ROOT, """
+            Resource not found at [%s] for request from inference entity id [id] status [404]. Error message: [{
+                "detail": "Not Found"
+            }]""", getUrl(webServer))));
+    }
+
+    public void testInfer_StreamRequestRetry() throws Exception {
+        webServer.enqueue(new MockResponse().setResponseCode(503).setBody("""
+            {
+              "error": {
+                "message": "server busy"
+              }
+            }"""));
+        webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
+            data: {\
+                "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\
+                "choices": [{\
+                        "delta": {\
+                            "content": "Deep",\
+                            "function_call": null,\
+                            "refusal": null,\
+                            "role": "assistant",\
+                            "tool_calls": null\
+                        },\
+                        "finish_reason": null,\
+                        "index": 0,\
+                        "logprobs": null\
+                    }\
+                ],\
+                "created": 1750158492,\
+                "model": "llama3.2:3b",\
+                "object": "chat.completion.chunk",\
+                "service_tier": null,\
+                "system_fingerprint": "fp_ollama",\
+                "usage": null\
+            }
+
+            """));
+
+        streamCompletion().hasNoErrors().hasEvent("""
+            {"completion":[{"delta":"Deep"}]}""");
+    }
+
+    public void testSupportsStreaming() throws IOException {
+        try (var service = new LlamaService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) {
+            assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)));
+            assertFalse(service.canStream(TaskType.ANY));
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException {
+        try (var service = createService()) {
+            var secretSettings = getSecretSettingsMap("secret");
+            secretSettings.put("extra_key", "value");
+
+            var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings);
+
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(
+                model -> fail("Expected exception, but got model: " + model),
+                exception -> {
+                    assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+                    assertThat(
+                        exception.getMessage(),
+                        is("Configuration contains settings [{extra_key=value}] unknown to the [llama] service")
+                    );
+                }
+            );
+
+            service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
+        }
+    }
+
+    public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
+        var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("id", "url", "api_key");
+        model.setURI(getUrl(webServer));
+
+        testChunkedInfer(model);
+    }
+
+    public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
+        var model = LlamaEmbeddingsModelTests.createEmbeddingsModelWithChunkingSettings("id", "url", "api_key");
+        model.setURI(getUrl(webServer));
+
+        testChunkedInfer(model);
+    }
+
+    public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
+
+            String responseJson = """
+                {
+                    "embeddings": [
+                        [
+                            0.010060793,
+                            -0.0017529363
+                        ],
+                        [
+                            0.110060793,
+                            -0.1017529363
+                        ]
+                    ]
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
+            service.chunkedInfer(
+                model,
+                null,
+                List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")),
+                new HashMap<>(),
+                InputType.INTERNAL_INGEST,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            var results = listener.actionGet(TIMEOUT);
+
+            assertThat(results, hasSize(2));
+            {
+                assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
+                var floatResult = (ChunkedInferenceEmbedding) results.get(0);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+                assertTrue(
+                    Arrays.equals(
+                        new float[] { 0.010060793f, -0.0017529363f },
+                        ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+                    )
+                );
+            }
+            {
+                assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
+                var floatResult = (ChunkedInferenceEmbedding) results.get(1);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+                assertTrue(
+                    Arrays.equals(
+                        new float[] { 0.110060793f, -0.1017529363f },
+                        ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+                    )
+                );
+            }
+
+            assertThat(webServer.requests(), hasSize(1));
+            assertNull(webServer.requests().get(0).getUri().getQuery());
+            assertThat(
+                webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
+                equalTo(XContentType.JSON.mediaTypeWithoutParameters())
+            );
+            assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer api_key"));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            assertThat(requestMap.size(), Matchers.is(2));
+            assertThat(requestMap.get("contents"), Matchers.is(List.of("abc", "def")));
+            assertThat(requestMap.get("model_id"), Matchers.is("id"));
+        }
+    }
+
+    public void testGetConfiguration() throws Exception {
+        try (var service = createService()) {
+            String content = XContentHelper.stripWhitespace("""
+                {
+                       "service": "llama",
+                       "name": "Llama",
+                       "task_types": ["text_embedding", "completion", "chat_completion"],
+                       "configurations": {
+                           "api_key": {
+                               "description": "API Key for the provider you're connecting to.",
+                               "label": "API Key",
+                               "required": true,
+                               "sensitive": true,
+                               "updatable": true,
+                               "type": "str",
+                               "supported_task_types": ["text_embedding", "completion", "chat_completion"]
+                           },
+                           "model_id": {
+                               "description": "Refer to the Llama models documentation for the list of available models.",
+                               "label": "Model",
+                               "required": true,
+                               "sensitive": false,
+                               "updatable": false,
+                               "type": "str",
+                               "supported_task_types": ["text_embedding", "completion", "chat_completion"]
+                           },
+                           "rate_limit.requests_per_minute": {
+                               "description": "Minimize the number of rate limit errors.",
+                               "label": "Rate Limit",
+                               "required": false,
+                               "sensitive": false,
+                               "updatable": false,
+                               "type": "int",
+                               "supported_task_types": ["text_embedding", "completion", "chat_completion"]
+                           },
+                           "url": {
+                               "description": "The URL endpoint to use for the requests.",
+                               "label": "URL",
+                               "required": true,
+                               "sensitive": false,
+                               "updatable": false,
+                               "type": "str",
+                               "supported_task_types": ["text_embedding", "completion", "chat_completion"]
+                           }
+                       }
+                   }
+                """);
+            InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes(
+                new BytesArray(content),
+                XContentType.JSON
+            );
+            boolean humanReadable = true;
+            BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable);
+            InferenceServiceConfiguration serviceConfiguration = service.getConfiguration();
+            assertToXContentEquivalent(
+                originalBytes,
+                toXContent(serviceConfiguration, XContentType.JSON, humanReadable),
+                XContentType.JSON
+            );
+        }
+    }
+
+    private InferenceEventsAssertion streamCompletion() throws Exception {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
+            var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret");
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            service.infer(
+                model,
+                null,
+                null,
+                null,
+                List.of("abc"),
+                true,
+                new HashMap<>(),
+                InputType.INGEST,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
+        }
+    }
+
+    private LlamaService createService() {
+        return new LlamaService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty());
+    }
+
+    private Map<String, Object> getRequestConfigMap(
+        Map<String, Object> serviceSettings,
+        Map<String, Object> chunkingSettings,
+        Map<String, Object> secretSettings
+    ) {
+        var requestConfigMap = getRequestConfigMap(serviceSettings, secretSettings);
+        requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings);
+
+        return requestConfigMap;
+    }
+
+    private Map<String, Object> getRequestConfigMap(Map<String, Object> serviceSettings, Map<String, Object> secretSettings) {
+        var builtServiceSettings = new HashMap<>();
+        builtServiceSettings.putAll(serviceSettings);
+        builtServiceSettings.putAll(secretSettings);
+
+        return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings));
+    }
+
+    private static Map<String, Object> getEmbeddingsServiceSettingsMap() {
+        return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null);
+    }
+}

+ 283 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java

@@ -0,0 +1,283 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.action;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
+import org.elasticsearch.xpack.inference.InputTypeTests;
+import org.elasticsearch.xpack.inference.common.TruncatorTests;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
+import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
+import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class LlamaActionCreatorTests extends ESTestCase {
+
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    private final MockWebServer webServer = new MockWebServer();
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    @Before
+    public void init() throws Exception {
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = createSender(senderFactory)) {
+            sender.start();
+
+            String responseJson = """
+                {
+                    "embeddings": [
+                        [
+                            -0.0123,
+                            0.123
+                        ]
+                    ]
+                {
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool));
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
+
+            assertEmbeddingsRequest();
+        }
+    }
+
+    public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws IOException {
+        var settings = buildSettingsWithRetryFields(
+            TimeValue.timeValueMillis(1),
+            TimeValue.timeValueMinutes(1),
+            TimeValue.timeValueSeconds(0)
+        );
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
+
+        try (var sender = createSender(senderFactory)) {
+            sender.start();
+
+            String responseJson = """
+                [
+                    {
+                        "embeddings": [
+                            [
+                                -0.0123,
+                                0.123
+                            ]
+                        ]
+                    {
+                ]
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool));
+
+            var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+            assertThat(
+                thrownException.getMessage(),
+                is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]")
+            );
+
+            assertEmbeddingsRequest();
+        }
+    }
+
+    public void testExecute_ReturnsSuccessfulResponse_ForCompletionAction() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = createSender(senderFactory)) {
+            sender.start();
+
+            String responseJson = """
+                {
+                    "id": "chatcmpl-03e70a75-efb6-447d-b661-e5ed0bd59ce9",
+                    "choices": [
+                        {
+                            "finish_reason": "length",
+                            "index": 0,
+                            "logprobs": null,
+                            "message": {
+                                "content": "Hello there, how may I assist you today?",
+                                "refusal": null,
+                                "role": "assistant",
+                                "annotations": null,
+                                "audio": null,
+                                "function_call": null,
+                                "tool_calls": null
+                            }
+                        }
+                    ],
+                    "created": 1750157476,
+                    "model": "llama3.2:3b",
+                    "object": "chat.completion",
+                    "service_tier": null,
+                    "system_fingerprint": "fp_ollama",
+                    "usage": {
+                        "completion_tokens": 10,
+                        "prompt_tokens": 30,
+                        "total_tokens": 40,
+                        "completion_tokens_details": null,
+                        "prompt_tokens_details": null
+                    }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            PlainActionFuture<InferenceServiceResults> listener = createCompletionFuture(sender, createWithEmptySettings(threadPool));
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));
+
+            assertCompletionRequest();
+        }
+    }
+
+    public void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction() throws IOException {
+        var settings = buildSettingsWithRetryFields(
+            TimeValue.timeValueMillis(1),
+            TimeValue.timeValueMinutes(1),
+            TimeValue.timeValueSeconds(0)
+        );
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
+
+        try (var sender = createSender(senderFactory)) {
+            sender.start();
+
+            String responseJson = """
+                {
+                    "invalid_field": "unexpected"
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            PlainActionFuture<InferenceServiceResults> listener = createCompletionFuture(
+                sender,
+                new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
+            );
+
+            var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+            assertThat(
+                thrownException.getMessage(),
+                is("Failed to send Llama completion request from inference entity id [id]. Cause: Required [choices]")
+            );
+
+            assertCompletionRequest();
+        }
+    }
+
+    private PlainActionFuture<InferenceServiceResults> createEmbeddingsFuture(Sender sender, ServiceComponents threadPool) {
+        var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret");
+        var actionCreator = new LlamaActionCreator(sender, threadPool);
+        var action = actionCreator.create(model);
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+        return listener;
+    }
+
+    private PlainActionFuture<InferenceServiceResults> createCompletionFuture(Sender sender, ServiceComponents threadPool) {
+        var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret");
+        var actionCreator = new LlamaActionCreator(sender, threadPool);
+        var action = actionCreator.create(model);
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+        return listener;
+    }
+
+    private void assertCompletionRequest() throws IOException {
+        assertCommonRequestProperties();
+
+        var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+        assertThat(requestMap.size(), is(4));
+        assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
+        assertThat(requestMap.get("model"), is("model"));
+        assertThat(requestMap.get("n"), is(1));
+        assertThat(requestMap.get("stream"), is(false));
+    }
+
+    @SuppressWarnings("unchecked")
+    private void assertEmbeddingsRequest() throws IOException {
+        assertCommonRequestProperties();
+
+        var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+        assertThat(requestMap.size(), is(2));
+        assertThat(requestMap.get("contents"), instanceOf(List.class));
+        var inputList = (List<String>) requestMap.get("contents");
+        assertThat(inputList, contains("abc"));
+    }
+
+    private void assertCommonRequestProperties() {
+        assertThat(webServer.requests(), hasSize(1));
+        assertNull(webServer.requests().get(0).getUri().getQuery());
+        assertThat(
+            webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
+            equalTo(XContentType.JSON.mediaTypeWithoutParameters())
+        );
+        assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+    }
+}

+ 142 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java

@@ -0,0 +1,142 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.completion;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.inference.EmptySecretSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.util.List;
+
+import static org.hamcrest.Matchers.is;
+
+public class LlamaChatCompletionModelTests extends ESTestCase {
+
+    public static LlamaChatCompletionModel createCompletionModel(String modelId, String url, String apiKey) {
+        return new LlamaChatCompletionModel(
+            "id",
+            TaskType.COMPLETION,
+            "llama",
+            new LlamaChatCompletionServiceSettings(modelId, url, null),
+            new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+        );
+    }
+
+    public static LlamaChatCompletionModel createChatCompletionModel(String modelId, String url, String apiKey) {
+        return new LlamaChatCompletionModel(
+            "id",
+            TaskType.CHAT_COMPLETION,
+            "llama",
+            new LlamaChatCompletionServiceSettings(modelId, url, null),
+            new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+        );
+    }
+
+    public static LlamaChatCompletionModel createChatCompletionModelNoAuth(String modelId, String url) {
+        return new LlamaChatCompletionModel(
+            "id",
+            TaskType.CHAT_COMPLETION,
+            "llama",
+            new LlamaChatCompletionServiceSettings(modelId, url, null),
+            EmptySecretSettings.INSTANCE
+        );
+    }
+
+    public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() {
+        var model = createCompletionModel("model_name", "url", "api_key");
+        var request = new UnifiedCompletionRequest(
+            List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
+            "model_name",
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+
+        var overriddenModel = LlamaChatCompletionModel.of(model, request);
+
+        assertThat(overriddenModel, is(model));
+    }
+
+    public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() {
+        var model = createCompletionModel("model_name", "url", "api_key");
+        var request = new UnifiedCompletionRequest(
+            List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
+            "different_model",
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+
+        var overriddenModel = LlamaChatCompletionModel.of(model, request);
+
+        assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
+    }
+
+    public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() {
+        var model = createCompletionModel(null, "url", "api_key");
+        var request = new UnifiedCompletionRequest(
+            List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
+            "different_model",
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+
+        var overriddenModel = LlamaChatCompletionModel.of(model, request);
+
+        assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
+    }
+
+    public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() {
+        var model = createCompletionModel(null, "url", "api_key");
+        var request = new UnifiedCompletionRequest(
+            List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+
+        var overriddenModel = LlamaChatCompletionModel.of(model, request);
+
+        assertNull(overriddenModel.getServiceSettings().modelId());
+    }
+
+    public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() {
+        var model = createCompletionModel("model_name", "url", "api_key");
+        var request = new UnifiedCompletionRequest(
+            List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
+            null, // not overriding model
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+
+        var overriddenModel = LlamaChatCompletionModel.of(model, request);
+
+        assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name"));
+    }
+}

+ 162 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java

@@ -0,0 +1,162 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.completion;
+
+import org.apache.http.HttpResponse;
+import org.apache.http.StatusLine;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+
+import static org.elasticsearch.ExceptionsHelper.unwrapCause;
+import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.isA;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class LlamaChatCompletionResponseHandlerTests extends ESTestCase {
+    private final LlamaChatCompletionResponseHandler responseHandler = new LlamaChatCompletionResponseHandler(
+        "chat completions",
+        (a, b) -> mock()
+    );
+
+    public void testFailNotFound() throws IOException {
+        var responseJson = XContentHelper.stripWhitespace("""
+            {
+                "detail": "Not Found"
+            }
+            """);
+
+        var errorJson = invalidResponseJson(responseJson, 404);
+
+        assertThat(errorJson, is(XContentHelper.stripWhitespace("""
+            {
+              "error" : {
+                "code" : "not_found",
+                "message" : "Resource not found at [https://api.llama.ai/v1/chat/completions] for request from inference entity id [id] \
+            status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]",
+                "type" : "llama_error"
+              }
+            }""")));
+    }
+
+    public void testFailBadRequest() throws IOException {
+        var responseJson = XContentHelper.stripWhitespace("""
+            {
+                "error": {
+                    "detail": {
+                        "errors": [{
+                                "loc": [
+                                    "body",
+                                    "messages"
+                                ],
+                                "msg": "Field required",
+                                "type": "missing"
+                            }
+                        ]
+                    }
+                }
+            }
+            """);
+
+        var errorJson = invalidResponseJson(responseJson, 400);
+
+        assertThat(errorJson, is(XContentHelper.stripWhitespace("""
+            {
+                "error": {
+                    "code": "bad_request",
+                    "message": "Received a bad request status code for request from inference entity id [id] status [400].\
+             Error message: [{\\"error\\":{\\"detail\\":{\\"errors\\":[{\\"loc\\":[\\"body\\",\\"messages\\"],\\"msg\\":\\"Field\
+             required\\",\\"type\\":\\"missing\\"}]}}}]",
+                    "type": "llama_error"
+                }
+            }
+            """)));
+    }
+
+    public void testFailValidationWithInvalidJson() throws IOException {
+        var responseJson = """
+            what? this isn't a json
+            """;
+
+        var errorJson = invalidResponseJson(responseJson, 500);
+
+        assertThat(errorJson, is(XContentHelper.stripWhitespace("""
+            {
+                "error": {
+                    "code": "bad_request",
+                    "message": "Received a server error status code for request from inference entity id [id] status [500]. Error message: \
+            [what? this isn't a json\\n]",
+                    "type": "llama_error"
+                }
+            }
+            """)));
+    }
+
+    private String invalidResponseJson(String responseJson, int statusCode) throws IOException {
+        var exception = invalidResponse(responseJson, statusCode);
+        assertThat(exception, isA(RetryException.class));
+        assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
+        return toJson((UnifiedChatCompletionException) unwrapCause(exception));
+    }
+
+    private Exception invalidResponse(String responseJson, int statusCode) {
+        return expectThrows(
+            RetryException.class,
+            () -> responseHandler.validateResponse(
+                mock(),
+                mock(),
+                mockRequest(),
+                new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
+                true
+            )
+        );
+    }
+
+    private static Request mockRequest() throws URISyntaxException {
+        var request = mock(Request.class);
+        when(request.getInferenceEntityId()).thenReturn("id");
+        when(request.isStreaming()).thenReturn(true);
+        when(request.getURI()).thenReturn(new URI("https://api.llama.ai/v1/chat/completions"));
+        return request;
+    }
+
+    private static HttpResponse mockErrorResponse(int statusCode) {
+        var statusLine = mock(StatusLine.class);
+        when(statusLine.getStatusCode()).thenReturn(statusCode);
+
+        var response = mock(HttpResponse.class);
+        when(response.getStatusLine()).thenReturn(statusLine);
+
+        return response;
+    }
+
+    private String toJson(UnifiedChatCompletionException e) throws IOException {
+        try (var builder = XContentFactory.jsonBuilder()) {
+            e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
+                try {
+                    xContent.toXContent(builder, EMPTY_PARAMS);
+                } catch (IOException ex) {
+                    throw new RuntimeException(ex);
+                }
+            });
+            return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
+        }
+    }
+}

+ 198 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java

@@ -0,0 +1,198 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.completion;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class LlamaChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase<LlamaChatCompletionServiceSettings> {
+
+    public static final String MODEL_ID = "some model";
+    public static final String CORRECT_URL = "https://www.elastic.co";
+    public static final int RATE_LIMIT = 2;
+
+    public void testFromMap_AllFields_Success() {
+        var serviceSettings = LlamaChatCompletionServiceSettings.fromMap(
+            new HashMap<>(
+                Map.of(
+                    ServiceFields.MODEL_ID,
+                    MODEL_ID,
+                    ServiceFields.URL,
+                    CORRECT_URL,
+                    RateLimitSettings.FIELD_NAME,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                )
+            ),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, new RateLimitSettings(RATE_LIMIT))));
+    }
+
+    public void testFromMap_MissingModelId_ThrowsException() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaChatCompletionServiceSettings.fromMap(
+                new HashMap<>(
+                    Map.of(
+                        ServiceFields.URL,
+                        CORRECT_URL,
+                        RateLimitSettings.FIELD_NAME,
+                        new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                    )
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];")
+        );
+    }
+
+    public void testFromMap_MissingUrl_ThrowsException() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaChatCompletionServiceSettings.fromMap(
+                new HashMap<>(
+                    Map.of(
+                        ServiceFields.MODEL_ID,
+                        MODEL_ID,
+                        RateLimitSettings.FIELD_NAME,
+                        new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                    )
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];")
+        );
+    }
+
+    public void testFromMap_MissingRateLimit_Success() {
+        var serviceSettings = LlamaChatCompletionServiceSettings.fromMap(
+            new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, null)));
+    }
+
+    public void testToXContent_WritesAllValues() throws IOException {
+        var serviceSettings = LlamaChatCompletionServiceSettings.fromMap(
+            new HashMap<>(
+                Map.of(
+                    ServiceFields.MODEL_ID,
+                    MODEL_ID,
+                    ServiceFields.URL,
+                    CORRECT_URL,
+                    RateLimitSettings.FIELD_NAME,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                )
+            ),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        serviceSettings.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "model_id": "some model",
+                "url": "https://www.elastic.co",
+                "rate_limit": {
+                    "requests_per_minute": 2
+                }
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException {
+        var serviceSettings = LlamaChatCompletionServiceSettings.fromMap(
+            new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        serviceSettings.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "model_id": "some model",
+                "url": "https://www.elastic.co",
+                "rate_limit": {
+                    "requests_per_minute": 3000
+                }
+            }
+            """);
+        assertThat(xContentResult, is(expected));
+    }
+
+    @Override
+    protected Writeable.Reader<LlamaChatCompletionServiceSettings> instanceReader() {
+        return LlamaChatCompletionServiceSettings::new;
+    }
+
+    @Override
+    protected LlamaChatCompletionServiceSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected LlamaChatCompletionServiceSettings mutateInstance(LlamaChatCompletionServiceSettings instance) throws IOException {
+        return randomValueOtherThan(instance, LlamaChatCompletionServiceSettingsTests::createRandom);
+    }
+
+    @Override
+    protected LlamaChatCompletionServiceSettings mutateInstanceForVersion(
+        LlamaChatCompletionServiceSettings instance,
+        TransportVersion version
+    ) {
+        return instance;
+    }
+
+    private static LlamaChatCompletionServiceSettings createRandom() {
+        var modelId = randomAlphaOfLength(8);
+        var url = randomAlphaOfLength(15);
+        return new LlamaChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom());
+    }
+
+    public static Map<String, Object> getServiceSettingsMap(String model, String url) {
+        var map = new HashMap<String, Object>();
+
+        map.put(ServiceFields.MODEL_ID, model);
+        map.put(ServiceFields.URL, url);
+
+        return map;
+    }
+}

+ 51 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java

@@ -0,0 +1,51 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.embeddings;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.inference.EmptySecretSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings;
+
+public class LlamaEmbeddingsModelTests extends ESTestCase {
+    public static LlamaEmbeddingsModel createEmbeddingsModel(String modelId, String url, String apiKey) {
+        return new LlamaEmbeddingsModel(
+            "id",
+            TaskType.TEXT_EMBEDDING,
+            "llama",
+            new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null),
+            null,
+            new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+        );
+    }
+
+    public static LlamaEmbeddingsModel createEmbeddingsModelWithChunkingSettings(String modelId, String url, String apiKey) {
+        return new LlamaEmbeddingsModel(
+            "id",
+            TaskType.TEXT_EMBEDDING,
+            "llama",
+            new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null),
+            createRandomChunkingSettings(),
+            new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+        );
+    }
+
+    public static LlamaEmbeddingsModel createEmbeddingsModelNoAuth(String modelId, String url) {
+        return new LlamaEmbeddingsModel(
+            "id",
+            TaskType.TEXT_EMBEDDING,
+            "llama",
+            new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null),
+            null,
+            EmptySecretSettings.INSTANCE
+        );
+    }
+}

+ 479 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java

@@ -0,0 +1,479 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.embeddings;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.ByteArrayStreamInput;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
+import org.hamcrest.CoreMatchers;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class LlamaEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<LlamaEmbeddingsServiceSettings> {
+    private static final String MODEL_ID = "some model";
+    private static final String CORRECT_URL = "https://www.elastic.co";
+    private static final int DIMENSIONS = 384;
+    private static final SimilarityMeasure SIMILARITY_MEASURE = SimilarityMeasure.DOT_PRODUCT;
+    private static final int MAX_INPUT_TOKENS = 128;
+    private static final int RATE_LIMIT = 2;
+
+    public void testFromMap_AllFields_Success() {
+        var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap(
+            buildServiceSettingsMap(
+                MODEL_ID,
+                CORRECT_URL,
+                SIMILARITY_MEASURE.toString(),
+                DIMENSIONS,
+                MAX_INPUT_TOKENS,
+                new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+            ),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(
+            serviceSettings,
+            is(
+                new LlamaEmbeddingsServiceSettings(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    DIMENSIONS,
+                    SIMILARITY_MEASURE,
+                    MAX_INPUT_TOKENS,
+                    new RateLimitSettings(RATE_LIMIT)
+                )
+            )
+        );
+    }
+
+    public void testFromMap_NoModelId_Failure() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaEmbeddingsServiceSettings.fromMap(
+                buildServiceSettingsMap(
+                    null,
+                    CORRECT_URL,
+                    SIMILARITY_MEASURE.toString(),
+                    DIMENSIONS,
+                    MAX_INPUT_TOKENS,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];")
+        );
+    }
+
+    public void testFromMap_NoUrl_Failure() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaEmbeddingsServiceSettings.fromMap(
+                buildServiceSettingsMap(
+                    MODEL_ID,
+                    null,
+                    SIMILARITY_MEASURE.toString(),
+                    DIMENSIONS,
+                    MAX_INPUT_TOKENS,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];")
+        );
+    }
+
+    public void testFromMap_EmptyUrl_Failure() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaEmbeddingsServiceSettings.fromMap(
+                buildServiceSettingsMap(
+                    MODEL_ID,
+                    "",
+                    SIMILARITY_MEASURE.toString(),
+                    DIMENSIONS,
+                    MAX_INPUT_TOKENS,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] Invalid value empty string. [url] must be a non-empty string;")
+        );
+    }
+
+    public void testFromMap_InvalidUrl_Failure() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaEmbeddingsServiceSettings.fromMap(
+                buildServiceSettingsMap(
+                    MODEL_ID,
+                    "^^^",
+                    SIMILARITY_MEASURE.toString(),
+                    DIMENSIONS,
+                    MAX_INPUT_TOKENS,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+        assertThat(
+            thrownException.getMessage(),
+            containsString(
+                "Validation Failed: 1: [service_settings] Invalid url [^^^] received for field [url]. "
+                    + "Error: unable to parse url [^^^]. Reason: Illegal character in path;"
+            )
+        );
+    }
+
+    public void testFromMap_NoSimilarity_Success() {
+        var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap(
+            buildServiceSettingsMap(
+                MODEL_ID,
+                CORRECT_URL,
+                null,
+                DIMENSIONS,
+                MAX_INPUT_TOKENS,
+                new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+            ),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(
+            serviceSettings,
+            is(
+                new LlamaEmbeddingsServiceSettings(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    DIMENSIONS,
+                    null,
+                    MAX_INPUT_TOKENS,
+                    new RateLimitSettings(RATE_LIMIT)
+                )
+            )
+        );
+    }
+
+    public void testFromMap_InvalidSimilarity_Failure() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaEmbeddingsServiceSettings.fromMap(
+                buildServiceSettingsMap(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    "by_size",
+                    DIMENSIONS,
+                    MAX_INPUT_TOKENS,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+        assertThat(
+            thrownException.getMessage(),
+            containsString(
+                "Validation Failed: 1: [service_settings] Invalid value [by_size] received. "
+                    + "[similarity] must be one of [cosine, dot_product, l2_norm];"
+            )
+        );
+    }
+
+    public void testFromMap_NoDimensions_Success() {
+        var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap(
+            buildServiceSettingsMap(
+                MODEL_ID,
+                CORRECT_URL,
+                SIMILARITY_MEASURE.toString(),
+                null,
+                MAX_INPUT_TOKENS,
+                new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+            ),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(
+            serviceSettings,
+            is(
+                new LlamaEmbeddingsServiceSettings(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    null,
+                    SIMILARITY_MEASURE,
+                    MAX_INPUT_TOKENS,
+                    new RateLimitSettings(RATE_LIMIT)
+                )
+            )
+        );
+    }
+
+    public void testFromMap_ZeroDimensions_Failure() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaEmbeddingsServiceSettings.fromMap(
+                buildServiceSettingsMap(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    SIMILARITY_MEASURE.toString(),
+                    0,
+                    MAX_INPUT_TOKENS,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;")
+        );
+    }
+
+    public void testFromMap_NegativeDimensions_Failure() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaEmbeddingsServiceSettings.fromMap(
+                buildServiceSettingsMap(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    SIMILARITY_MEASURE.toString(),
+                    -10,
+                    MAX_INPUT_TOKENS,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [dimensions] must be a positive integer;")
+        );
+    }
+
+    public void testFromMap_NoInputTokens_Success() {
+        var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap(
+            buildServiceSettingsMap(
+                MODEL_ID,
+                CORRECT_URL,
+                SIMILARITY_MEASURE.toString(),
+                DIMENSIONS,
+                null,
+                new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+            ),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(
+            serviceSettings,
+            is(
+                new LlamaEmbeddingsServiceSettings(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    DIMENSIONS,
+                    SIMILARITY_MEASURE,
+                    null,
+                    new RateLimitSettings(RATE_LIMIT)
+                )
+            )
+        );
+    }
+
+    public void testFromMap_ZeroInputTokens_Failure() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaEmbeddingsServiceSettings.fromMap(
+                buildServiceSettingsMap(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    SIMILARITY_MEASURE.toString(),
+                    DIMENSIONS,
+                    0,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;")
+        );
+    }
+
+    public void testFromMap_NegativeInputTokens_Failure() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> LlamaEmbeddingsServiceSettings.fromMap(
+                buildServiceSettingsMap(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    SIMILARITY_MEASURE.toString(),
+                    DIMENSIONS,
+                    -10,
+                    new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [max_input_tokens] must be a positive integer;")
+        );
+    }
+
+    public void testFromMap_NoRateLimit_Success() {
+        var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap(
+            buildServiceSettingsMap(MODEL_ID, CORRECT_URL, SIMILARITY_MEASURE.toString(), DIMENSIONS, MAX_INPUT_TOKENS, null),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(
+            serviceSettings,
+            is(
+                new LlamaEmbeddingsServiceSettings(
+                    MODEL_ID,
+                    CORRECT_URL,
+                    DIMENSIONS,
+                    SIMILARITY_MEASURE,
+                    MAX_INPUT_TOKENS,
+                    new RateLimitSettings(3000)
+                )
+            )
+        );
+    }
+
+    public void testToXContent_WritesAllValues() throws IOException {
+        var entity = new LlamaEmbeddingsServiceSettings(
+            MODEL_ID,
+            CORRECT_URL,
+            DIMENSIONS,
+            SIMILARITY_MEASURE,
+            MAX_INPUT_TOKENS,
+            new RateLimitSettings(3)
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace("""
+            {
+                "model_id": "some model",
+                "url": "https://www.elastic.co",
+                "dimensions": 384,
+                "similarity": "dot_product",
+                "max_input_tokens": 128,
+                "rate_limit": {
+                    "requests_per_minute": 3
+                }
+            }
+            """)));
+    }
+
+    public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException {
+        var outputBuffer = new BytesStreamOutput();
+        var settings = new LlamaEmbeddingsServiceSettings(
+            MODEL_ID,
+            CORRECT_URL,
+            DIMENSIONS,
+            SIMILARITY_MEASURE,
+            MAX_INPUT_TOKENS,
+            new RateLimitSettings(3)
+        );
+        settings.writeTo(outputBuffer);
+
+        var outputBufferRef = outputBuffer.bytes();
+        var inputBuffer = new ByteArrayStreamInput(outputBufferRef.array());
+
+        var settingsFromBuffer = new LlamaEmbeddingsServiceSettings(inputBuffer);
+
+        assertEquals(settings, settingsFromBuffer);
+    }
+
+    @Override
+    protected Writeable.Reader<LlamaEmbeddingsServiceSettings> instanceReader() {
+        return LlamaEmbeddingsServiceSettings::new;
+    }
+
+    @Override
+    protected LlamaEmbeddingsServiceSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected LlamaEmbeddingsServiceSettings mutateInstance(LlamaEmbeddingsServiceSettings instance) throws IOException {
+        return randomValueOtherThan(instance, LlamaEmbeddingsServiceSettingsTests::createRandom);
+    }
+
+    private static LlamaEmbeddingsServiceSettings createRandom() {
+        var modelId = randomAlphaOfLength(8);
+        var url = randomAlphaOfLength(15);
+        var similarityMeasure = randomFrom(SimilarityMeasure.values());
+        var dimensions = randomIntBetween(32, 256);
+        var maxInputTokens = randomIntBetween(128, 256);
+        return new LlamaEmbeddingsServiceSettings(
+            modelId,
+            url,
+            dimensions,
+            similarityMeasure,
+            maxInputTokens,
+            RateLimitSettingsTests.createRandom()
+        );
+    }
+
+    public static HashMap<String, Object> buildServiceSettingsMap(
+        @Nullable String modelId,
+        @Nullable String url,
+        @Nullable String similarity,
+        @Nullable Integer dimensions,
+        @Nullable Integer maxInputTokens,
+        @Nullable HashMap<String, Integer> rateLimitSettings
+    ) {
+        HashMap<String, Object> result = new HashMap<>();
+        if (modelId != null) {
+            result.put(ServiceFields.MODEL_ID, modelId);
+        }
+        if (url != null) {
+            result.put(ServiceFields.URL, url);
+        }
+        if (similarity != null) {
+            result.put(ServiceFields.SIMILARITY, similarity);
+        }
+        if (dimensions != null) {
+            result.put(ServiceFields.DIMENSIONS, dimensions);
+        }
+        if (maxInputTokens != null) {
+            result.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens);
+        }
+        if (rateLimitSettings != null) {
+            result.put(RateLimitSettings.FIELD_NAME, rateLimitSettings);
+        }
+        return result;
+    }
+}

+ 64 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java

@@ -0,0 +1,64 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.request.completion;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.json.JsonXContent;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;
+
+import java.io.IOException;
+import java.util.ArrayList;
+
+public class LlamaChatCompletionRequestEntityTests extends ESTestCase {
+    private static final String ROLE = "user";
+
+    public void testModelUserFieldsSerialization() throws IOException {
+        UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
+            new UnifiedCompletionRequest.ContentString("Hello, world!"),
+            ROLE,
+            null,
+            null
+        );
+        var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
+        messageList.add(message);
+
+        var unifiedRequest = UnifiedCompletionRequest.of(messageList);
+
+        UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
+        LlamaChatCompletionModel model = LlamaChatCompletionModelTests.createChatCompletionModel("model", "url", "api-key");
+
+        LlamaChatCompletionRequestEntity entity = new LlamaChatCompletionRequestEntity(unifiedChatInput, model);
+
+        XContentBuilder builder = JsonXContent.contentBuilder();
+        entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
+        String expectedJson = """
+            {
+                "messages": [{
+                        "content": "Hello, world!",
+                        "role": "user"
+                    }
+                ],
+                "model": "model",
+                "n": 1,
+                "stream": true,
+                "stream_options": {
+                    "include_usage": true
+                }
+            }
+            """;
+        assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder));
+    }
+
+}

+ 82 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java

@@ -0,0 +1,82 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.request.completion;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class LlamaChatCompletionRequestTests extends ESTestCase {
+
+    public void testCreateRequest_WithStreaming() throws IOException {
+        String input = randomAlphaOfLength(15);
+        var request = createRequest("model", "url", "secret", input, true);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(request.getURI().toString(), is("url"));
+        assertThat(requestMap.get("stream"), is(true));
+        assertThat(requestMap.get("model"), is("model"));
+        assertThat(requestMap.get("n"), is(1));
+        assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
+        assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
+        assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+    }
+
+    public void testCreateRequest_NoStreaming_NoAuthorization() throws IOException {
+        String input = randomAlphaOfLength(15);
+        var request = createRequestWithNoAuth("model", "url", input, false);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(request.getURI().toString(), is("url"));
+        assertThat(requestMap.get("stream"), is(false));
+        assertThat(requestMap.get("model"), is("model"));
+        assertThat(requestMap.get("n"), is(1));
+        assertNull(requestMap.get("stream_options"));
+        assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
+        assertNull(httpPost.getFirstHeader("Authorization"));
+    }
+
+    public void testTruncate_DoesNotReduceInputTextSize() {
+        String input = randomAlphaOfLength(5);
+        var request = createRequest("model", "url", "secret", input, true);
+        assertThat(request.truncate(), is(request));
+    }
+
+    public void testTruncationInfo_ReturnsNull() {
+        var request = createRequest("model", "url", "secret", randomAlphaOfLength(5), true);
+        assertNull(request.getTruncationInfo());
+    }
+
+    public static LlamaChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) {
+        var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModel(modelId, url, apiKey);
+        return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
+    }
+
+    public static LlamaChatCompletionRequest createRequestWithNoAuth(String modelId, String url, String input, boolean stream) {
+        var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModelNoAuth(modelId, url);
+        return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
+    }
+}

+ 38 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.request.embeddings;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class LlamaEmbeddingsRequestEntityTests extends ESTestCase {
+
+    public void testXContent_Success() throws IOException {
+        var entity = new LlamaEmbeddingsRequestEntity("llama-embed", List.of("ABDC"));
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is(XContentHelper.stripWhitespace("""
+            {
+                "model_id": "llama-embed",
+                "contents": ["ABDC"]
+            }
+            """)));
+    }
+}

+ 101 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java

@@ -0,0 +1,101 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.request.embeddings;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.common.Truncator;
+import org.elasticsearch.xpack.inference.common.TruncatorTests;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.hamcrest.Matchers.aMapWithSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class LlamaEmbeddingsRequestTests extends ESTestCase {
+
+    public void testCreateRequest_WithAuth_Success() throws IOException {
+        var request = createRequest();
+        var httpRequest = request.createHttpRequest();
+        var httpPost = validateRequestUrlAndContentType(httpRequest);
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(2));
+        assertThat(requestMap.get("contents"), is(List.of("ABCD")));
+        assertThat(requestMap.get("model_id"), is("llama-embed"));
+        assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey"));
+    }
+
+    public void testCreateRequest_NoAuth_Success() throws IOException {
+        var request = createRequestNoAuth();
+        var httpRequest = request.createHttpRequest();
+        var httpPost = validateRequestUrlAndContentType(httpRequest);
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(2));
+        assertThat(requestMap.get("contents"), is(List.of("ABCD")));
+        assertThat(requestMap.get("model_id"), is("llama-embed"));
+        assertNull(httpPost.getFirstHeader("Authorization"));
+    }
+
+    public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
+        var request = createRequest();
+        var truncatedRequest = request.truncate();
+
+        var httpRequest = truncatedRequest.createHttpRequest();
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(2));
+        assertThat(requestMap.get("contents"), is(List.of("AB")));
+        assertThat(requestMap.get("model_id"), is("llama-embed"));
+    }
+
+    public void testIsTruncated_ReturnsTrue() {
+        var request = createRequest();
+        assertFalse(request.getTruncationInfo()[0]);
+
+        var truncatedRequest = request.truncate();
+        assertTrue(truncatedRequest.getTruncationInfo()[0]);
+    }
+
+    private HttpPost validateRequestUrlAndContentType(HttpRequest request) {
+        assertThat(request.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) request.httpRequestBase();
+        assertThat(httpPost.getURI().toString(), is("url"));
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters()));
+        return httpPost;
+    }
+
+    private static LlamaEmbeddingsRequest createRequest() {
+        var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModel("llama-embed", "url", "apikey");
+        return new LlamaEmbeddingsRequest(
+            TruncatorTests.createTruncator(),
+            new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }),
+            embeddingsModel
+        );
+    }
+
+    private static LlamaEmbeddingsRequest createRequestNoAuth() {
+        var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModelNoAuth("llama-embed", "url");
+        return new LlamaEmbeddingsRequest(
+            TruncatorTests.createTruncator(),
+            new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }),
+            embeddingsModel
+        );
+    }
+
+}

+ 34 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java

@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama.response;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.mockito.Mockito.mock;
+
+public class LlamaErrorResponseTests extends ESTestCase {
+
+    public static final String ERROR_RESPONSE_JSON = """
+        {
+            "error": "A valid user token is required"
+        }
+        """;
+
+    public void testFromResponse() {
+        var errorResponse = LlamaErrorResponse.fromResponse(
+            new HttpResult(mock(HttpResponse.class), ERROR_RESPONSE_JSON.getBytes(StandardCharsets.UTF_8))
+        );
+        assertNotNull(errorResponse);
+        assertEquals(ERROR_RESPONSE_JSON, errorResponse.getErrorMessage());
+    }
+
+}

+ 0 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.inference.services.mistral.embeddings;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.ChunkingSettings;
-import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
@@ -37,7 +36,6 @@ public class MistralEmbeddingModelTests extends ESTestCase {
             TaskType.TEXT_EMBEDDING,
             "mistral",
             new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings),
-            EmptyTaskSettings.INSTANCE,
             chunkingSettings,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );
@@ -57,7 +55,6 @@ public class MistralEmbeddingModelTests extends ESTestCase {
             TaskType.TEXT_EMBEDDING,
             "mistral",
             new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings),
-            EmptyTaskSettings.INSTANCE,
             null,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java

@@ -49,7 +49,7 @@ public class MistralChatCompletionRequestTests extends ESTestCase {
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         assertThat(requestMap, aMapWithSize(4));
 
-        // We do not truncate for Hugging Face chat completions
+        // We do not truncate for Mistral chat completions
         assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
         assertThat(requestMap.get("model"), is("model"));
         assertThat(requestMap.get("n"), is(1));