Browse Source

[ML] add new text_embedding NLP task for model inference (#78025)

Adds new text_embedding NLP task for model inference.

Initial support assumes that the output layer of the model is the embedding. Consequently, before the pytorch script is uploaded, it is best that the pooling layer IS the embedding.

The output format is a single dimension array of doubles. This is because the dense_vector mapped field does not support multiple values (yet) and it's natural for text_embedding outputs be used in the dense_vector field.
Benjamin Trent 4 years ago
parent
commit
809d6b985e

+ 11 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
 import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
@@ -48,6 +49,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassification
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
@@ -179,6 +181,10 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             new ParseField(PassThroughConfig.NAME), PassThroughConfig::fromXContentLenient));
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(PassThroughConfig.NAME),
             PassThroughConfig::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class,
+            new ParseField(TextEmbeddingConfig.NAME), TextEmbeddingConfig::fromXContentLenient));
+        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(TextEmbeddingConfig.NAME),
+            TextEmbeddingConfig::fromXContentStrict));
 
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ClassificationConfigUpdate.NAME,
             ClassificationConfigUpdate::fromXContentStrict));
@@ -271,6 +277,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
             TextClassificationResults.NAME,
             TextClassificationResults::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
+            TextEmbeddingResults.NAME,
+            TextEmbeddingResults::new));
 
         // Inference Configs
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
@@ -285,6 +294,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             TextClassificationConfig.NAME, TextClassificationConfig::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
             PassThroughConfig.NAME, PassThroughConfig::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
+            TextEmbeddingConfig.NAME, TextEmbeddingConfig::new));
 
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
             ClassificationConfigUpdate.NAME.getPreferredName(), ClassificationConfigUpdate::new));

+ 79 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java

@@ -0,0 +1,79 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Map;
+
+public class TextEmbeddingResults implements InferenceResults {
+
+    public static final String NAME = "text_embedding_result";
+    static final String DEFAULT_RESULTS_FIELD = "results";
+
+    private static final ParseField INFERENCE = new ParseField("inference");
+
+    private final double[] inference;
+
+    public TextEmbeddingResults(double[] inference) {
+        this.inference = inference;
+    }
+
+    public TextEmbeddingResults(StreamInput in) throws IOException {
+        inference = in.readDoubleArray();
+    }
+
+    public double[] getInference() {
+        return inference;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(INFERENCE.getPreferredName(), inference);
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeDoubleArray(inference);
+    }
+
+    @Override
+    public Map<String, Object> asMap() {
+        return Collections.singletonMap(DEFAULT_RESULTS_FIELD, inference);
+    }
+
+    @Override
+    public Object predictedValue() {
+        throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TextEmbeddingResults that = (TextEmbeddingResults) o;
+        return Arrays.equals(inference, that.inference);
+    }
+
+    @Override
+    public int hashCode() {
+        return Arrays.hashCode(inference);
+    }
+}

+ 141 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfig.java

@@ -0,0 +1,141 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+import java.util.Optional;
+
+public class TextEmbeddingConfig implements NlpConfig {
+
+    public static final String NAME = "text_embedding";
+
+    public static TextEmbeddingConfig fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null);
+    }
+
+    public static TextEmbeddingConfig fromXContentLenient(XContentParser parser) {
+        return LENIENT_PARSER.apply(parser, null);
+    }
+
+    private static final ConstructingObjectParser<TextEmbeddingConfig, Void> STRICT_PARSER = createParser(false);
+    private static final ConstructingObjectParser<TextEmbeddingConfig, Void> LENIENT_PARSER = createParser(true);
+
+    private static ConstructingObjectParser<TextEmbeddingConfig, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<TextEmbeddingConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
+            a -> new TextEmbeddingConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
+        parser.declareObject(
+            ConstructingObjectParser.optionalConstructorArg(),
+            (p, c) -> {
+                if (ignoreUnknownFields == false) {
+                    throw ExceptionsHelper.badRequestException(
+                        "illegal setting [{}] on inference model creation",
+                        VOCABULARY.getPreferredName()
+                    );
+                }
+                return VocabularyConfig.fromXContentLenient(p);
+            },
+            VOCABULARY
+        );
+        parser.declareNamedObject(
+            ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
+            TOKENIZATION
+        );
+        return parser;
+    }
+
+    private final VocabularyConfig vocabularyConfig;
+    private final Tokenization tokenization;
+
+    public TextEmbeddingConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
+        this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
+            .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
+        this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
+    }
+
+    public TextEmbeddingConfig(StreamInput in) throws IOException {
+        vocabularyConfig = new VocabularyConfig(in);
+        tokenization = in.readNamedWriteable(Tokenization.class);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
+        NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        vocabularyConfig.writeTo(out);
+        out.writeNamedWriteable(tokenization);
+    }
+
+    @Override
+    public boolean isTargetTypeSupported(TargetType targetType) {
+        return false;
+    }
+
+    @Override
+    public Version getMinimalSupportedVersion() {
+        return Version.V_8_0_0;
+    }
+
+    @Override
+    public boolean isAllocateOnly() {
+        return true;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == this) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+
+        TextEmbeddingConfig that = (TextEmbeddingConfig) o;
+        return Objects.equals(vocabularyConfig, that.vocabularyConfig)
+            && Objects.equals(tokenization, that.tokenization);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(vocabularyConfig, tokenization);
+    }
+
+    @Override
+    public VocabularyConfig getVocabularyConfig() {
+        return vocabularyConfig;
+    }
+
+    @Override
+    public Tokenization getTokenization() {
+        return tokenization;
+    }
+}

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java

@@ -21,7 +21,7 @@ public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTe
     protected Request createTestInstance() {
         String modelId = randomAlphaOfLength(10);
         return new Request(
-            TrainedModelConfigTests.createTestInstance(modelId)
+            TrainedModelConfigTests.createTestInstance(modelId, false)
                 .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
                 .build(),
             randomBoolean()

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java

@@ -20,7 +20,7 @@ public class PutTrainedModelActionResponseTests extends AbstractWireSerializingT
     @Override
     protected Response createTestInstance() {
         String modelId = randomAlphaOfLength(10);
-        return new Response(TrainedModelConfigTests.createTestInstance(modelId)
+        return new Response(TrainedModelConfigTests.createTestInstance(modelId, randomBoolean())
             .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
             .build());
     }

+ 28 - 6
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java

@@ -24,8 +24,14 @@ import org.elasticsearch.license.License;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigTests;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocationTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.MlStrings;
 import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
@@ -53,9 +59,26 @@ import static org.hamcrest.Matchers.not;
 public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<TrainedModelConfig> {
 
     private boolean lenient;
-
     public static TrainedModelConfig.Builder createTestInstance(String modelId) {
+        return createTestInstance(modelId, false);
+    }
 
+    public static TrainedModelConfig.Builder createTestInstance(String modelId, boolean lenient) {
+
+        InferenceConfig[] inferenceConfigs = lenient ?
+            // Because of vocab config validations on parse, only test on lenient
+            new InferenceConfig[] {
+                ClassificationConfigTests.randomClassificationConfig(),
+                RegressionConfigTests.randomRegressionConfig(),
+                NerConfigTests.createRandom(),
+                PassThroughConfigTests.createRandom(),
+                TextClassificationConfigTests.createRandom(),
+                FillMaskConfigTests.createRandom(),
+                TextEmbeddingConfigTests.createRandom()
+           } : new InferenceConfig[] {
+                ClassificationConfigTests.randomClassificationConfig(),
+                RegressionConfigTests.randomRegressionConfig()
+          };
         List<String> tags = Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false));
         return TrainedModelConfig.builder()
             .setInput(TrainedModelInputTests.createRandomInput())
@@ -70,8 +93,7 @@ public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<Tr
             .setEstimatedOperations(randomNonNegativeLong())
             .setLicenseLevel(randomFrom(License.OperationMode.PLATINUM.description(),
                 License.OperationMode.BASIC.description()))
-            .setInferenceConfig(randomFrom(ClassificationConfigTests.randomClassificationConfig(),
-                RegressionConfigTests.randomRegressionConfig()))
+            .setInferenceConfig(randomFrom(inferenceConfigs))
             .setTags(tags)
             .setLocation(randomBoolean() ? null : IndexLocationTests.randomInstance());
     }
@@ -98,7 +120,7 @@ public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<Tr
 
     @Override
     protected TrainedModelConfig createTestInstance() {
-        return createTestInstance(randomAlphaOfLength(10)).build();
+        return createTestInstance(randomAlphaOfLength(10), lenient).build();
     }
 
     @Override
@@ -278,7 +300,7 @@ public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<Tr
             () -> {
             try {
                 BytesReference bytes = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build());
-                return createTestInstance(randomAlphaOfLength(10))
+                return createTestInstance(randomAlphaOfLength(10), lenient)
                     .setDefinitionFromBytes(bytes)
                     .build();
             } catch (IOException ex) {
@@ -310,7 +332,7 @@ public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<Tr
                 try {
                     BytesReference bytes =
                         InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build());
-                    return createTestInstance(randomAlphaOfLength(10))
+                    return createTestInstance(randomAlphaOfLength(10), lenient)
                         .setDefinitionFromBytes(bytes)
                         .build();
                 } catch (IOException ex) {

+ 40 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java

@@ -0,0 +1,40 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.util.Map;
+
+import static org.hamcrest.Matchers.hasSize;
+
+public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingResults> {
+    @Override
+    protected Writeable.Reader<TextEmbeddingResults> instanceReader() {
+        return TextEmbeddingResults::new;
+    }
+
+    @Override
+    protected TextEmbeddingResults createTestInstance() {
+        int columns = randomIntBetween(1, 10);
+        double[] arr = new double[columns];
+        for (int i=0; i<columns; i++) {
+            arr[i] = randomDouble();
+        }
+
+        return new TextEmbeddingResults(arr);
+    }
+
+    public void testAsMap() {
+        TextEmbeddingResults testInstance = createTestInstance();
+        Map<String, Object> asMap = testInstance.asMap();
+        assertThat(asMap.keySet(), hasSize(1));
+        assertArrayEquals(testInstance.getInference(), (double[]) asMap.get(TextEmbeddingResults.DEFAULT_RESULTS_FIELD), 1e-10);
+    }
+}

+ 58 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java

@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
+
+import java.io.IOException;
+import java.util.function.Predicate;
+
+public class TextEmbeddingConfigTests extends InferenceConfigItemTestCase<TextEmbeddingConfig> {
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> field.isEmpty() == false;
+    }
+
+    @Override
+    protected TextEmbeddingConfig doParseInstance(XContentParser parser) throws IOException {
+        return TextEmbeddingConfig.fromXContentLenient(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<TextEmbeddingConfig> instanceReader() {
+        return TextEmbeddingConfig::new;
+    }
+
+    @Override
+    protected TextEmbeddingConfig createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected TextEmbeddingConfig mutateInstanceForVersion(TextEmbeddingConfig instance, Version version) {
+        return instance;
+    }
+
+    public static TextEmbeddingConfig createRandom() {
+        return new TextEmbeddingConfig(
+            randomBoolean() ? null : VocabularyConfigTests.createRandom(),
+            randomBoolean() ?
+                null :
+                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())
+            );
+    }
+}

+ 7 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java

@@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 
 import java.util.Locale;
@@ -41,6 +42,12 @@ public enum TaskType {
         public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
             return new PassThroughProcessor(tokenizer, (PassThroughConfig) config);
         }
+    },
+    TEXT_EMBEDDING {
+        @Override
+        public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
+            return new TextEmbeddingProcessor(tokenizer, (TextEmbeddingConfig) config);
+        }
     };
 
     public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {

+ 49 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java

@@ -0,0 +1,49 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
+import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+
+import java.util.List;
+
+/**
+ * A NLP processor that returns a single double[] output from the model. Assumes that only one tensor is returned via inference
+ **/
+public class TextEmbeddingProcessor implements NlpTask.Processor {
+
+    private final NlpTask.RequestBuilder requestBuilder;
+
+    TextEmbeddingProcessor(NlpTokenizer tokenizer, TextEmbeddingConfig config) {
+        this.requestBuilder = tokenizer.requestBuilder();
+    }
+
+    @Override
+    public void validateInputs(List<String> inputs) {
+        // nothing to validate
+    }
+
+    @Override
+    public NlpTask.RequestBuilder getRequestBuilder() {
+        return requestBuilder;
+    }
+
+    @Override
+    public NlpTask.ResultProcessor getResultProcessor() {
+        return TextEmbeddingProcessor::processResult;
+    }
+
+    private static InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
+        // TODO - process all results in the batch
+        return new TextEmbeddingResults(pyTorchResult.getInferenceResult()[0][0]);
+    }
+}