Browse Source

[ML] Fix language identification bug when multi-languages are present (#80675)

Language identification works fairly well when only one language and
script type is present. But when multiple are present, it can return
some unexpected results Example: "행 레이블 this is english text obviously
and 생성 tom said to test it" Which appears to a human to be english text
(Latin unicode) with Korean via Hangul unicode is erroneously
categorized as Japanese. It should be categorized as English as it is
the dominate language and script type. This commit fixes this bug by
doing the following:  - Input text is partitioned into common,
continuous, unicode script    sections  - Those sections individual
language scores are gathered  - Each score is then weighted according to
the number of characters in    each section  - The resulting weight
scores are transformed into probabilities  - The final probabilities are
the ones returned to the user.
Benjamin Trent 3 years ago
parent
commit
49517dadd5

+ 88 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java

@@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembeddi
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
 import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;
 
@@ -43,6 +44,24 @@ import java.util.stream.Collectors;
  */
 public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
 
+    public static class StringLengthAndEmbedding {
+        final int stringLen;
+        final double[] embedding;
+
+        public StringLengthAndEmbedding(int stringLen, double[] embedding) {
+            this.stringLen = stringLen;
+            this.embedding = embedding;
+        }
+
+        public int getStringLen() {
+            return stringLen;
+        }
+
+        public double[] getEmbedding() {
+            return embedding;
+        }
+    }
+
     private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CustomWordEmbedding.class);
     public static final int MAX_STRING_SIZE_IN_BYTES = 10000;
     public static final ParseField NAME = new ParseField("custom_word_embedding");
@@ -213,11 +232,75 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
         String text = (String) field;
         text = FeatureUtils.cleanAndLowerText(text);
         text = FeatureUtils.truncateToNumValidBytes(text, MAX_STRING_SIZE_IN_BYTES);
-        String finalText = text;
-        List<FeatureValue[]> processedFeatures = FEATURE_EXTRACTORS.stream()
-            .map((featureExtractor) -> featureExtractor.extractFeatures(finalText))
-            .collect(Collectors.toList());
-        fields.put(destField, concatEmbeddings(processedFeatures));
+        final String finalText = text;
+        if (finalText.isEmpty() || finalText.isBlank()) {
+            fields.put(
+                destField,
+                Collections.singletonList(
+                    new StringLengthAndEmbedding(
+                        0,
+                        concatEmbeddings(
+                            FEATURE_EXTRACTORS.stream()
+                                .map((featureExtractor) -> featureExtractor.extractFeatures(finalText))
+                                .collect(Collectors.toList())
+                        )
+                    )
+                )
+            );
+            return;
+        }
+        List<StringLengthAndEmbedding> embeddings = new ArrayList<>();
+        int[] codePoints = finalText.codePoints().toArray();
+        for (int i = 0; i < codePoints.length - 1;) {
+            while (i < codePoints.length - 1 && Character.isLetter(codePoints[i]) == false) {
+                i++;
+            }
+            if (i >= codePoints.length) {
+                break;
+            }
+            ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
+            int j = i + 1;
+            for (; j < codePoints.length; j++) {
+                while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
+                    j++;
+                }
+                if (j >= codePoints.length) {
+                    break;
+                }
+                ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
+                if (j1 != currentCode && j1 != ScriptCode.Inherited) {
+                    if (j < codePoints.length - 1) {
+                        ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
+                        if (j2 != ScriptCode.Common && j2 != currentCode) {
+                            break;
+                        }
+                    }
+                }
+            }
+            // Knowing the start and the end of the section is important for feature building, so make sure its wrapped in spaces
+            String str = new String(codePoints, i, j - i);
+            StringBuilder builder = new StringBuilder();
+            if (str.startsWith(" ") == false) {
+                builder.append(" ");
+            }
+            builder.append(str);
+            if (str.endsWith(" ") == false) {
+                builder.append(" ");
+            }
+            embeddings.add(
+                new StringLengthAndEmbedding(
+                    // Don't count white spaces as bytes for the prediction
+                    str.trim().length(),
+                    concatEmbeddings(
+                        FEATURE_EXTRACTORS.stream()
+                            .map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
+                            .collect(Collectors.toList())
+                    )
+                )
+            );
+            i = j;
+        }
+        fields.put(destField, embeddings);
     }
 
     @Override

+ 17 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

@@ -188,13 +188,29 @@ public final class InferenceHelpers {
     }
 
     public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
+        return sumDoubleArrays(sumTo, inc, 1);
+    }
+
+    public static double[] sumDoubleArrays(double[] sumTo, double[] inc, int weight) {
         assert sumTo != null && inc != null && sumTo.length == inc.length;
         for (int i = 0; i < inc.length; i++) {
-            sumTo[i] += inc[i];
+            sumTo[i] += (inc[i] * weight);
         }
         return sumTo;
     }
 
+    public static void divMut(double[] xs, int v) {
+        if (xs.length == 0) {
+            return;
+        }
+        if (v == 0) {
+            throw new IllegalArgumentException("unable to divide by [" + v + "] as it results in undefined behavior");
+        }
+        for (int i = 0; i < xs.length; i++) {
+            xs[i] /= v;
+        }
+    }
+
     public static class TopClassificationValue {
         private final int value;
         private final double probability;

+ 22 - 14
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java

@@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
@@ -36,6 +37,8 @@ import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.divMut;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
 import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
 
 public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel, InferenceModel {
@@ -217,27 +220,32 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
             throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName());
         }
         Object vector = fields.get(embeddedVectorFeatureName);
-        if (vector instanceof double[] == false) {
+        if (vector instanceof List<?> == false) {
             throw ExceptionsHelper.badRequestException(
-                "[{}] model could not find non-null numerical array named [{}]",
+                "[{}] model could not find non-null collection of embeddings separated by unicode script type [{}]. "
+                    + "Please verify that the input is a string.",
                 NAME.getPreferredName(),
                 embeddedVectorFeatureName
             );
         }
-        double[] embeddedVector = (double[]) vector;
-        if (embeddedVector.length != EMBEDDING_VECTOR_LENGTH) {
-            throw ExceptionsHelper.badRequestException(
-                "[{}] model is expecting embedding vector of length [{}] but got [{}]",
-                NAME.getPreferredName(),
-                EMBEDDING_VECTOR_LENGTH,
-                embeddedVector.length
-            );
+        List<?> embeddedVector = (List<?>) vector;
+        double[] scores = new double[LANGUAGE_NAMES.size()];
+        int totalLen = 0;
+        for (Object vec : embeddedVector) {
+            if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) {
+                continue;
+            }
+            CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec;
+            int square = stringLengthAndEmbedding.getStringLen() * stringLengthAndEmbedding.getStringLen();
+            totalLen += square;
+            double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
+            double[] score = softmaxLayer.productPlusBias(true, h0);
+            sumDoubleArrays(scores, score, Math.max(square, 1));
+        }
+        if (totalLen != 0) {
+            divMut(scores, totalLen);
         }
-        double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector);
-        double[] scores = softmaxLayer.productPlusBias(true, h0);
-
         double[] probabilities = softMax(scores);
-
         ClassificationConfig classificationConfig = (ClassificationConfig) config;
         Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
             probabilities,

+ 108 - 16
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java

@@ -20,30 +20,103 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.Inferenc
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
+import org.hamcrest.Matcher;
 
+import java.io.IOException;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.greaterThan;
 import static org.mockito.Mockito.mock;
 
 public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
 
-    public void testLangInference() throws Exception {
-        TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
-        PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
-        // Should be OK as we don't make any client calls
-        trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
-        TrainedModelConfig config = future.actionGet();
+    public void testAdverseScenarios() throws Exception {
+        InferenceDefinition inferenceDefinition = grabModel();
+        ClassificationConfig classificationConfig = new ClassificationConfig(5);
 
-        config.ensureParsedDefinition(xContentRegistry());
-        TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
-        InferenceDefinition inferenceDefinition = new InferenceDefinition(
-            (LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(),
-            trainedModelDefinition.getPreProcessors()
+        ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj(""),
+            classificationConfig
+        );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));
+
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj("     "),
+            classificationConfig
+        );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));
+
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj("!@#$%^&*()"),
+            classificationConfig
+        );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));
+
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj("1234567890"),
+            classificationConfig
+        );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj("-----=-=--=-=+__+_+__==-=-!@#$%^&*()"),
+            classificationConfig
+        );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));
+
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("A"), classificationConfig);
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("lb"));
+
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj("„ÍÎÏ◊˝Ïδ„€‹›fifl‡°·‚∏ØÒÚÒ˘ÚÆ’ÆÚ”∏Ø\uF8FFÔÓ˝Ïδ„‹›fiˇflÁ¨ˆØ"),
+            classificationConfig
+        );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("vi"));
+
+        // Should not throw
+        inferenceDefinition.infer(inferenceObj("행 A A"), classificationConfig);
+        inferenceDefinition.infer(inferenceObj("행 A성 xx"), classificationConfig);
+        inferenceDefinition.infer(inferenceObj("행 A성 성x"), classificationConfig);
+        inferenceDefinition.infer(inferenceObj("행A A성 x성"), classificationConfig);
+        inferenceDefinition.infer(inferenceObj("행A 성 x"), classificationConfig);
+    }
+
+    public void testMixedLangInference() throws Exception {
+        InferenceDefinition inferenceDefinition = grabModel();
+        ClassificationConfig classificationConfig = new ClassificationConfig(5);
+
+        ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj("행 레이블 this is english text obviously and 생성 tom said to test it "),
+            classificationConfig
         );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("en"));
+
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj("행 레이블 Dashboard ISSUE Qual. Plan Qual. Report Qual. 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"),
+            classificationConfig
+        );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
+
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
+
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj(
+                "@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
+                    + " 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"
+            ),
+            classificationConfig
+        );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
+
+    }
+
+    public void testLangInference() throws Exception {
+
+        InferenceDefinition inferenceDefinition = grabModel();
         List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples();
         ClassificationConfig classificationConfig = new ClassificationConfig(1);
 
@@ -52,23 +125,42 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
             String cld3Actual = entry.getPredictedLanguage();
             double cld3Probability = entry.getProbability();
 
-            Map<String, Object> inferenceFields = new HashMap<>();
-            inferenceFields.put("text", text);
             ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
-                inferenceFields,
+                inferenceObj(text),
                 classificationConfig
             );
 
             assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
-            double eps = entry.getLanguage().equals("hr") ? 0.001 : 0.00001;
+            Matcher<Double> matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001);
             assertThat(
                 "mismatch probability for language " + cld3Actual,
                 singleValueInferenceResults.getTopClasses().get(0).getProbability(),
-                closeTo(cld3Probability, eps)
+                matcher
             );
         }
     }
 
+    InferenceDefinition grabModel() throws IOException {
+        TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
+        PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
+        // Should be OK as we don't make any client calls
+        trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
+        TrainedModelConfig config = future.actionGet();
+
+        config.ensureParsedDefinition(xContentRegistry());
+        TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
+        return new InferenceDefinition(
+            (LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(),
+            trainedModelDefinition.getPreProcessors()
+        );
+    }
+
+    private static Map<String, Object> inferenceObj(String text) {
+        Map<String, Object> inferenceFields = new HashMap<>();
+        inferenceFields.put("text", text);
+        return inferenceFields;
+    }
+
     @Override
     protected NamedXContentRegistry xContentRegistry() {
         return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());