Browse Source

[ML] Adding response parsers for custom service (#127179) (#127478)

* Adding response parsers for custom service

* [CI] Auto commit changes from spotless

* Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java



* Refactoring to include field names in exceptions

* Adding list entry index to error message and field names

* Addressing feedback for validation exception

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Co-authored-by: David Kyle <david.kyle@elastic.co>
Jonathan Buttner 5 months ago
parent
commit
7e9a7c3684
18 changed files with 2654 additions and 23 deletions
  1. 64 10
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java
  2. 12 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java
  3. 19 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
  4. 149 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java
  5. 105 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java
  6. 19 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java
  7. 105 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java
  8. 46 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/NoopResponseParser.java
  9. 187 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java
  10. 165 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java
  11. 104 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java
  12. 51 13
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java
  13. 111 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParserTests.java
  14. 304 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java
  15. 145 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java
  16. 456 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java
  17. 349 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java
  18. 263 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java

+ 64 - 10
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java

@@ -10,8 +10,10 @@ package org.elasticsearch.xpack.inference.common;
 import org.elasticsearch.common.Strings;
 
 import java.util.ArrayList;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.regex.Pattern;
 
 /**
@@ -78,6 +80,8 @@ import java.util.regex.Pattern;
  *   [1, 2]
  * ]
  * }
+ *
+ * The array field names would be {@code ["embeddings", "embedding"}
  * </pre>
  *
  * This implementation differs from JSONPath when handling a list of maps. JSONPath will flatten the result and return a single array.
@@ -123,10 +127,28 @@ public class MapPathExtractor {
     private static final String DOLLAR = "$";
 
     // default for testing
-    static final Pattern dotFieldPattern = Pattern.compile("^\\.([^.\\[]+)(.*)");
-    static final Pattern arrayWildcardPattern = Pattern.compile("^\\[\\*\\](.*)");
+    static final Pattern DOT_FIELD_PATTERN = Pattern.compile("^\\.([^.\\[]+)(.*)");
+    static final Pattern ARRAY_WILDCARD_PATTERN = Pattern.compile("^\\[\\*\\](.*)");
+    public static final String UNKNOWN_FIELD_NAME = "unknown";
+
+    /**
+     * A result object that tries to match up the field names parsed from the passed in path and the result
+     * extracted from the passed in map.
+     * @param extractedObject represents the extracted result from the map
+     * @param traversedFields a list of field names in order as they're encountered while navigating through the nested objects
+     */
+    public record Result(Object extractedObject, List<String> traversedFields) {
+        public String getArrayFieldName(int index) {
+            // if the index is out of bounds we'll return a default value
+            if (traversedFields.size() <= index || index < 0) {
+                return UNKNOWN_FIELD_NAME;
+            }
+
+            return traversedFields.get(index);
+        }
+    }
 
-    public static Object extract(Map<String, Object> data, String path) {
+    public static Result extract(Map<String, Object> data, String path) {
         if (data == null || data.isEmpty() || path == null || path.trim().isEmpty()) {
             return null;
         }
@@ -139,16 +161,41 @@ public class MapPathExtractor {
             throw new IllegalArgumentException(Strings.format("Path [%s] must start with a dollar sign ($)", cleanedPath));
         }
 
-        return navigate(data, cleanedPath);
+        var fieldNames = new LinkedHashSet<String>();
+
+        return new Result(navigate(data, cleanedPath, new FieldNameInfo("", "", fieldNames)), fieldNames.stream().toList());
     }
 
-    private static Object navigate(Object current, String remainingPath) {
-        if (current == null || remainingPath == null || remainingPath.isEmpty()) {
+    private record FieldNameInfo(String currentPath, String fieldName, Set<String> traversedFields) {
+        void addTraversedField(String fieldName) {
+            traversedFields.add(createPath(fieldName));
+        }
+
+        void addCurrentField() {
+            traversedFields.add(currentPath);
+        }
+
+        FieldNameInfo descend(String newFieldName) {
+            var newLocation = createPath(newFieldName);
+            return new FieldNameInfo(newLocation, newFieldName, traversedFields);
+        }
+
+        private String createPath(String newFieldName) {
+            if (Strings.isNullOrEmpty(currentPath)) {
+                return newFieldName;
+            } else {
+                return currentPath + "." + newFieldName;
+            }
+        }
+    }
+
+    private static Object navigate(Object current, String remainingPath, FieldNameInfo fieldNameInfo) {
+        if (current == null || Strings.isNullOrEmpty(remainingPath)) {
             return current;
         }
 
-        var dotFieldMatcher = dotFieldPattern.matcher(remainingPath);
-        var arrayWildcardMatcher = arrayWildcardPattern.matcher(remainingPath);
+        var dotFieldMatcher = DOT_FIELD_PATTERN.matcher(remainingPath);
+        var arrayWildcardMatcher = ARRAY_WILDCARD_PATTERN.matcher(remainingPath);
 
         if (dotFieldMatcher.matches()) {
             String field = dotFieldMatcher.group(1);
@@ -168,7 +215,12 @@ public class MapPathExtractor {
                     throw new IllegalArgumentException(Strings.format("Unable to find field [%s] in map", field));
                 }
 
-                return navigate(currentMap.get(field), nextPath);
+                // Handle the case where the path was $.result.text or $.result[*].key
+                if (Strings.isNullOrEmpty(nextPath)) {
+                    fieldNameInfo.addTraversedField(field);
+                }
+
+                return navigate(currentMap.get(field), nextPath, fieldNameInfo.descend(field));
             } else {
                 throw new IllegalArgumentException(
                     Strings.format(
@@ -182,10 +234,12 @@ public class MapPathExtractor {
         } else if (arrayWildcardMatcher.matches()) {
             String nextPath = arrayWildcardMatcher.group(1);
             if (current instanceof List<?> list) {
+                fieldNameInfo.addCurrentField();
+
                 List<Object> results = new ArrayList<>();
 
                 for (Object item : list) {
-                    Object result = navigate(item, nextPath);
+                    Object result = navigate(item, nextPath, fieldNameInfo);
                     if (result != null) {
                         results.add(result);
                     }

+ 12 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java

@@ -34,4 +34,16 @@ public class ErrorResponse {
     public boolean errorStructureFound() {
         return errorStructureFound;
     }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == null || getClass() != o.getClass()) return false;
+        ErrorResponse that = (ErrorResponse) o;
+        return errorStructureFound == that.errorStructureFound && Objects.equals(errorMessage, that.errorMessage);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(errorMessage, errorStructureFound);
+    }
 }

+ 19 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

@@ -0,0 +1,19 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+public class CustomServiceSettings {
+    public static final String NAME = "custom_service_settings";
+    public static final String URL = "url";
+    public static final String HEADERS = "headers";
+    public static final String REQUEST = "request";
+    public static final String REQUEST_CONTENT = "content";
+    public static final String RESPONSE = "response";
+    public static final String JSON_PARSER = "json_parser";
+    public static final String ERROR_PARSER = "error_parser";
+}

+ 149 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java

@@ -0,0 +1,149 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.BiFunction;
+
+public abstract class BaseCustomResponseParser<T extends InferenceServiceResults> implements CustomResponseParser {
+
+    @Override
+    public InferenceServiceResults parse(HttpResult response) throws IOException {
+        try (
+            XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
+                .createParser(XContentParserConfiguration.EMPTY, response.body())
+        ) {
+            var map = jsonParser.map();
+
+            return transform(map);
+        }
+    }
+
+    protected abstract T transform(Map<String, Object> extractedField);
+
+    static List<?> validateList(Object obj, String fieldName) {
+        validateNonNull(obj, fieldName);
+
+        if (obj instanceof List<?> == false) {
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "Extracted field [%s] is an invalid type, expected a list but received [%s]",
+                    fieldName,
+                    obj.getClass().getSimpleName()
+                )
+            );
+        }
+
+        return (List<?>) obj;
+    }
+
+    static void validateNonNull(Object obj, String fieldName) {
+        Objects.requireNonNull(obj, Strings.format("Failed to parse field [%s], extracted field was null", fieldName));
+    }
+
+    static Map<String, Object> validateMap(Object obj, String fieldName) {
+        validateNonNull(obj, fieldName);
+
+        if (obj instanceof Map<?, ?> == false) {
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "Extracted field [%s] is an invalid type, expected a map but received [%s]",
+                    fieldName,
+                    obj.getClass().getSimpleName()
+                )
+            );
+        }
+
+        var keys = ((Map<?, ?>) obj).keySet();
+        for (var key : keys) {
+            if (key instanceof String == false) {
+                throw new IllegalStateException(
+                    Strings.format(
+                        "Extracted field [%s] map has an invalid key type. Expected a string but received [%s]",
+                        fieldName,
+                        key.getClass().getSimpleName()
+                    )
+                );
+            }
+        }
+
+        @SuppressWarnings("unchecked")
+        var result = (Map<String, Object>) obj;
+        return result;
+    }
+
+    static List<Float> convertToListOfFloats(Object obj, String fieldName) {
+        return castList(validateList(obj, fieldName), BaseCustomResponseParser::toFloat, fieldName);
+    }
+
+    static Float toFloat(Object obj, String fieldName) {
+        return toNumber(obj, fieldName).floatValue();
+    }
+
+    private static Number toNumber(Object obj, String fieldName) {
+        if (obj instanceof Number == false) {
+            throw new IllegalArgumentException(
+                Strings.format("Unable to convert field [%s] of type [%s] to Number", fieldName, obj.getClass().getSimpleName())
+            );
+        }
+
+        return ((Number) obj);
+    }
+
+    static List<Integer> convertToListOfIntegers(Object obj, String fieldName) {
+        return castList(validateList(obj, fieldName), BaseCustomResponseParser::toInteger, fieldName);
+    }
+
+    private static Integer toInteger(Object obj, String fieldName) {
+        return toNumber(obj, fieldName).intValue();
+    }
+
+    static <T> List<T> castList(List<?> items, BiFunction<Object, String, T> converter, String fieldName) {
+        validateNonNull(items, fieldName);
+
+        List<T> resultList = new ArrayList<>();
+        for (int i = 0; i < items.size(); i++) {
+            try {
+                resultList.add(converter.apply(items.get(i), fieldName));
+            } catch (Exception e) {
+                throw new IllegalStateException(Strings.format("Failed to parse list entry [%d], error: %s", i, e.getMessage()), e);
+            }
+        }
+
+        return resultList;
+    }
+
+    static <T> T toType(Object obj, Class<T> type, String fieldName) {
+        validateNonNull(obj, fieldName);
+
+        if (type.isInstance(obj) == false) {
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "Unable to convert field [%s] of type [%s] to [%s]",
+                    fieldName,
+                    obj.getClass().getSimpleName(),
+                    type.getSimpleName()
+                )
+            );
+        }
+
+        return type.cast(obj);
+    }
+}

+ 105 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java

@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.common.MapPathExtractor;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;
+
+public class CompletionResponseParser extends BaseCustomResponseParser<ChatCompletionResults> {
+
+    public static final String NAME = "completion_response_parser";
+    public static final String COMPLETION_PARSER_RESULT = "completion_result";
+
+    private final String completionResultPath;
+
+    public static CompletionResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
+        var path = extractRequiredString(responseParserMap, COMPLETION_PARSER_RESULT, JSON_PARSER, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new CompletionResponseParser(path);
+    }
+
+    public CompletionResponseParser(String completionResultPath) {
+        this.completionResultPath = Objects.requireNonNull(completionResultPath);
+    }
+
+    public CompletionResponseParser(StreamInput in) throws IOException {
+        this.completionResultPath = in.readString();
+    }
+
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(completionResultPath);
+    }
+
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject(JSON_PARSER);
+        {
+            builder.field(COMPLETION_PARSER_RESULT, completionResultPath);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        CompletionResponseParser that = (CompletionResponseParser) o;
+        return Objects.equals(completionResultPath, that.completionResultPath);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(completionResultPath);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public ChatCompletionResults transform(Map<String, Object> map) {
+        var result = MapPathExtractor.extract(map, completionResultPath);
+        var extractedField = result.extractedObject();
+
+        validateNonNull(extractedField, completionResultPath);
+
+        if (extractedField instanceof List<?> extractedList) {
+            var completionList = castList(extractedList, (obj, fieldName) -> toType(obj, String.class, fieldName), completionResultPath);
+            return new ChatCompletionResults(completionList.stream().map(ChatCompletionResults.Result::new).toList());
+        } else if (extractedField instanceof String extractedString) {
+            return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(extractedString)));
+        } else {
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "Extracted field [%s] from path [%s] is an invalid type, expected a list or a string but received [%s]",
+                    result.getArrayFieldName(0),
+                    completionResultPath,
+                    extractedField.getClass().getSimpleName()
+                )
+            );
+        }
+    }
+}

+ 19 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java

@@ -0,0 +1,19 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xcontent.ToXContentFragment;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.io.IOException;
+
+public interface CustomResponseParser extends ToXContentFragment, NamedWriteable {
+    InferenceServiceResults parse(HttpResult response) throws IOException;
+}

+ 105 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java

@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.ToXContentFragment;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.common.MapPathExtractor;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Function;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.ERROR_PARSER;
+import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toType;
+
+public class ErrorResponseParser implements ToXContentFragment, Function<HttpResult, ErrorResponse> {
+
+    public static final String MESSAGE_PATH = "path";
+
+    private final String messagePath;
+
+    public static ErrorResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
+        var path = extractRequiredString(responseParserMap, MESSAGE_PATH, ERROR_PARSER, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new ErrorResponseParser(path);
+    }
+
+    public ErrorResponseParser(String messagePath) {
+        this.messagePath = Objects.requireNonNull(messagePath);
+    }
+
+    public ErrorResponseParser(StreamInput in) throws IOException {
+        this.messagePath = in.readString();
+    }
+
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(messagePath);
+    }
+
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject(ERROR_PARSER);
+        {
+            builder.field(MESSAGE_PATH, messagePath);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        ErrorResponseParser that = (ErrorResponseParser) o;
+        return Objects.equals(messagePath, that.messagePath);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(messagePath);
+    }
+
+    @Override
+    public ErrorResponse apply(HttpResult httpResult) {
+        try (
+            XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
+                .createParser(XContentParserConfiguration.EMPTY, httpResult.body())
+        ) {
+            var map = jsonParser.map();
+
+            // NOTE: This deviates from what we've done in the past. In the ErrorMessageResponseEntity logic
+            // if we find the top level error field we'll return a response with an empty message but indicate
+            // that we found the structure of the error object. Here if we're missing the final field we will return
+            // a ErrorResponse.UNDEFINED_ERROR which will indicate that we did not find the structure even if for example
+            // the outer error field does exist, but it doesn't contain the nested field we were looking for.
+            // If in the future we want the previous behavior, we can add a new message_path field or something and have
+            // the current path field point to the field that indicates whether we found an error object.
+            var errorText = toType(MapPathExtractor.extract(map, messagePath).extractedObject(), String.class, messagePath);
+            return new ErrorResponse(errorText);
+        } catch (Exception e) {
+            // swallow the error
+        }
+
+        return ErrorResponse.UNDEFINED_ERROR;
+    }
+}

+ 46 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/NoopResponseParser.java

@@ -0,0 +1,46 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.io.IOException;
+
+public record NoopResponseParser() implements CustomResponseParser {
+
+    public static final String NAME = "noop_response_parser";
+    public static final NoopResponseParser INSTANCE = new NoopResponseParser();
+
+    public static NoopResponseParser fromMap() {
+        return new NoopResponseParser();
+    }
+
+    public NoopResponseParser(StreamInput in) {
+        this();
+    }
+
+    public void writeTo(StreamOutput out) throws IOException {}
+
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public InferenceServiceResults parse(HttpResult result) {
+        return null;
+    }
+}

+ 187 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java

@@ -0,0 +1,187 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.Strings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.inference.common.MapPathExtractor;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;
+
+public class RerankResponseParser extends BaseCustomResponseParser<RankedDocsResults> {
+
+    public static final String NAME = "rerank_response_parser";
+    public static final String RERANK_PARSER_SCORE = "relevance_score";
+    public static final String RERANK_PARSER_INDEX = "reranked_index";
+    public static final String RERANK_PARSER_DOCUMENT_TEXT = "document_text";
+
+    private final String relevanceScorePath;
+    private final String rerankIndexPath;
+    private final String documentTextPath;
+
+    public static RerankResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
+
+        var relevanceScore = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, JSON_PARSER, validationException);
+        var rerankIndex = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, JSON_PARSER, validationException);
+        var documentText = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, JSON_PARSER, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new RerankResponseParser(relevanceScore, rerankIndex, documentText);
+    }
+
+    public RerankResponseParser(String relevanceScorePath) {
+        this(relevanceScorePath, null, null);
+    }
+
+    public RerankResponseParser(String relevanceScorePath, @Nullable String rerankIndexPath, @Nullable String documentTextPath) {
+        this.relevanceScorePath = Objects.requireNonNull(relevanceScorePath);
+        this.rerankIndexPath = rerankIndexPath;
+        this.documentTextPath = documentTextPath;
+    }
+
+    public RerankResponseParser(StreamInput in) throws IOException {
+        this.relevanceScorePath = in.readString();
+        this.rerankIndexPath = in.readOptionalString();
+        this.documentTextPath = in.readOptionalString();
+    }
+
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(relevanceScorePath);
+        out.writeOptionalString(rerankIndexPath);
+        out.writeOptionalString(documentTextPath);
+    }
+
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject(JSON_PARSER);
+        {
+            builder.field(RERANK_PARSER_SCORE, relevanceScorePath);
+            if (rerankIndexPath != null) {
+                builder.field(RERANK_PARSER_INDEX, rerankIndexPath);
+            }
+
+            if (documentTextPath != null) {
+                builder.field(RERANK_PARSER_DOCUMENT_TEXT, documentTextPath);
+            }
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        RerankResponseParser that = (RerankResponseParser) o;
+        return Objects.equals(relevanceScorePath, that.relevanceScorePath)
+            && Objects.equals(rerankIndexPath, that.rerankIndexPath)
+            && Objects.equals(documentTextPath, that.documentTextPath);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(relevanceScorePath, rerankIndexPath, documentTextPath);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public RankedDocsResults transform(Map<String, Object> map) {
+        var scores = extractScores(map);
+        var indices = extractIndices(map);
+        var documents = extractDocuments(map);
+
+        if (indices != null && indices.size() != scores.size()) {
+            throw new IllegalStateException(
+                Strings.format(
+                    "The number of index fields [%d] was not the same as the number of scores [%d]",
+                    indices.size(),
+                    scores.size()
+                )
+            );
+        }
+
+        if (documents != null && documents.size() != scores.size()) {
+            throw new IllegalStateException(
+                Strings.format(
+                    "The number of document fields [%d] was not the same as the number of scores [%d]",
+                    documents.size(),
+                    scores.size()
+                )
+            );
+        }
+
+        var rankedDocs = new ArrayList<RankedDocsResults.RankedDoc>();
+        for (int i = 0; i < scores.size(); i++) {
+            var index = indices != null ? indices.get(i) : i;
+            var score = scores.get(i);
+            var document = documents != null ? documents.get(i) : null;
+            rankedDocs.add(new RankedDocsResults.RankedDoc(index, score, document));
+        }
+
+        return new RankedDocsResults(rankedDocs);
+    }
+
+    private List<Float> extractScores(Map<String, Object> map) {
+        try {
+            var result = MapPathExtractor.extract(map, relevanceScorePath);
+            return convertToListOfFloats(result.extractedObject(), result.getArrayFieldName(0));
+        } catch (Exception e) {
+            throw new IllegalStateException(Strings.format("Failed to parse rerank scores, error: %s", e.getMessage()), e);
+        }
+    }
+
+    private List<Integer> extractIndices(Map<String, Object> map) {
+        if (rerankIndexPath != null) {
+            try {
+                var indexResult = MapPathExtractor.extract(map, rerankIndexPath);
+                return convertToListOfIntegers(indexResult.extractedObject(), indexResult.getArrayFieldName(0));
+            } catch (Exception e) {
+                throw new IllegalStateException(Strings.format("Failed to parse rerank indices, error: %s", e.getMessage()), e);
+            }
+        }
+
+        return null;
+    }
+
+    private List<String> extractDocuments(Map<String, Object> map) {
+        try {
+            if (documentTextPath != null) {
+                var documentResult = MapPathExtractor.extract(map, documentTextPath);
+                var documentFieldName = documentResult.getArrayFieldName(0);
+                return castList(
+                    validateList(documentResult.extractedObject(), documentFieldName),
+                    (obj, fieldName) -> toType(obj, String.class, fieldName),
+                    documentFieldName
+                );
+            }
+        } catch (Exception e) {
+            throw new IllegalStateException(Strings.format("Failed to parse rerank documents, error: %s", e.getMessage()), e);
+        }
+
+        return null;
+    }
+}

+ 165 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java

@@ -0,0 +1,165 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.search.WeightedToken;
+import org.elasticsearch.xpack.inference.common.MapPathExtractor;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;
+
+public class SparseEmbeddingResponseParser extends BaseCustomResponseParser<SparseEmbeddingResults> {
+
+    public static final String NAME = "sparse_embedding_response_parser";
+    public static final String SPARSE_EMBEDDING_TOKEN_PATH = "token_path";
+    public static final String SPARSE_EMBEDDING_WEIGHT_PATH = "weight_path";
+
+    private final String tokenPath;
+    private final String weightPath;
+
+    public static SparseEmbeddingResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
+        var tokenPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_TOKEN_PATH, JSON_PARSER, validationException);
+
+        var weightPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_WEIGHT_PATH, JSON_PARSER, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new SparseEmbeddingResponseParser(tokenPath, weightPath);
+    }
+
+    public SparseEmbeddingResponseParser(String tokenPath, String weightPath) {
+        this.tokenPath = Objects.requireNonNull(tokenPath);
+        this.weightPath = Objects.requireNonNull(weightPath);
+    }
+
+    public SparseEmbeddingResponseParser(StreamInput in) throws IOException {
+        this.tokenPath = in.readString();
+        this.weightPath = in.readString();
+    }
+
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(tokenPath);
+        out.writeString(weightPath);
+    }
+
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject(JSON_PARSER);
+        {
+            builder.field(SPARSE_EMBEDDING_TOKEN_PATH, tokenPath);
+            builder.field(SPARSE_EMBEDDING_WEIGHT_PATH, weightPath);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        SparseEmbeddingResponseParser that = (SparseEmbeddingResponseParser) o;
+        return Objects.equals(tokenPath, that.tokenPath) && Objects.equals(weightPath, that.weightPath);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(tokenPath, weightPath);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    protected SparseEmbeddingResults transform(Map<String, Object> map) {
+        // These will be List<List<T>>
+        var tokenResult = MapPathExtractor.extract(map, tokenPath);
+        var tokens = validateList(tokenResult.extractedObject(), tokenResult.getArrayFieldName(0));
+
+        // These will be List<List<T>>
+        var weightResult = MapPathExtractor.extract(map, weightPath);
+        var weights = validateList(weightResult.extractedObject(), weightResult.getArrayFieldName(0));
+
+        validateListsSize(tokens, weights);
+
+        var tokenEntryFieldName = tokenResult.getArrayFieldName(1);
+        var weightEntryFieldName = weightResult.getArrayFieldName(1);
+        var embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
+        for (int responseCounter = 0; responseCounter < tokens.size(); responseCounter++) {
+            try {
+                var tokenEntryList = validateList(tokens.get(responseCounter), tokenEntryFieldName);
+                var weightEntryList = validateList(weights.get(responseCounter), weightEntryFieldName);
+
+                validateListsSize(tokenEntryList, weightEntryList);
+
+                embeddings.add(createEmbedding(tokenEntryList, weightEntryList, weightEntryFieldName));
+            } catch (Exception e) {
+                throw new IllegalStateException(
+                    Strings.format("Failed to parse sparse embedding entry [%d], error: %s", responseCounter, e.getMessage()),
+                    e
+                );
+            }
+        }
+
+        return new SparseEmbeddingResults(Collections.unmodifiableList(embeddings));
+    }
+
+    private static void validateListsSize(List<?> tokens, List<?> weights) {
+        if (tokens.size() != weights.size()) {
+            throw new IllegalStateException(
+                Strings.format(
+                    "The extracted tokens list is size [%d] but the weights list is size [%d]. The list sizes must be equal.",
+                    tokens.size(),
+                    weights.size()
+                )
+            );
+        }
+    }
+
+    private static SparseEmbeddingResults.Embedding createEmbedding(
+        List<?> tokenEntryList,
+        List<?> weightEntryList,
+        String weightFieldName
+    ) {
+        var weightedTokens = new ArrayList<WeightedToken>();
+
+        for (int embeddingCounter = 0; embeddingCounter < tokenEntryList.size(); embeddingCounter++) {
+            var token = tokenEntryList.get(embeddingCounter);
+            var weight = weightEntryList.get(embeddingCounter);
+
+            // Alibaba can return a token id which is an integer and needs to be converted to a string
+            var tokenIdAsString = token.toString();
+            try {
+                var weightAsFloat = toFloat(weight, weightFieldName);
+                weightedTokens.add(new WeightedToken(tokenIdAsString, weightAsFloat));
+            } catch (IllegalArgumentException e) {
+                throw new IllegalArgumentException(
+                    Strings.format("Failed to parse weight item: [%d] of array, error: %s", embeddingCounter, e.getMessage()),
+                    e
+                );
+            }
+        }
+
+        return new SparseEmbeddingResults.Embedding(weightedTokens, false);
+    }
+}

+ 104 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java

@@ -0,0 +1,104 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.inference.common.MapPathExtractor;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;
+
+public class TextEmbeddingResponseParser extends BaseCustomResponseParser<TextEmbeddingFloatResults> {
+
+    public static final String NAME = "text_embedding_response_parser";
+    public static final String TEXT_EMBEDDING_PARSER_EMBEDDINGS = "text_embeddings";
+
+    private final String textEmbeddingsPath;
+
+    public static TextEmbeddingResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
+        var path = extractRequiredString(responseParserMap, TEXT_EMBEDDING_PARSER_EMBEDDINGS, JSON_PARSER, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new TextEmbeddingResponseParser(path);
+    }
+
+    public TextEmbeddingResponseParser(String textEmbeddingsPath) {
+        this.textEmbeddingsPath = Objects.requireNonNull(textEmbeddingsPath);
+    }
+
+    public TextEmbeddingResponseParser(StreamInput in) throws IOException {
+        this.textEmbeddingsPath = in.readString();
+    }
+
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(textEmbeddingsPath);
+    }
+
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject(JSON_PARSER);
+        {
+            builder.field(TEXT_EMBEDDING_PARSER_EMBEDDINGS, textEmbeddingsPath);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TextEmbeddingResponseParser that = (TextEmbeddingResponseParser) o;
+        return Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(textEmbeddingsPath);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    protected TextEmbeddingFloatResults transform(Map<String, Object> map) {
+        var extractedResult = MapPathExtractor.extract(map, textEmbeddingsPath);
+        var mapResultsList = validateList(extractedResult.extractedObject(), extractedResult.getArrayFieldName(0));
+
+        var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>(mapResultsList.size());
+
+        for (int i = 0; i < mapResultsList.size(); i++) {
+            try {
+                var entry = mapResultsList.get(i);
+                var embeddingsAsListFloats = convertToListOfFloats(entry, extractedResult.getArrayFieldName(1));
+                embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats));
+            } catch (Exception e) {
+                throw new IllegalArgumentException(
+                    Strings.format("Failed to parse text embedding entry [%d], error: %s", i, e.getMessage()),
+                    e
+                );
+            }
+        }
+
+        return new TextEmbeddingFloatResults(embeddings);
+    }
+}

+ 51 - 13
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java

@@ -21,7 +21,15 @@ public class MapPathExtractorTests extends ESTestCase {
             Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4))))
         );
 
-        assertThat(MapPathExtractor.extract(input, "$.result.embeddings[*].embedding"), is(List.of(List.of(1, 2), List.of(3, 4))));
+        assertThat(
+            MapPathExtractor.extract(input, "$.result.embeddings[*].embedding"),
+            is(
+                new MapPathExtractor.Result(
+                    List.of(List.of(1, 2), List.of(3, 4)),
+                    List.of("result.embeddings", "result.embeddings.embedding")
+                )
+            )
+        );
     }
 
     public void testExtract_IteratesListOfMapsToListOfStrings() {
@@ -32,7 +40,29 @@ public class MapPathExtractorTests extends ESTestCase {
 
         assertThat(
             MapPathExtractor.extract(input, "$.result[*].key[*]"),
-            is(List.of(List.of("value1", "value2"), List.of("value3", "value4")))
+            is(
+                new MapPathExtractor.Result(
+                    List.of(List.of("value1", "value2"), List.of("value3", "value4")),
+                    List.of("result", "result.key")
+                )
+            )
+        );
+    }
+
+    public void testExtract_IteratesListOfMapsToListOfStrings_WithoutFinalArraySyntax() {
+        Map<String, Object> input = Map.of(
+            "result",
+            List.of(Map.of("key", List.of("value1", "value2")), Map.of("key", List.of("value3", "value4")))
+        );
+
+        assertThat(
+            MapPathExtractor.extract(input, "$.result[*].key"),
+            is(
+                new MapPathExtractor.Result(
+                    List.of(List.of("value1", "value2"), List.of("value3", "value4")),
+                    List.of("result", "result.key")
+                )
+            )
         );
     }
 
@@ -45,7 +75,15 @@ public class MapPathExtractorTests extends ESTestCase {
             )
         );
 
-        assertThat(MapPathExtractor.extract(input, "$.result[*].key[*].a"), is(List.of(List.of(1.1d, 2.2d), List.of(3.3d, 4.4d))));
+        assertThat(
+            MapPathExtractor.extract(input, "$.result[*].key[*].a"),
+            is(
+                new MapPathExtractor.Result(
+                    List.of(List.of(1.1d, 2.2d), List.of(3.3d, 4.4d)),
+                    List.of("result", "result.key", "result.key.a")
+                )
+            )
+        );
     }
 
     public void testExtract_ReturnsNullForEmptyList() {
@@ -128,36 +166,36 @@ public class MapPathExtractorTests extends ESTestCase {
     public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty() {
         Map<String, Object> input = Map.of("result", List.of());
 
-        assertThat(MapPathExtractor.extract(input, "$.result"), is(List.of()));
+        assertThat(MapPathExtractor.extract(input, "$.result"), is(new MapPathExtractor.Result(List.of(), List.of("result"))));
     }
 
     public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty_PathIncludesArray() {
         Map<String, Object> input = Map.of("result", List.of());
 
-        assertThat(MapPathExtractor.extract(input, "$.result[*]"), is(List.of()));
+        assertThat(MapPathExtractor.extract(input, "$.result[*]"), is(new MapPathExtractor.Result(List.of(), List.of("result"))));
     }
 
     public void testDotFieldPattern() {
         {
-            var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc.123");
+            var matcher = MapPathExtractor.DOT_FIELD_PATTERN.matcher(".abc.123");
             assertTrue(matcher.matches());
             assertThat(matcher.group(1), is("abc"));
             assertThat(matcher.group(2), is(".123"));
         }
         {
-            var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[*].123");
+            var matcher = MapPathExtractor.DOT_FIELD_PATTERN.matcher(".abc[*].123");
             assertTrue(matcher.matches());
             assertThat(matcher.group(1), is("abc"));
             assertThat(matcher.group(2), is("[*].123"));
         }
         {
-            var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[.123");
+            var matcher = MapPathExtractor.DOT_FIELD_PATTERN.matcher(".abc[.123");
             assertTrue(matcher.matches());
             assertThat(matcher.group(1), is("abc"));
             assertThat(matcher.group(2), is("[.123"));
         }
         {
-            var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc");
+            var matcher = MapPathExtractor.DOT_FIELD_PATTERN.matcher(".abc");
             assertTrue(matcher.matches());
             assertThat(matcher.group(1), is("abc"));
             assertThat(matcher.group(2), is(""));
@@ -166,21 +204,21 @@ public class MapPathExtractorTests extends ESTestCase {
 
     public void testArrayWildcardPattern() {
         {
-            var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*].abc.123");
+            var matcher = MapPathExtractor.ARRAY_WILDCARD_PATTERN.matcher("[*].abc.123");
             assertTrue(matcher.matches());
             assertThat(matcher.group(1), is(".abc.123"));
         }
         {
-            var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*]");
+            var matcher = MapPathExtractor.ARRAY_WILDCARD_PATTERN.matcher("[*]");
             assertTrue(matcher.matches());
             assertThat(matcher.group(1), is(""));
         }
         {
-            var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[1].abc");
+            var matcher = MapPathExtractor.ARRAY_WILDCARD_PATTERN.matcher("[1].abc");
             assertFalse(matcher.matches());
         }
         {
-            var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[].abc");
+            var matcher = MapPathExtractor.ARRAY_WILDCARD_PATTERN.matcher("[].abc");
             assertFalse(matcher.matches());
         }
     }

+ 111 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParserTests.java

@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+
+import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.castList;
+import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.convertToListOfFloats;
+import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toFloat;
+import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toType;
+import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.validateList;
+import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.validateMap;
+import static org.hamcrest.Matchers.is;
+
+public class BaseCustomResponseParserTests extends ESTestCase {
+    public void testValidateNonNull_ThrowsException_WhenPassedNull() {
+        var exception = expectThrows(NullPointerException.class, () -> BaseCustomResponseParser.validateNonNull(null, "field"));
+        assertThat(exception.getMessage(), is("Failed to parse field [field], extracted field was null"));
+    }
+
+    public void testValidateList_ThrowsException_WhenPassedAnObjectThatIsNotAList() {
+        var exception = expectThrows(IllegalArgumentException.class, () -> validateList(new Object(), "field"));
+        assertThat(exception.getMessage(), is("Extracted field [field] is an invalid type, expected a list but received [Object]"));
+    }
+
+    public void testValidateList_ReturnsList() {
+        Object obj = List.of("abc", "123");
+        assertThat(validateList(obj, "field"), is(List.of("abc", "123")));
+    }
+
+    public void testConvertToListOfFloats_ThrowsException_WhenAnItemInTheListIsNotANumber() {
+        var list = List.of(1, "hello");
+
+        var exception = expectThrows(IllegalStateException.class, () -> convertToListOfFloats(list, "field"));
+        assertThat(
+            exception.getMessage(),
+            is("Failed to parse list entry [1], error: Unable to convert field [field] of type [String] to Number")
+        );
+    }
+
+    public void testConvertToListOfFloats_ReturnsList() {
+        var list = List.of(1, 1.1f, -2.0d, new AtomicInteger(1));
+
+        assertThat(convertToListOfFloats(list, "field"), is(List.of(1f, 1.1f, -2f, 1f)));
+    }
+
+    public void testCastList() {
+        var list = List.of("abc", "123", 1, 2.2d);
+
+        assertThat(castList(list, (obj, fieldName) -> obj.toString(), "field"), is(List.of("abc", "123", "1", "2.2")));
+    }
+
+    public void testCastList_ThrowsException() {
+        var list = List.of("abc");
+
+        var exception = expectThrows(IllegalStateException.class, () -> castList(list, (obj, fieldName) -> {
+            throw new IllegalArgumentException("failed");
+        }, "field"));
+
+        assertThat(exception.getMessage(), is("Failed to parse list entry [0], error: failed"));
+    }
+
+    public void testValidateMap() {
+        assertThat(validateMap(Map.of("abc", 123), "field"), is(Map.of("abc", 123)));
+    }
+
+    public void testValidateMap_ThrowsException_WhenObjectIsNotAMap() {
+        var exception = expectThrows(IllegalArgumentException.class, () -> validateMap("hello", "field"));
+        assertThat(exception.getMessage(), is("Extracted field [field] is an invalid type, expected a map but received [String]"));
+    }
+
+    public void testValidateMap_ThrowsException_WhenKeysAreNotStrings() {
+        var exception = expectThrows(IllegalStateException.class, () -> validateMap(Map.of("key", "value", 1, "abc"), "field"));
+        assertThat(
+            exception.getMessage(),
+            is("Extracted field [field] map has an invalid key type. Expected a string but received [Integer]")
+        );
+    }
+
+    public void testToFloat() {
+        assertThat(toFloat(1, "field"), is(1f));
+    }
+
+    public void testToFloat_AtomicLong() {
+        assertThat(toFloat(new AtomicLong(100), "field"), is(100f));
+    }
+
+    public void testToFloat_Double() {
+        assertThat(toFloat(1.123d, "field"), is(1.123f));
+    }
+
+    public void testToType() {
+        Object obj = "hello";
+        assertThat(toType(obj, String.class, "field"), is("hello"));
+    }
+
+    public void testToType_List() {
+        Object obj = List.of(123, 456);
+        assertThat(toType(obj, List.class, "field"), is(List.of(123, 456)));
+    }
+}

+ 304 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java

@@ -0,0 +1,304 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser.COMPLETION_PARSER_RESULT;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class CompletionResponseParserTests extends AbstractBWCWireSerializationTestCase<CompletionResponseParser> {
+
+    public static CompletionResponseParser createRandom() {
+        return new CompletionResponseParser("$." + randomAlphaOfLength(5));
+    }
+
+    public void testFromMap() {
+        var validation = new ValidationException();
+        var parser = CompletionResponseParser.fromMap(new HashMap<>(Map.of(COMPLETION_PARSER_RESULT, "$.result[*].text")), validation);
+
+        assertThat(parser, is(new CompletionResponseParser("$.result[*].text")));
+    }
+
+    public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() {
+        var validation = new ValidationException();
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> CompletionResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].text")), validation)
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is("Validation Failed: 1: [json_parser] does not contain the required setting [completion_result];")
+        );
+    }
+
+    public void testToXContent() throws IOException {
+        var entity = new CompletionResponseParser("$.result[*].text");
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        {
+            builder.startObject();
+            entity.toXContent(builder, null);
+            builder.endObject();
+        }
+        String xContentResult = Strings.toString(builder);
+
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "json_parser": {
+                    "completion_result": "$.result[*].text"
+                }
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    public void testParse() throws IOException {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "result": [
+                  {
+                    "text":"completion results"
+                  }
+              ],
+              "usage": {
+                  "output_tokens": 6320,
+                  "input_tokens": 35,
+                  "total_tokens": 6355
+              }
+            }
+            """;
+
+        var parser = new CompletionResponseParser("$.result[*].text");
+        ChatCompletionResults parsedResults = (ChatCompletionResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(parsedResults, is(new ChatCompletionResults(List.of(new ChatCompletionResults.Result("completion results")))));
+    }
+
+    public void testParse_String() throws IOException {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "result": {
+                "text":"completion results"
+              },
+              "usage": {
+                  "output_tokens": 6320,
+                  "input_tokens": 35,
+                  "total_tokens": 6355
+              }
+            }
+            """;
+
+        var parser = new CompletionResponseParser("$.result.text");
+        ChatCompletionResults parsedResults = (ChatCompletionResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(parsedResults, is(new ChatCompletionResults(List.of(new ChatCompletionResults.Result("completion results")))));
+    }
+
+    public void testParse_MultipleResults() throws IOException {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "result": [
+                  {
+                    "text":"completion results"
+                  },
+                  {
+                    "text":"completion results2"
+                  }
+              ],
+              "usage": {
+                  "output_tokens": 6320,
+                  "input_tokens": 35,
+                  "total_tokens": 6355
+              }
+            }
+            """;
+
+        var parser = new CompletionResponseParser("$.result[*].text");
+        ChatCompletionResults parsedResults = (ChatCompletionResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults,
+            is(
+                new ChatCompletionResults(
+                    List.of(new ChatCompletionResults.Result("completion results"), new ChatCompletionResults.Result("completion results2"))
+                )
+            )
+        );
+    }
+
+    public void testParse_AnthropicFormat() throws IOException {
+        String responseJson = """
+            {
+                "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb",
+                "type": "message",
+                "role": "assistant",
+                "model": "claude-3-opus-20240229",
+                "content": [
+                    {
+                        "type": "text",
+                        "text": "result"
+                    },
+                    {
+                        "type": "text",
+                        "text": "result2"
+                    }
+                ],
+                "stop_reason": "end_turn",
+                "stop_sequence": null,
+                "usage": {
+                    "input_tokens": 16,
+                    "output_tokens": 326
+                }
+            }
+            """;
+
+        var parser = new CompletionResponseParser("$.content[*].text");
+        ChatCompletionResults parsedResults = (ChatCompletionResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults,
+            is(new ChatCompletionResults(List.of(new ChatCompletionResults.Result("result"), new ChatCompletionResults.Result("result2"))))
+        );
+    }
+
+    public void testParse_ThrowsException_WhenExtractedField_IsNotAList() {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "result": "invalid_field",
+              "usage": {
+                  "output_tokens": 6320,
+                  "input_tokens": 35,
+                  "total_tokens": 6355
+              }
+            }
+            """;
+
+        var parser = new CompletionResponseParser("$.result[*].text");
+        var exception = expectThrows(
+            IllegalArgumentException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Current path [[*].text] matched the array field pattern "
+                    + "but the current object is not a list, found invalid type [String] instead."
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenExtractedField_IsNotListOfStrings() {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "result": ["string", true],
+              "usage": {
+                  "output_tokens": 6320,
+                  "input_tokens": 35,
+                  "total_tokens": 6355
+              }
+            }
+            """;
+
+        var parser = new CompletionResponseParser("$.result");
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is("Failed to parse list entry [1], error: Unable to convert field [$.result] of type [Boolean] to [String]")
+        );
+    }
+
+    public void testParse_ThrowsException_WhenExtractedField_IsNotAListOrString() {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "result": 123,
+              "usage": {
+                  "output_tokens": 6320,
+                  "input_tokens": 35,
+                  "total_tokens": 6355
+              }
+            }
+            """;
+
+        var parser = new CompletionResponseParser("$.result");
+        var exception = expectThrows(
+            IllegalArgumentException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is("Extracted field [result] from path [$.result] is an invalid type, expected a list or a string but received [Integer]")
+        );
+    }
+
+    @Override
+    protected CompletionResponseParser mutateInstanceForVersion(CompletionResponseParser instance, TransportVersion version) {
+        return instance;
+    }
+
+    @Override
+    protected Writeable.Reader<CompletionResponseParser> instanceReader() {
+        return CompletionResponseParser::new;
+    }
+
+    @Override
+    protected CompletionResponseParser createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected CompletionResponseParser mutateInstance(CompletionResponseParser instance) throws IOException {
+        return randomValueOtherThan(instance, CompletionResponseParserTests::createRandom);
+    }
+}

+ 145 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java

@@ -0,0 +1,145 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser.MESSAGE_PATH;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.sameInstance;
+import static org.mockito.Mockito.mock;
+
+public class ErrorResponseParserTests extends ESTestCase {
+
+    public static ErrorResponseParser createRandom() {
+        return new ErrorResponseParser("$." + randomAlphaOfLength(5));
+    }
+
+    public void testFromMap() {
+        var validation = new ValidationException();
+        var parser = ErrorResponseParser.fromMap(new HashMap<>(Map.of(MESSAGE_PATH, "$.error.message")), validation);
+
+        assertThat(parser, is(new ErrorResponseParser("$.error.message")));
+    }
+
+    public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() {
+        var validation = new ValidationException();
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> ErrorResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.error.message")), validation)
+        );
+
+        assertThat(exception.getMessage(), is("Validation Failed: 1: [error_parser] does not contain the required setting [path];"));
+    }
+
+    public void testToXContent() throws IOException {
+        var entity = new ErrorResponseParser("$.error.message");
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        {
+            builder.startObject();
+            entity.toXContent(builder, null);
+            builder.endObject();
+        }
+        String xContentResult = Strings.toString(builder);
+
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "error_parser": {
+                    "path": "$.error.message"
+                }
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    public void testErrorResponse_ExtractsError() throws IOException {
+        var result = getMockResult("""
+            {
+                "error": {
+                    "message": "test_error_message"
+                }
+            }""");
+
+        var parser = new ErrorResponseParser("$.error.message");
+        var error = parser.apply(result);
+        assertThat(error, is(new ErrorResponse("test_error_message")));
+    }
+
+    public void testFromResponse_WithOtherFieldsPresent() throws IOException {
+        String responseJson = """
+            {
+                "error": {
+                    "message": "You didn't provide an API key",
+                    "type": "invalid_request_error",
+                    "param": null,
+                    "code": null
+                }
+            }
+            """;
+
+        var parser = new ErrorResponseParser("$.error.message");
+        var error = parser.apply(getMockResult(responseJson));
+
+        assertThat(error, is(new ErrorResponse("You didn't provide an API key")));
+    }
+
+    public void testFromResponse_noMessage() throws IOException {
+        String responseJson = """
+            {
+              "error": {
+                "type": "not_found_error"
+              }
+            }
+            """;
+
+        var parser = new ErrorResponseParser("$.error.message");
+        var error = parser.apply(getMockResult(responseJson));
+
+        assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR));
+        assertThat(error.getErrorMessage(), is(""));
+        assertFalse(error.errorStructureFound());
+    }
+
+    public void testErrorResponse_ReturnsUndefinedObjectIfNoError() throws IOException {
+        var mockResult = getMockResult("""
+            {"noerror":true}""");
+
+        var parser = new ErrorResponseParser("$.error.message");
+        var error = parser.apply(mockResult);
+
+        assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR));
+    }
+
+    public void testErrorResponse_ReturnsUndefinedObjectIfNotJson() {
+        var result = new HttpResult(mock(HttpResponse.class), Strings.toUTF8Bytes("not a json string"));
+
+        var parser = new ErrorResponseParser("$.error.message");
+        var error = parser.apply(result);
+        assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR));
+    }
+
+    private static HttpResult getMockResult(String jsonString) throws IOException {
+        var response = mock(HttpResponse.class);
+        return new HttpResult(response, Strings.toUTF8Bytes(XContentHelper.stripWhitespace(jsonString)));
+    }
+}

+ 456 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java

@@ -0,0 +1,456 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_DOCUMENT_TEXT;
+import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_INDEX;
+import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_SCORE;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class RerankResponseParserTests extends AbstractBWCWireSerializationTestCase<RerankResponseParser> {
+
+    public static RerankResponseParser createRandom() {
+        var indexPath = randomBoolean() ? "$." + randomAlphaOfLength(5) : null;
+        var documentTextPath = randomBoolean() ? "$." + randomAlphaOfLength(5) : null;
+        return new RerankResponseParser("$." + randomAlphaOfLength(5), indexPath, documentTextPath);
+    }
+
+    public void testFromMap() {
+        var validation = new ValidationException();
+        var parser = RerankResponseParser.fromMap(
+            new HashMap<>(
+                Map.of(
+                    RERANK_PARSER_SCORE,
+                    "$.result.scores[*].score",
+                    RERANK_PARSER_INDEX,
+                    "$.result.scores[*].index",
+                    RERANK_PARSER_DOCUMENT_TEXT,
+                    "$.result.scores[*].document_text"
+                )
+            ),
+            validation
+        );
+
+        assertThat(
+            parser,
+            is(new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", "$.result.scores[*].document_text"))
+        );
+    }
+
+    public void testFromMap_WithoutOptionalFields() {
+        var validation = new ValidationException();
+        var parser = RerankResponseParser.fromMap(new HashMap<>(Map.of(RERANK_PARSER_SCORE, "$.result.scores[*].score")), validation);
+
+        assertThat(parser, is(new RerankResponseParser("$.result.scores[*].score", null, null)));
+    }
+
+    public void testFromMap_ThrowsException_WhenRequiredFieldsAreNotPresent() {
+        var validation = new ValidationException();
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> RerankResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), validation)
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is("Validation Failed: 1: [json_parser] does not contain the required setting [relevance_score];")
+        );
+    }
+
+    public void testToXContent() throws IOException {
+        var entity = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", "$.result.scores[*].document_text");
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        {
+            builder.startObject();
+            entity.toXContent(builder, null);
+            builder.endObject();
+        }
+        String xContentResult = Strings.toString(builder);
+
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "json_parser": {
+                    "relevance_score": "$.result.scores[*].score",
+                    "reranked_index": "$.result.scores[*].index",
+                    "document_text": "$.result.scores[*].document_text"
+                }
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    public void testToXContent_WithoutOptionalFields() throws IOException {
+        var entity = new RerankResponseParser("$.result.scores[*].score");
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        {
+            builder.startObject();
+            entity.toXContent(builder, null);
+            builder.endObject();
+        }
+        String xContentResult = Strings.toString(builder);
+
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "json_parser": {
+                    "relevance_score": "$.result.scores[*].score"
+                }
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    public void testParse() throws IOException {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "usage": {
+                "doc_count": 2
+              },
+              "result": {
+               "scores":[
+                 {
+                   "index":1,
+                   "score": 1.37
+                 },
+                 {
+                   "index":0,
+                   "score": -0.3
+                 }
+               ]
+              }
+            }
+            """;
+
+        var parser = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null);
+        RankedDocsResults parsedResults = (RankedDocsResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults,
+            is(
+                new RankedDocsResults(
+                    List.of(new RankedDocsResults.RankedDoc(1, 1.37f, null), new RankedDocsResults.RankedDoc(0, -0.3f, null))
+                )
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenIndex_IsInvalid() {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "usage": {
+                "doc_count": 2
+              },
+              "result": {
+               "scores":[
+                 {
+                   "index":"abc",
+                   "score": 1.37
+                 },
+                 {
+                   "index":0,
+                   "score": -0.3
+                 }
+               ]
+              }
+            }
+            """;
+
+        var parser = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null);
+
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Failed to parse rerank indices, error: Failed to parse list entry [0], "
+                    + "error: Unable to convert field [result.scores] of type [String] to Number"
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenScore_IsInvalid() {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "usage": {
+                "doc_count": 2
+              },
+              "result": {
+               "scores":[
+                 {
+                   "index":1,
+                   "score": true
+                 },
+                 {
+                   "index":0,
+                   "score": -0.3
+                 }
+               ]
+              }
+            }
+            """;
+
+        var parser = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null);
+
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Failed to parse rerank scores, error: Failed to parse list entry [0], "
+                    + "error: Unable to convert field [result.scores] of type [Boolean] to Number"
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenDocument_IsInvalid() {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "usage": {
+                "doc_count": 2
+              },
+              "result": {
+               "scores":[
+                 {
+                   "index":1,
+                   "score": 0.2,
+                   "document": 1
+                 },
+                 {
+                   "index":0,
+                   "score": -0.3,
+                   "document": "a document"
+                 }
+               ]
+              }
+            }
+            """;
+
+        var parser = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", "$.result.scores[*].document");
+
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Failed to parse rerank documents, error: Failed to parse list entry [0], error: "
+                    + "Unable to convert field [result.scores] of type [Integer] to [String]"
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenIndices_ListSizeDoesNotMatchScores() {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "usage": {
+                "doc_count": 2
+              },
+              "result": {
+                "indices": [1],
+                "scores": [0.2, 0.3],
+                "documents": ["a", "b"]
+              }
+            }
+            """;
+
+        var parser = new RerankResponseParser("$.result.scores", "$.result.indices", "$.result.documents");
+
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(exception.getMessage(), is("The number of index fields [1] was not the same as the number of scores [2]"));
+    }
+
+    public void testParse_ThrowsException_WhenDocuments_ListSizeDoesNotMatchScores() {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "usage": {
+                "doc_count": 2
+              },
+              "result": {
+                "indices": [1, 0],
+                "scores": [0.2, 0.3],
+                "documents": ["a"]
+              }
+            }
+            """;
+
+        var parser = new RerankResponseParser("$.result.scores", "$.result.indices", "$.result.documents");
+
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(exception.getMessage(), is("The number of document fields [1] was not the same as the number of scores [2]"));
+    }
+
+    public void testParse_WithoutIndex() throws IOException {
+        String responseJson = """
+            {
+              "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+              "latency": 564.903929,
+              "usage": {
+                "doc_count": 2
+              },
+              "result": {
+               "scores":[
+                 {
+                   "score": 1.37
+                 },
+                 {
+                   "score": -0.3
+                 }
+               ]
+              }
+            }
+            """;
+
+        var parser = new RerankResponseParser("$.result.scores[*].score", null, null);
+        RankedDocsResults parsedResults = (RankedDocsResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults,
+            is(
+                new RankedDocsResults(
+                    List.of(new RankedDocsResults.RankedDoc(0, 1.37f, null), new RankedDocsResults.RankedDoc(1, -0.3f, null))
+                )
+            )
+        );
+    }
+
+    public void testParse_CohereResponseFormat() throws IOException {
+        String responseJson = """
+            {
+                "index": "44873262-1315-4c06-8433-fdc90c9790d0",
+                "results": [
+                    {
+                        "document": {
+                            "text": "Washington, D.C.."
+                        },
+                        "index": 2,
+                        "relevance_score": 0.98005307
+                    },
+                    {
+                        "document": {
+                            "text": "Capital punishment has existed in the United States since beforethe United States was a country. "
+                        },
+                        "index": 3,
+                        "relevance_score": 0.27904198
+                    },
+                    {
+                        "document": {
+                            "text": "Carson City is the capital city of the American state of Nevada."
+                        },
+                        "index": 0,
+                        "relevance_score": 0.10194652
+                    }
+                ],
+                "meta": {
+                    "api_version": {
+                        "version": "1"
+                    },
+                    "billed_units": {
+                        "search_units": 1
+                    }
+                }
+            }
+            """;
+
+        var parser = new RerankResponseParser("$.results[*].relevance_score", "$.results[*].index", "$.results[*].document.text");
+        RankedDocsResults parsedResults = (RankedDocsResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults,
+            is(
+                new RankedDocsResults(
+                    List.of(
+                        new RankedDocsResults.RankedDoc(2, 0.98005307f, "Washington, D.C.."),
+                        new RankedDocsResults.RankedDoc(
+                            3,
+                            0.27904198f,
+                            "Capital punishment has existed in the United States since beforethe United States was a country. "
+                        ),
+                        new RankedDocsResults.RankedDoc(0, 0.10194652f, "Carson City is the capital city of the American state of Nevada.")
+                    )
+                )
+            )
+        );
+    }
+
+    @Override
+    protected RerankResponseParser mutateInstanceForVersion(RerankResponseParser instance, TransportVersion version) {
+        return instance;
+    }
+
+    @Override
+    protected Writeable.Reader<RerankResponseParser> instanceReader() {
+        return RerankResponseParser::new;
+    }
+
+    @Override
+    protected RerankResponseParser createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected RerankResponseParser mutateInstance(RerankResponseParser instance) throws IOException {
+        return randomValueOtherThan(instance, RerankResponseParserTests::createRandom);
+    }
+}

+ 349 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java

@@ -0,0 +1,349 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.search.WeightedToken;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH;
+import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class SparseEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase<SparseEmbeddingResponseParser> {
+
+    public static SparseEmbeddingResponseParser createRandom() {
+        return new SparseEmbeddingResponseParser(randomAlphaOfLength(5), randomAlphaOfLength(5));
+    }
+
+    public void testFromMap() {
+        var validation = new ValidationException();
+        var parser = SparseEmbeddingResponseParser.fromMap(
+            new HashMap<>(
+                Map.of(
+                    SPARSE_EMBEDDING_TOKEN_PATH,
+                    "$.result[*].embeddings[*].token",
+                    SPARSE_EMBEDDING_WEIGHT_PATH,
+                    "$.result[*].embeddings[*].weight"
+                )
+            ),
+            validation
+        );
+
+        assertThat(parser, is(new SparseEmbeddingResponseParser("$.result[*].embeddings[*].token", "$.result[*].embeddings[*].weight")));
+    }
+
+    public void testFromMap_ThrowsException_WhenRequiredFieldsAreNotPresent() {
+        var validation = new ValidationException();
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> SparseEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), validation)
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: [json_parser] does not contain the required setting [token_path];"
+                    + "2: [json_parser] does not contain the required setting [weight_path];"
+            )
+        );
+    }
+
+    public void testToXContent() throws IOException {
+        var entity = new SparseEmbeddingResponseParser("$.result.path.token", "$.result.path.weight");
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        {
+            builder.startObject();
+            entity.toXContent(builder, null);
+            builder.endObject();
+        }
+        String xContentResult = Strings.toString(builder);
+
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "json_parser": {
+                    "token_path": "$.result.path.token",
+                    "weight_path": "$.result.path.weight"
+                }
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    public void testParse() throws IOException {
+        String responseJson = """
+            {
+                "request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
+                "latency": 22,
+                "usage": {
+                    "token_count": 11
+                },
+                "result": {
+                    "sparse_embeddings": [
+                        {
+                            "index": 0,
+                            "embedding": [
+                                {
+                                    "tokenId": 6,
+                                    "weight": 0.101
+                                },
+                                {
+                                    "tokenId": 163040,
+                                    "weight": 0.28417
+                                }
+                            ]
+                        }
+                    ]
+                }
+            }
+            """;
+
+        var parser = new SparseEmbeddingResponseParser(
+            "$.result.sparse_embeddings[*].embedding[*].tokenId",
+            "$.result.sparse_embeddings[*].embedding[*].weight"
+        );
+        SparseEmbeddingResults parsedResults = (SparseEmbeddingResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults,
+            is(
+                new SparseEmbeddingResults(
+                    List.of(
+                        new SparseEmbeddingResults.Embedding(
+                            List.of(new WeightedToken("6", 0.101f), new WeightedToken("163040", 0.28417f)),
+                            false
+                        )
+                    )
+                )
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenTheTokenField_IsNotAnArray() {
+        String responseJson = """
+            {
+                "request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
+                "latency": 22,
+                "usage": {
+                    "token_count": 11
+                },
+                "result": {
+                    "sparse_embeddings": [
+                        {
+                            "index": 0,
+                            "tokenId": 6,
+                            "weight": [0.101]
+                        }
+                    ]
+                }
+            }
+            """;
+
+        var parser = new SparseEmbeddingResponseParser("$.result.sparse_embeddings[*].tokenId", "$.result.sparse_embeddings[*].weight");
+
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Failed to parse sparse embedding entry [0], error: Extracted field [result.sparse_embeddings.tokenId] "
+                    + "is an invalid type, expected a list but received [Integer]"
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenTheTokenArraySize_AndWeightArraySize_AreDifferent() {
+        String responseJson = """
+            {
+                "request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
+                "latency": 22,
+                "usage": {
+                    "token_count": 11
+                },
+                "result": {
+                    "sparse_embeddings": [
+                        {
+                            "index": 0,
+                            "tokenId": [6, 7],
+                            "weight": [0.101]
+                        }
+                    ]
+                }
+            }
+            """;
+
+        var parser = new SparseEmbeddingResponseParser("$.result.sparse_embeddings[*].tokenId", "$.result.sparse_embeddings[*].weight");
+
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Failed to parse sparse embedding entry [0], error: The extracted tokens list is size [2] "
+                    + "but the weights list is size [1]. The list sizes must be equal."
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenTheWeightValue_IsNotAFloat() {
+        String responseJson = """
+            {
+                "request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
+                "latency": 22,
+                "usage": {
+                    "token_count": 11
+                },
+                "result": {
+                    "sparse_embeddings": [
+                        {
+                            "index": 0,
+                            "tokenId": [6],
+                            "weight": [true]
+                        }
+                    ]
+                }
+            }
+            """;
+
+        var parser = new SparseEmbeddingResponseParser("$.result.sparse_embeddings[*].tokenId", "$.result.sparse_embeddings[*].weight");
+
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Failed to parse sparse embedding entry [0], error: Failed to parse weight item: "
+                    + "[0] of array, error: Unable to convert field [result.sparse_embeddings.weight] of type [Boolean] to Number"
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenTheWeightField_IsNotAnArray() {
+        String responseJson = """
+            {
+                "request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
+                "latency": 22,
+                "usage": {
+                    "token_count": 11
+                },
+                "result": {
+                    "sparse_embeddings": [
+                        {
+                            "index": 0,
+                            "tokenId": [6],
+                            "weight": 0.101
+                        }
+                    ]
+                }
+            }
+            """;
+
+        var parser = new SparseEmbeddingResponseParser("$.result.sparse_embeddings[*].tokenId", "$.result.sparse_embeddings[*].weight");
+
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Failed to parse sparse embedding entry [0], error: Extracted field [result.sparse_embeddings.weight] "
+                    + "is an invalid type, expected a list but received [Double]"
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenExtractedField_IsNotFormattedCorrectly() {
+        String responseJson = """
+            {
+                "request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
+                "latency": 22,
+                "usage": {
+                    "token_count": 11
+                },
+                "result": {
+                    "sparse_embeddings": [
+                        {
+                            "index": 0,
+                            "embedding": [
+                                {
+                                    "6": 0.101
+                                },
+                                {
+                                    "163040": 0.28417
+                                }
+                            ]
+                        }
+                    ]
+                }
+            }
+            """;
+
+        var parser = new SparseEmbeddingResponseParser(
+            "$.result.sparse_embeddings[*].embedding[*].tokenId",
+            "$.result.sparse_embeddings[*].embedding[*].weight"
+        );
+        var exception = expectThrows(
+            IllegalArgumentException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(exception.getMessage(), is("Unable to find field [tokenId] in map"));
+    }
+
+    @Override
+    protected SparseEmbeddingResponseParser mutateInstanceForVersion(SparseEmbeddingResponseParser instance, TransportVersion version) {
+        return instance;
+    }
+
+    @Override
+    protected Writeable.Reader<SparseEmbeddingResponseParser> instanceReader() {
+        return SparseEmbeddingResponseParser::new;
+    }
+
+    @Override
+    protected SparseEmbeddingResponseParser createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected SparseEmbeddingResponseParser mutateInstance(SparseEmbeddingResponseParser instance) throws IOException {
+        return randomValueOtherThan(instance, SparseEmbeddingResponseParserTests::createRandom);
+    }
+}

+ 263 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java

@@ -0,0 +1,263 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.response;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class TextEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase<TextEmbeddingResponseParser> {
+
+    public static TextEmbeddingResponseParser createRandom() {
+        return new TextEmbeddingResponseParser("$." + randomAlphaOfLength(5));
+    }
+
+    public void testFromMap() {
+        var validation = new ValidationException();
+        var parser = TextEmbeddingResponseParser.fromMap(
+            new HashMap<>(Map.of(TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result[*].embeddings")),
+            validation
+        );
+
+        assertThat(parser, is(new TextEmbeddingResponseParser("$.result[*].embeddings")));
+    }
+
+    public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() {
+        var validation = new ValidationException();
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> TextEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), validation)
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is("Validation Failed: 1: [json_parser] does not contain " + "the required setting [text_embeddings];")
+        );
+    }
+
+    public void testToXContent() throws IOException {
+        var entity = new TextEmbeddingResponseParser("$.result.path");
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        {
+            builder.startObject();
+            entity.toXContent(builder, null);
+            builder.endObject();
+        }
+        String xContentResult = Strings.toString(builder);
+
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "json_parser": {
+                    "text_embeddings": "$.result.path"
+                }
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    public void testParse() throws IOException {
+        String responseJson = """
+            {
+              "object": "list",
+              "data": [
+                  {
+                      "object": "embedding",
+                      "index": 0,
+                      "embedding": [
+                          0.014539449,
+                          -0.015288644
+                      ]
+                  }
+              ],
+              "model": "text-embedding-ada-002-v2",
+              "usage": {
+                  "prompt_tokens": 8,
+                  "total_tokens": 8
+              }
+            }
+            """;
+
+        var parser = new TextEmbeddingResponseParser("$.data[*].embedding");
+        TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults,
+            is(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))))
+        );
+    }
+
+    public void testParse_MultipleEmbeddings() throws IOException {
+        String responseJson = """
+            {
+              "object": "list",
+              "data": [
+                  {
+                      "object": "embedding",
+                      "index": 0,
+                      "embedding": [
+                          0.014539449,
+                          -0.015288644
+                      ]
+                  },
+                  {
+                      "object": "embedding",
+                      "index": 1,
+                      "embedding": [
+                          1,
+                          -2
+                      ]
+                  }
+              ],
+              "model": "text-embedding-ada-002-v2",
+              "usage": {
+                  "prompt_tokens": 8,
+                  "total_tokens": 8
+              }
+            }
+            """;
+
+        var parser = new TextEmbeddingResponseParser("$.data[*].embedding");
+        TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults,
+            is(
+                new TextEmbeddingFloatResults(
+                    List.of(
+                        new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+                        new TextEmbeddingFloatResults.Embedding(new float[] { 1F, -2F })
+                    )
+                )
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenExtractedField_IsNotAListOfFloats() {
+        String responseJson = """
+            {
+              "object": "list",
+              "data": [
+                  {
+                      "object": "embedding",
+                      "index": 0,
+                      "embedding": [
+                          1,
+                          -0.015288644
+                      ]
+                  },
+                  {
+                      "object": "embedding",
+                      "index": 0,
+                      "embedding": [
+                          true,
+                          -0.015288644
+                      ]
+                  }
+              ],
+              "model": "text-embedding-ada-002-v2",
+              "usage": {
+                  "prompt_tokens": 8,
+                  "total_tokens": 8
+              }
+            }
+            """;
+
+        var parser = new TextEmbeddingResponseParser("$.data[*].embedding");
+        var exception = expectThrows(
+            IllegalArgumentException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Failed to parse text embedding entry [1], error: Failed to parse list entry [0], error:"
+                    + " Unable to convert field [data.embedding] of type [Boolean] to Number"
+            )
+        );
+    }
+
+    public void testParse_ThrowsException_WhenExtractedField_IsNotAList() {
+        String responseJson = """
+            {
+              "object": "list",
+              "data": [
+                  {
+                      "object": "embedding",
+                      "index": 0,
+                      "embedding": 1
+                  }
+              ],
+              "model": "text-embedding-ada-002-v2",
+              "usage": {
+                  "prompt_tokens": 8,
+                  "total_tokens": 8
+              }
+            }
+            """;
+
+        var parser = new TextEmbeddingResponseParser("$.data[*].embedding");
+        var exception = expectThrows(
+            IllegalArgumentException.class,
+            () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Failed to parse text embedding entry [0], error: Extracted field [data.embedding] "
+                    + "is an invalid type, expected a list but received [Integer]"
+            )
+        );
+    }
+
+    @Override
+    protected TextEmbeddingResponseParser mutateInstanceForVersion(TextEmbeddingResponseParser instance, TransportVersion version) {
+        return instance;
+    }
+
+    @Override
+    protected Writeable.Reader<TextEmbeddingResponseParser> instanceReader() {
+        return TextEmbeddingResponseParser::new;
+    }
+
+    @Override
+    protected TextEmbeddingResponseParser createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected TextEmbeddingResponseParser mutateInstance(TextEmbeddingResponseParser instance) throws IOException {
+        return randomValueOtherThan(instance, TextEmbeddingResponseParserTests::createRandom);
+    }
+}