|
@@ -20,11 +20,11 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembeddi
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
|
|
-import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
|
|
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
|
|
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;
|
|
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;
|
|
|
|
|
|
import java.io.IOException;
|
|
import java.io.IOException;
|
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
import java.util.ArrayList;
|
|
import java.util.ArrayList;
|
|
import java.util.Arrays;
|
|
import java.util.Arrays;
|
|
import java.util.Collections;
|
|
import java.util.Collections;
|
|
@@ -45,16 +45,16 @@ import java.util.stream.Collectors;
|
|
public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
|
public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
|
|
|
|
|
public static class StringLengthAndEmbedding {
|
|
public static class StringLengthAndEmbedding {
|
|
- final int stringLen;
|
|
|
|
|
|
+ final int utf8StringLen;
|
|
final double[] embedding;
|
|
final double[] embedding;
|
|
|
|
|
|
- public StringLengthAndEmbedding(int stringLen, double[] embedding) {
|
|
|
|
- this.stringLen = stringLen;
|
|
|
|
|
|
+ public StringLengthAndEmbedding(int utf8StringLen, double[] embedding) {
|
|
|
|
+ this.utf8StringLen = utf8StringLen;
|
|
this.embedding = embedding;
|
|
this.embedding = embedding;
|
|
}
|
|
}
|
|
|
|
|
|
- public int getStringLen() {
|
|
|
|
- return stringLen;
|
|
|
|
|
|
+ public int getUtf8StringLen() {
|
|
|
|
+ return utf8StringLen;
|
|
}
|
|
}
|
|
|
|
|
|
public double[] getEmbedding() {
|
|
public double[] getEmbedding() {
|
|
@@ -258,7 +258,7 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
|
if (i >= codePoints.length) {
|
|
if (i >= codePoints.length) {
|
|
break;
|
|
break;
|
|
}
|
|
}
|
|
- ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
|
|
|
|
|
|
+ Character.UnicodeScript currentCode = Character.UnicodeScript.of(codePoints[i]);
|
|
int j = i + 1;
|
|
int j = i + 1;
|
|
for (; j < codePoints.length; j++) {
|
|
for (; j < codePoints.length; j++) {
|
|
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
|
|
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
|
|
@@ -267,11 +267,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
|
if (j >= codePoints.length) {
|
|
if (j >= codePoints.length) {
|
|
break;
|
|
break;
|
|
}
|
|
}
|
|
- ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
|
|
|
|
- if (j1 != currentCode && j1 != ScriptCode.Inherited) {
|
|
|
|
|
|
+ Character.UnicodeScript j1 = Character.UnicodeScript.of(codePoints[j]);
|
|
|
|
+ if (j1 != currentCode && j1 != Character.UnicodeScript.INHERITED) {
|
|
if (j < codePoints.length - 1) {
|
|
if (j < codePoints.length - 1) {
|
|
- ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
|
|
|
|
- if (j2 != ScriptCode.Common && j2 != currentCode) {
|
|
|
|
|
|
+ Character.UnicodeScript j2 = Character.UnicodeScript.of(codePoints[j + 1]);
|
|
|
|
+ if (j2 != Character.UnicodeScript.COMMON && j2 != currentCode) {
|
|
break;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -290,7 +290,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
|
embeddings.add(
|
|
embeddings.add(
|
|
new StringLengthAndEmbedding(
|
|
new StringLengthAndEmbedding(
|
|
// Don't count white spaces as bytes for the prediction
|
|
// Don't count white spaces as bytes for the prediction
|
|
- str.trim().length(),
|
|
|
|
|
|
+ // We ues utf-8 length here as
|
|
|
|
+ // * The original C++ implementation does this when measuring string length
|
|
|
|
+ // * Languages with complex characters (like zh) convey more information per a single utf-16 character and
|
|
|
|
+ // using utf-8 length captures that.
|
|
|
|
+ str.trim().getBytes(StandardCharsets.UTF_8).length,
|
|
concatEmbeddings(
|
|
concatEmbeddings(
|
|
FEATURE_EXTRACTORS.stream()
|
|
FEATURE_EXTRACTORS.stream()
|
|
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
|
|
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
|