Browse Source

[ML] fix LangIdent model when multiple unicode scripts are present (#81876)

LangIdent was recently updated to handle multiple unicode scripts (#80675). But this introduced some bugs fixed with this commit.

1. Sections with the same scripted were weighted by Java string length (utf-16) encoding. This is not accurate as certain languages (like Chinese and Korean) convey much more information with fewer utf-16 characters. FIX weight by utf-8 length.
2. The weighing of different language scores was done via the raw score from the neural network. This caused languages with a high score (but low compared to most likely language) from the network to be inaccurately weighted. FIX We are now instead weighing the probabilities of the sections of the text.
3. To split the input across the multiple scripts, we split on the "paired down" CDL3 script types. Java has superior support for unicode script blocks. FIX split by Java unicode script blocks not by the paired down CDL3 scripts
Benjamin Trent 3 years ago
parent
commit
4b0864d9b3

+ 16 - 12
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java

@@ -20,11 +20,11 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembeddi
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
-import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
 import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;
 import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;
 
 
 import java.io.IOException;
 import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
@@ -45,16 +45,16 @@ import java.util.stream.Collectors;
 public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
 public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
 
 
     public static class StringLengthAndEmbedding {
     public static class StringLengthAndEmbedding {
-        final int stringLen;
+        final int utf8StringLen;
         final double[] embedding;
         final double[] embedding;
 
 
-        public StringLengthAndEmbedding(int stringLen, double[] embedding) {
-            this.stringLen = stringLen;
+        public StringLengthAndEmbedding(int utf8StringLen, double[] embedding) {
+            this.utf8StringLen = utf8StringLen;
             this.embedding = embedding;
             this.embedding = embedding;
         }
         }
 
 
-        public int getStringLen() {
-            return stringLen;
+        public int getUtf8StringLen() {
+            return utf8StringLen;
         }
         }
 
 
         public double[] getEmbedding() {
         public double[] getEmbedding() {
@@ -258,7 +258,7 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
             if (i >= codePoints.length) {
             if (i >= codePoints.length) {
                 break;
                 break;
             }
             }
-            ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
+            Character.UnicodeScript currentCode = Character.UnicodeScript.of(codePoints[i]);
             int j = i + 1;
             int j = i + 1;
             for (; j < codePoints.length; j++) {
             for (; j < codePoints.length; j++) {
                 while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
                 while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
@@ -267,11 +267,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
                 if (j >= codePoints.length) {
                 if (j >= codePoints.length) {
                     break;
                     break;
                 }
                 }
-                ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
-                if (j1 != currentCode && j1 != ScriptCode.Inherited) {
+                Character.UnicodeScript j1 = Character.UnicodeScript.of(codePoints[j]);
+                if (j1 != currentCode && j1 != Character.UnicodeScript.INHERITED) {
                     if (j < codePoints.length - 1) {
                     if (j < codePoints.length - 1) {
-                        ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
-                        if (j2 != ScriptCode.Common && j2 != currentCode) {
+                        Character.UnicodeScript j2 = Character.UnicodeScript.of(codePoints[j + 1]);
+                        if (j2 != Character.UnicodeScript.COMMON && j2 != currentCode) {
                             break;
                             break;
                         }
                         }
                     }
                     }
@@ -290,7 +290,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
             embeddings.add(
             embeddings.add(
                 new StringLengthAndEmbedding(
                 new StringLengthAndEmbedding(
                     // Don't count white spaces as bytes for the prediction
                     // Don't count white spaces as bytes for the prediction
-                    str.trim().length(),
+                    // We ues utf-8 length here as
+                    // * The original C++ implementation does this when measuring string length
+                    // * Languages with complex characters (like zh) convey more information per a single utf-16 character and
+                    // using utf-8 length captures that.
+                    str.trim().getBytes(StandardCharsets.UTF_8).length,
                     concatEmbeddings(
                     concatEmbeddings(
                         FEATURE_EXTRACTORS.stream()
                         FEATURE_EXTRACTORS.stream()
                             .map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
                             .map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))

+ 4 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java

@@ -229,23 +229,22 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
             );
             );
         }
         }
         List<?> embeddedVector = (List<?>) vector;
         List<?> embeddedVector = (List<?>) vector;
-        double[] scores = new double[LANGUAGE_NAMES.size()];
+        double[] probabilities = new double[LANGUAGE_NAMES.size()];
         int totalLen = 0;
         int totalLen = 0;
         for (Object vec : embeddedVector) {
         for (Object vec : embeddedVector) {
             if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) {
             if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) {
                 continue;
                 continue;
             }
             }
             CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec;
             CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec;
-            int square = stringLengthAndEmbedding.getStringLen() * stringLengthAndEmbedding.getStringLen();
+            int square = stringLengthAndEmbedding.getUtf8StringLen() * stringLengthAndEmbedding.getUtf8StringLen();
             totalLen += square;
             totalLen += square;
             double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
             double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
             double[] score = softmaxLayer.productPlusBias(true, h0);
             double[] score = softmaxLayer.productPlusBias(true, h0);
-            sumDoubleArrays(scores, score, Math.max(square, 1));
+            sumDoubleArrays(probabilities, softMax(score), Math.max(square, 1));
         }
         }
         if (totalLen != 0) {
         if (totalLen != 0) {
-            divMut(scores, totalLen);
+            divMut(probabilities, totalLen);
         }
         }
-        double[] probabilities = softMax(scores);
         ClassificationConfig classificationConfig = (ClassificationConfig) config;
         ClassificationConfig classificationConfig = (ClassificationConfig) config;
         Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
         Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
             probabilities,
             probabilities,

+ 37 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java

@@ -29,7 +29,6 @@ import java.util.Map;
 
 
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.closeTo;
-import static org.hamcrest.Matchers.greaterThan;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
 
 
 public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
 public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
@@ -103,6 +102,12 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
         singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
         singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
         assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
         assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
 
 
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj("매트 스미스는 BBC äôs Doctor Who를 그만둔다."),
+            classificationConfig
+        );
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
+
         singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
         singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
             inferenceObj(
             inferenceObj(
                 "@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
                 "@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
@@ -112,6 +117,34 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
         );
         );
         assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
         assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
 
 
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj(
+                "김걸도혁(金乞都革) 김공소(金公疎) 김교합(金咬哈) 김다롱합(金多弄哈) 김마상개(金麻尙介) 김우리개(金于里介) 김상미(金尙美) 김아도을치(金阿都乙赤) "
+                    + "김아라(金阿喇) 김아랑합(金阿郞哈) 김아을가(金阿乙加) 김역류(金易留) 김우두(金于豆) 김우허내(金右虛乃) 김유리가(金留里加) 김윤적(金允績) "
+                    + "김이랑합(金伊郞哈) 김인을개(金引乙介) 김입성(金入成) 김주창개(金主昌介) 김지하리(金之下里) 김차독(金箚禿) 김지칭가(金只稱哥) 김자라노(金者羅老)."
+            ),
+            classificationConfig
+        );
+        // Half the string is ko the other half is zh
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
+        assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.5, 0.1));
+        assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("zh"));
+        assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.5, 0.1));
+
+        singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
+            inferenceObj(
+                "[ Republic of Korea ],\n"
+                    + "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
+                    + "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
+                    + "        !대한민국(, 영어: Republic of Korea, KOR)은 동아시아의 한반도 남부에 자리한 민주공화국이다. 서쪽으로 중화인민공화국과 황해를 사이에 두고"
+            ),
+            classificationConfig
+        );
+        // Majority of the text is obviously Thai, but a close second is Korean
+        assertThat(singleValueInferenceResults.valueAsString(), equalTo("th"));
+        assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.6, 0.1));
+        assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("ko"));
+        assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.4, 0.1));
     }
     }
 
 
     public void testLangInference() throws Exception {
     public void testLangInference() throws Exception {
@@ -131,7 +164,9 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
             );
             );
 
 
             assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
             assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
-            Matcher<Double> matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001);
+            // The stored language example is a mixture of `ja` and other languages, it should not be predicted with 1.0 accuracy as the
+            // cld3 probability indicates.
+            Matcher<Double> matcher = entry.getLanguage().equals("ja") ? closeTo(cld3Probability, 0.11) : closeTo(cld3Probability, .01);
             assertThat(
             assertThat(
                 "mismatch probability for language " + cld3Actual,
                 "mismatch probability for language " + cld3Actual,
                 singleValueInferenceResults.getTopClasses().get(0).getProbability(),
                 singleValueInferenceResults.getTopClasses().get(0).getProbability(),