Quellcode durchsuchen

[ML][Inference][HLRC] Add necessary lang ident classes (#50705)

This adds the necessary named XContent classes to the HLRC for the lang ident model. This is so the HLRC can call `GET _ml/inference/lang_ident_model_1?include_definition=true` without XContent parsing errors.

The constructors are package private as since this classes are used exclusively within the pre-packaged model (and require the specific weights, etc. to be of any use).
Benjamin Trent vor 5 Jahren
Ursprung
Commit
5a3939ae44

+ 7 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java

@@ -18,12 +18,14 @@
  */
 package org.elasticsearch.client.ml.inference;
 
+import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
 import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
+import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
 import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
 import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
 import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
@@ -49,10 +51,15 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             TargetMeanEncoding::fromXContent));
         namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(FrequencyEncoding.NAME),
             FrequencyEncoding::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(CustomWordEmbedding.NAME),
+            CustomWordEmbedding::fromXContent));
 
         // Model
         namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
         namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class,
+            new ParseField(LangIdentNeuralNetwork.NAME),
+            LangIdentNeuralNetwork::fromXContent));
 
         // Aggregating output
         namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,

+ 166 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/CustomWordEmbedding.java

@@ -0,0 +1,166 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.preprocessing;
+
+import org.elasticsearch.common.CheckedFunction;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * This is a pre-processor that embeds text into a numerical vector.
+ *
+ * It calculates a set of features based on script type, ngram hashes, and most common script values.
+ *
+ * The features are then concatenated with specific quantization scales and weights into a vector of length 80.
+ *
+ * This is a fork and a port of: https://github.com/google/cld3/blob/06f695f1c8ee530104416aab5dcf2d6a1414a56a/src/embedding_network.cc
+ */
+public class CustomWordEmbedding implements PreProcessor {
+
+    public static final String NAME = "custom_word_embedding";
+    static final ParseField FIELD = new ParseField("field");
+    static final ParseField DEST_FIELD = new ParseField("dest_field");
+    static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights");
+    static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales");
+
+    public static final ConstructingObjectParser<CustomWordEmbedding, Void> PARSER = new ConstructingObjectParser<>(
+        NAME,
+        true,
+        a -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
+    static {
+        PARSER.declareField(ConstructingObjectParser.constructorArg(),
+            (p, c) -> {
+                List<List<Short>> listOfListOfShorts = parseArrays(EMBEDDING_QUANT_SCALES.getPreferredName(),
+                    XContentParser::shortValue,
+                    p);
+                short[][] primitiveShorts = new short[listOfListOfShorts.size()][];
+                int i = 0;
+                for (List<Short> shorts : listOfListOfShorts) {
+                    short[] innerShorts = new short[shorts.size()];
+                    for (int j = 0; j < shorts.size(); j++) {
+                        innerShorts[j] = shorts.get(j);
+                    }
+                    primitiveShorts[i++] = innerShorts;
+                }
+                return primitiveShorts;
+            },
+            EMBEDDING_QUANT_SCALES,
+            ObjectParser.ValueType.VALUE_ARRAY);
+        PARSER.declareField(ConstructingObjectParser.constructorArg(),
+            (p, c) -> {
+                List<byte[]> values = new ArrayList<>();
+                while(p.nextToken() != XContentParser.Token.END_ARRAY) {
+                    values.add(p.binaryValue());
+                }
+                byte[][] primitiveBytes = new byte[values.size()][];
+                int i = 0;
+                for (byte[] bytes : values) {
+                    primitiveBytes[i++] = bytes;
+                }
+                return primitiveBytes;
+            },
+            EMBEDDING_WEIGHTS,
+            ObjectParser.ValueType.VALUE_ARRAY);
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), DEST_FIELD);
+    }
+
+    private static <T> List<List<T>> parseArrays(String fieldName,
+                                                 CheckedFunction<XContentParser, T, IOException> fromParser,
+                                                 XContentParser p) throws IOException {
+        if (p.currentToken() != XContentParser.Token.START_ARRAY) {
+            throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
+        }
+        List<List<T>> values = new ArrayList<>();
+        while(p.nextToken() != XContentParser.Token.END_ARRAY) {
+            if (p.currentToken() != XContentParser.Token.START_ARRAY) {
+                throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
+            }
+            List<T> innerList = new ArrayList<>();
+            while(p.nextToken() != XContentParser.Token.END_ARRAY) {
+                if(p.currentToken().isValue() == false) {
+                    throw new IllegalStateException("expected non-null value but got [" + p.currentToken() + "] " +
+                        "for [" + fieldName + "]");
+                }
+                innerList.add(fromParser.apply(p));
+            }
+            values.add(innerList);
+        }
+        return values;
+    }
+
+    public static CustomWordEmbedding fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final short[][] embeddingsQuantScales;
+    private final byte[][] embeddingsWeights;
+    private final String fieldName;
+    private final String destField;
+
+    CustomWordEmbedding(short[][] embeddingsQuantScales, byte[][] embeddingsWeights, String fieldName, String destField) {
+        this.embeddingsQuantScales = embeddingsQuantScales;
+        this.embeddingsWeights = embeddingsWeights;
+        this.fieldName = fieldName;
+        this.destField = destField;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        builder.field(FIELD.getPreferredName(), fieldName);
+        builder.field(DEST_FIELD.getPreferredName(), destField);
+        builder.field(EMBEDDING_QUANT_SCALES.getPreferredName(), embeddingsQuantScales);
+        builder.field(EMBEDDING_WEIGHTS.getPreferredName(), embeddingsWeights);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        CustomWordEmbedding that = (CustomWordEmbedding) o;
+        return Objects.equals(fieldName, that.fieldName)
+            && Objects.equals(destField, that.destField)
+            && Arrays.deepEquals(embeddingsWeights, that.embeddingsWeights)
+            && Arrays.deepEquals(embeddingsQuantScales, that.embeddingsQuantScales);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(fieldName, destField, Arrays.deepHashCode(embeddingsQuantScales), Arrays.deepHashCode(embeddingsWeights));
+    }
+
+}

+ 108 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java

@@ -0,0 +1,108 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.langident;
+
+import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+/**
+ * Shallow, fully connected, feed forward NN modeled after and ported from https://github.com/google/cld3
+ */
+public class LangIdentNeuralNetwork implements TrainedModel {
+
+    public static final String NAME = "lang_ident_neural_network";
+    public static final ParseField EMBEDDED_VECTOR_FEATURE_NAME = new ParseField("embedded_vector_feature_name");
+    public static final ParseField HIDDEN_LAYER = new ParseField("hidden_layer");
+    public static final ParseField SOFTMAX_LAYER = new ParseField("softmax_layer");
+    public static final ConstructingObjectParser<LangIdentNeuralNetwork, Void> PARSER = new ConstructingObjectParser<>(
+        NAME,
+        true,
+        a -> new LangIdentNeuralNetwork((String) a[0],
+            (LangNetLayer) a[1],
+            (LangNetLayer) a[2]));
+
+    static {
+        PARSER.declareString(constructorArg(), EMBEDDED_VECTOR_FEATURE_NAME);
+        PARSER.declareObject(constructorArg(), LangNetLayer.PARSER::apply, HIDDEN_LAYER);
+        PARSER.declareObject(constructorArg(), LangNetLayer.PARSER::apply, SOFTMAX_LAYER);
+    }
+
+    public static LangIdentNeuralNetwork fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final LangNetLayer hiddenLayer;
+    private final LangNetLayer softmaxLayer;
+    private final String embeddedVectorFeatureName;
+
+    LangIdentNeuralNetwork(String embeddedVectorFeatureName,
+                                  LangNetLayer hiddenLayer,
+                                  LangNetLayer softmaxLayer) {
+        this.embeddedVectorFeatureName = embeddedVectorFeatureName;
+        this.hiddenLayer = hiddenLayer;
+        this.softmaxLayer = softmaxLayer;
+    }
+
+    @Override
+    public List<String> getFeatureNames() {
+        return Collections.singletonList(embeddedVectorFeatureName);
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        builder.field(EMBEDDED_VECTOR_FEATURE_NAME.getPreferredName(), embeddedVectorFeatureName);
+        builder.field(HIDDEN_LAYER.getPreferredName(), hiddenLayer);
+        builder.field(SOFTMAX_LAYER.getPreferredName(), softmaxLayer);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        LangIdentNeuralNetwork that = (LangIdentNeuralNetwork) o;
+        return Objects.equals(embeddedVectorFeatureName, that.embeddedVectorFeatureName)
+            && Objects.equals(hiddenLayer, that.hiddenLayer)
+            && Objects.equals(softmaxLayer, that.softmaxLayer);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(embeddedVectorFeatureName, hiddenLayer, softmaxLayer);
+    }
+
+}

+ 123 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/langident/LangNetLayer.java

@@ -0,0 +1,123 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.langident;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+/**
+ * Represents a single layer in the compressed Lang Net
+ */
+public class LangNetLayer implements ToXContentObject {
+
+    public static final ParseField NAME = new ParseField("lang_net_layer");
+
+    private static final ParseField NUM_ROWS = new ParseField("num_rows");
+    private static final ParseField NUM_COLS = new ParseField("num_cols");
+    private static final ParseField WEIGHTS = new ParseField("weights");
+    private static final ParseField BIAS = new ParseField("bias");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<LangNetLayer, Void> PARSER = new ConstructingObjectParser<>(
+        NAME.getPreferredName(),
+        true,
+        a -> new LangNetLayer(
+            (List<Double>) a[0],
+            (int) a[1],
+            (int) a[2],
+            (List<Double>) a[3]));
+
+    static {
+        PARSER.declareDoubleArray(constructorArg(), WEIGHTS);
+        PARSER.declareInt(constructorArg(), NUM_COLS);
+        PARSER.declareInt(constructorArg(), NUM_ROWS);
+        PARSER.declareDoubleArray(constructorArg(), BIAS);
+    }
+
+    private final double[] weights;
+    private final int weightRows;
+    private final int weightCols;
+    private final double[] bias;
+
+    private LangNetLayer(List<Double> weights, int numCols, int numRows, List<Double> bias) {
+        this(weights.stream().mapToDouble(Double::doubleValue).toArray(),
+            numCols,
+            numRows,
+            bias.stream().mapToDouble(Double::doubleValue).toArray());
+    }
+
+    LangNetLayer(double[] weights, int numCols, int numRows, double[] bias) {
+        this.weights = weights;
+        this.weightCols = numCols;
+        this.weightRows = numRows;
+        this.bias = bias;
+    }
+
+    double[] getWeights() {
+        return weights;
+    }
+
+    int getWeightRows() {
+        return weightRows;
+    }
+
+    int getWeightCols() {
+        return weightCols;
+    }
+
+    double[] getBias() {
+        return bias;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(NUM_COLS.getPreferredName(), weightCols);
+        builder.field(NUM_ROWS.getPreferredName(), weightRows);
+        builder.field(WEIGHTS.getPreferredName(), weights);
+        builder.field(BIAS.getPreferredName(), bias);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        LangNetLayer that = (LangNetLayer) o;
+        return Arrays.equals(weights, that.weights)
+            && Arrays.equals(bias, that.bias)
+            && Objects.equals(weightCols, that.weightCols)
+            && Objects.equals(weightRows, that.weightRows);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(bias), weightCols, weightRows);
+    }
+}

+ 17 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -151,6 +151,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
 import org.elasticsearch.client.ml.inference.TrainedModelStats;
 import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
 import org.elasticsearch.client.ml.job.config.AnalysisConfig;
 import org.elasticsearch.client.ml.job.config.DataDescription;
 import org.elasticsearch.client.ml.job.config.Detector;
@@ -201,6 +202,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.hasItems;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.lessThan;
 import static org.hamcrest.Matchers.not;
@@ -2278,6 +2280,21 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(0));
     }
 
+    public void testGetPrepackagedModels() throws Exception {
+        MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+
+        GetTrainedModelsResponse getTrainedModelsResponse = execute(
+            new GetTrainedModelsRequest("lang_ident_model_1").setIncludeDefinition(true),
+            machineLearningClient::getTrainedModels,
+            machineLearningClient::getTrainedModelsAsync);
+
+        assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
+        assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
+        assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo("lang_ident_model_1"));
+        assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition().getTrainedModel(),
+            instanceOf(LangIdentNeuralNetwork.class));
+    }
+
     public void testPutFilter() throws Exception {
         String filterId = "filter-job-test";
         MlFilter mlFilter = MlFilter.builder(filterId)

+ 7 - 5
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -68,6 +68,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
+import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
 import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
 import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
 import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
@@ -75,6 +76,7 @@ import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
+import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
 import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
 import org.elasticsearch.client.transform.transforms.SyncConfig;
 import org.elasticsearch.client.transform.transforms.TimeSyncConfig;
@@ -688,7 +690,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(55, namedXContents.size());
+        assertEquals(57, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -760,10 +762,10 @@ public class RestHighLevelClientTests extends ESTestCase {
                 registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
                 registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
                 registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
-        assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
-        assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
-        assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
-        assertThat(names, hasItems(Tree.NAME, Ensemble.NAME));
+        assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
+        assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
+        assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
+        assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
         assertEquals(Integer.valueOf(3),
             categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
         assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME));

+ 64 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/CustomWordEmbeddingTests.java

@@ -0,0 +1,64 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.preprocessing;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+
+public class CustomWordEmbeddingTests extends AbstractXContentTestCase<CustomWordEmbedding> {
+
+    @Override
+    protected CustomWordEmbedding doParseInstance(XContentParser parser) throws IOException {
+        return CustomWordEmbedding.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected CustomWordEmbedding createTestInstance() {
+        return createRandom();
+    }
+
+    public static CustomWordEmbedding createRandom() {
+        int quantileSize = randomIntBetween(1, 10);
+        int internalQuantSize = randomIntBetween(1, 10);
+        short[][] quantiles = new short[quantileSize][internalQuantSize];
+        for (int i = 0; i < quantileSize; i++) {
+            for (int j = 0; j < internalQuantSize; j++) {
+                quantiles[i][j] = randomShort();
+            }
+        }
+        int weightsSize = randomIntBetween(1, 10);
+        int internalWeightsSize = randomIntBetween(1, 10);
+        byte[][] weights = new byte[weightsSize][internalWeightsSize];
+        for (int i = 0; i < weightsSize; i++) {
+            for (int j = 0; j < internalWeightsSize; j++) {
+                weights[i][j] = randomByte();
+            }
+        }
+        return new CustomWordEmbedding(quantiles, weights, randomAlphaOfLength(10), randomAlphaOfLength(10));
+    }
+
+}

+ 57 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/langident/LangIdentNeuralNetworkTests.java

@@ -0,0 +1,57 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.langident;
+
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+
+import java.io.IOException;
+
+
+public class LangIdentNeuralNetworkTests extends AbstractXContentTestCase<LangIdentNeuralNetwork> {
+
+    @Override
+    protected LangIdentNeuralNetwork doParseInstance(XContentParser parser) throws IOException {
+        return LangIdentNeuralNetwork.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected LangIdentNeuralNetwork createTestInstance() {
+        return createRandom();
+    }
+
+    public static LangIdentNeuralNetwork createRandom() {
+        return new LangIdentNeuralNetwork(randomAlphaOfLength(10),
+            LangNetLayerTests.createRandom(),
+            LangNetLayerTests.createRandom());
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+}

+ 55 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/langident/LangNetLayerTests.java

@@ -0,0 +1,55 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.langident;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.stream.Stream;
+
+
+public class LangNetLayerTests extends AbstractXContentTestCase<LangNetLayer> {
+
+    @Override
+    protected LangNetLayer doParseInstance(XContentParser parser) throws IOException {
+        return LangNetLayer.PARSER.apply(parser, null);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected LangNetLayer createTestInstance() {
+        return createRandom();
+    }
+
+    public static LangNetLayer createRandom() {
+        int numWeights = randomIntBetween(1, 1000);
+        return new LangNetLayer(
+            Stream.generate(ESTestCase::randomDouble).limit(numWeights).mapToDouble(Double::doubleValue).toArray(),
+            numWeights,
+            1,
+            Stream.generate(ESTestCase::randomDouble).limit(numWeights).mapToDouble(Double::doubleValue).toArray());
+    }
+
+}

+ 10 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

@@ -198,6 +198,16 @@ public class TrainedModelIT extends ESRestTestCase {
         assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
     }
 
+    public void testGetPrePackagedModels() throws IOException {
+        Response getModel = client().performRequest(new Request("GET",
+            MachineLearning.BASE_PATH + "inference/lang_ident_model_1?human=true&include_model_definition=true"));
+
+        assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
+        String response = EntityUtils.toString(getModel.getEntity());
+        assertThat(response, containsString("lang_ident_model_1"));
+        assertThat(response, containsString("\"definition\""));
+    }
+
     private static String buildRegressionModel(String modelId) throws IOException {
         try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
             TrainedModelConfig.builder()