Pārlūkot izejas kodu

[ML] add sentence piece pre-compiled normalizer (#87575)

This is one of the many prerequisites for supporting sentence-piece tokenization within NLP.

Sentence piece is a fairly complicated and involved tokenization scheme.

This commit contains the normalization logic that transforms the provided string from its current utf8 bytes into a standard normalized set of utf8 bytes.

The typical storage for this normalizer is a compressed representation of a DARTS array and a null delimited normalization string.
Benjamin Trent 3 gadi atpakaļ
vecāks
revīzija
2d571a04ea

+ 209 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizer.java

@@ -0,0 +1,209 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ *
+ * This Java port DoubleArray Trie Structure, precompiled charmap parsing and sentence piece normalizer was derived from
+ * Huggingface's spm-precompiled.
+ * project at https://github.com/huggingface/spm_precompiled
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import com.ibm.icu.text.BreakIterator;
+
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefBuilder;
+import org.apache.lucene.util.UnicodeUtil;
+
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.Locale;
+import java.util.Optional;
+import java.util.OptionalInt;
+
+/**
+ * This is custom normalizer logic purpose built to replicate the logic in DoubleArray Trie System (darts)
+ * object and the sentence piece normalizer.
+ *
+ * Links with further explanation of various parts of the algorithm:
+ *  - <a href="https://github.com/huggingface/spm_precompiled/blob/81b911a362adef3ad3cc6d5835d2980690dbb871/src/lib.rs">
+ *      huggingface lib
+ *      </a>
+ *  - <a href="https://github.com/google/sentencepiece/blob/bc53923a9147dc8ffa54034c8ed774de78cc4d39/third_party/darts_clone/darts.h#L469">
+ *      DARTS
+ *      </a>
+ *  - <a href="https://github.com/google/sentencepiece/blob/91809e5c70ed0e6364267a0f0fed66c144482ce4/src/normalizer.cc">SP normalizer</a>
+ */
+public class PrecompiledCharMapNormalizer {
+
+    static PrecompiledCharMapNormalizer fromBase64Str(String s) {
+        int offset = 0;
+        byte[] bytes = Base64.getDecoder().decode(s);
+        int trieSize = ByteBuffer.wrap(bytes, offset, 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).getInt();
+        offset += 4;
+        int size = trieSize / 4;
+        int[] offsets = new int[size];
+        for (int i = 0; i < size; i++) {
+            offsets[i] = ByteBuffer.wrap(bytes, offset, 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).getInt();
+            offset += 4;
+        }
+        String utf8Str = new String(bytes, offset, bytes.length - offset, StandardCharsets.UTF_8);
+        return new PrecompiledCharMapNormalizer(offsets, utf8Str);
+    }
+
+    // The offsets for each normalization piece. Used in DARTS algorithm to iterate and find appropriate section
+    // in normalizedStrUtf8Bytes
+    private final int[] offsets;
+    // The entire normalized bytes representations delimited by NULL
+    private final byte[] normalizedStrUtf8Bytes;
+    // Continually reused to copy a single char into utf8 bytes
+    private final byte[] reusableCharByteBuffer = new byte[4];
+
+    public PrecompiledCharMapNormalizer(int[] offsets, String normalizedStr) {
+        this.offsets = offsets;
+        this.normalizedStrUtf8Bytes = normalizedStr.getBytes(StandardCharsets.UTF_8);
+    }
+
+    boolean hasLeaf(int v) {
+        return ((v >>> 8) & 1) == 1;
+    }
+
+    int label(int v) {
+        return (v & ((1 << 31) | 0xFF));
+    }
+
+    int value(int v) {
+        return (v & ((1 << 31) - 1));
+    }
+
+    int offset(int v) {
+        return (v >>> 10) << ((v & (1 << 9)) >>> 6);
+    }
+
+    OptionalInt commonPrefix(byte[] inputBytes) {
+        return commonPrefix(inputBytes, 0, inputBytes.length);
+    }
+
+    /**
+     * This finds a common prefix position within the normalization byte string.
+     *
+     * Since the normalization string is NULL delimited, start at the returned index and continue until you hit the NULL byte. That is
+     * then the normalized string.
+     *
+     * The prefix search is done according to DoubleArray Trie System (DARTS).
+     *
+     * See:
+     * <a href="https://github.com/google/sentencepiece/blob/bc53923a9147dc8ffa54034c8ed774de78cc4d39/third_party/darts_clone/darts.h#L469">
+     *     DARTS
+     *     </a>
+     * @param inputBytes utf8 bytes to normalize
+     * @param offset offset position to start given the input
+     * @param len the length of bytes to consider
+     * @return The starting position in the normalization string of the normalized bytes, if found.
+     */
+    OptionalInt commonPrefix(byte[] inputBytes, int offset, int len) {
+        int pos = 0;
+        OptionalInt vs = OptionalInt.empty();
+        int v = offsets[pos];
+        pos ^= offset(v);
+        for (int i = offset; i < offset + len; i++) {
+            // bytes can be negative in java, handle it and require unsigned
+            int k = inputBytes[i];
+            if (k < 0) {
+                k += 256;
+            }
+            if (k == 0) {
+                break;
+            }
+            pos ^= k;
+            v = offsets[pos];
+            if (label(v) != k) {
+                return vs;
+            }
+            pos ^= offset(v);
+            if (hasLeaf(v)) {
+                vs = OptionalInt.of(value(offsets[pos]));
+                return vs;
+            }
+        }
+        return vs;
+    }
+
+    Optional<BytesRef> normalizePart(byte[] strBytes, int offset, int len) {
+        OptionalInt index = commonPrefix(strBytes, offset, len);
+        if (index.isEmpty()) {
+            return Optional.empty();
+        }
+        int firstIndex = index.getAsInt();
+        int secondIndex = firstIndex;
+        // Parsed normalized string has normalization sections partitioned by \0 (NULL) byte
+        while (secondIndex < normalizedStrUtf8Bytes.length && normalizedStrUtf8Bytes[secondIndex] != 0) {
+            secondIndex++;
+        }
+        if (secondIndex == firstIndex) {
+            return Optional.empty();
+        }
+        return Optional.of(new BytesRef(normalizedStrUtf8Bytes, firstIndex, secondIndex - firstIndex));
+    }
+
+    String normalize(String str) {
+        // We need to iterate actual Unicode graphemes (this includes surrogate pairs, etc.)
+        // I would much rather translate the entire input string text into utf-8 bytes, and then iterate to the appropriate
+        // break points from there. But, this seemed the easiest way for now
+        //
+        // Keep in mind, these break points aren't necessarily surrogate pairs, but also codepoints that contain a combining mark
+        BreakIterator b = BreakIterator.getCharacterInstance(Locale.ROOT);
+        b.setText(str);
+        int start = b.first();
+        // If we knew the utf-8 length ahead of time (and iterated over the bytes in the appropriate chunks)
+        // we could pre-populate the known length here.
+        BytesRefBuilder strBuilder = new BytesRefBuilder();
+        for (int end = b.next(); end != BreakIterator.DONE; start = end, end = b.next()) {
+            // TODO: It would be awesome if we could translate these starts and ends to byte positions, if we could performance would be
+            // dramatically improved
+            String unicodeStr = str.substring(start, end);
+            byte[] unicode = unicodeStr.getBytes(StandardCharsets.UTF_8);
+            // The trie only go up to a depth of 5 bytes.
+            // So even looking at it for graphemes (with combining, surrogate, etc.) that are 6+ bytes in length is useless.
+            if (unicode.length < 6) {
+                Optional<BytesRef> subStr = normalizePart(unicode, 0, unicode.length);
+                if (subStr.isPresent()) {
+                    strBuilder.append(subStr.get());
+                    continue;
+                }
+            }
+            int charIndex = 0;
+            int charByteIndex = 0;
+            char[] unicodeCharArray = unicodeStr.toCharArray();
+            for (char c : unicodeCharArray) {
+                Optional<BytesRef> subStr = normalizePart(unicode, charByteIndex, numUtf8Bytes(c));
+                if (subStr.isPresent()) {
+                    strBuilder.append(subStr.get());
+                } else {
+                    int numBytes = UnicodeUtil.UTF16toUTF8(unicodeCharArray, charIndex, 1, reusableCharByteBuffer);
+                    strBuilder.append(reusableCharByteBuffer, 0, numBytes);
+                }
+                charByteIndex += numUtf8Bytes(c);
+                ++charIndex;
+            }
+        }
+        return strBuilder.get().utf8ToString();
+    }
+
+    private static int numUtf8Bytes(int c) {
+        if (c < 128) {
+            return 1;
+        }
+        if (c < 2048) {
+            return 2;
+        }
+        if (c < 65536) {
+            return 3;
+        }
+        return 4;
+    }
+
+}

+ 39 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PreCompiledCharMap.java

@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+
+import java.io.IOException;
+
+import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+
+record PreCompiledCharMap(String charMapStr) {
+    static ParseField PRECOMPILED_CHARSMAP = new ParseField("precompiled_charsmap");
+    static ConstructingObjectParser<PreCompiledCharMap, Void> PARSER = new ConstructingObjectParser<>(
+        "precompiled_charsmap_config",
+        true,
+        a -> new PreCompiledCharMap((String) a[0])
+    );
+    static {
+        PARSER.declareString(constructorArg(), PRECOMPILED_CHARSMAP);
+    }
+
+    static PreCompiledCharMap fromResource(String resourcePath) throws IOException {
+        try (
+            XContentParser parser = XContentType.JSON.xContent()
+                .createParser(XContentParserConfiguration.EMPTY, PreCompiledCharMap.class.getResourceAsStream(resourcePath))
+        ) {
+            return PreCompiledCharMap.PARSER.apply(parser, null);
+        }
+    }
+}

+ 55 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizerTests.java

@@ -0,0 +1,55 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.OptionalInt;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
+public class PrecompiledCharMapNormalizerTests extends ESTestCase {
+
+    public void testCommonPrefix() throws IOException {
+        PrecompiledCharMapNormalizer parsed = loadTestCharMap();
+        OptionalInt local = parsed.commonPrefix("\uFB01".getBytes(StandardCharsets.UTF_8));
+        assertThat(local.isPresent(), is(true));
+        assertThat(local.getAsInt(), equalTo(2130));
+        String transformed = parsed.normalize("\uFB01");
+        assertThat(transformed, equalTo("fi"));
+        assertThat(parsed.normalize("𝔾"), equalTo("G"));
+        assertThat(parsed.normalize("\uD835\uDD60"), equalTo("o"));
+        assertThat(parsed.normalize("\u200D"), equalTo(" "));
+        assertThat(parsed.normalize("เขาไม่ได้พูดสักคำ"), equalTo("เขาไม\u0E48ได\u0E49พ\u0E39ดส\u0E31กค\u0E4Dา"));
+    }
+
+    public void testAdverseScenario() throws IOException {
+        PrecompiledCharMapNormalizer parsed = loadTestCharMap();
+        assertThat(parsed.normalize("คำ"), equalTo("ค\u0e4dา"));
+    }
+
+    public void testAdverseScenarioHindi() throws IOException {
+        PrecompiledCharMapNormalizer parsed = loadTestCharMap();
+        assertThat(parsed.normalize("ड़ी दुख"), equalTo("ड\u093cी द\u0941ख"));
+    }
+
+    public void testTwoCharUnicode() throws IOException {
+        PrecompiledCharMapNormalizer parsed = loadTestCharMap();
+        assertThat(parsed.normalize("آ"), equalTo("آ"));
+    }
+
+    private static PrecompiledCharMapNormalizer loadTestCharMap() throws IOException {
+        PreCompiledCharMap map = PreCompiledCharMap.fromResource(
+            "/org.elasticsearch.xpack.ml.inference.nlp.tokenizers/precompiled_char_map.json"
+        );
+        return PrecompiledCharMapNormalizer.fromBase64Str(map.charMapStr());
+    }
+}

Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 1 - 0
x-pack/plugin/ml/src/test/resources/org.elasticsearch.xpack.ml.inference.nlp.tokenizers/precompiled_char_map.json


Daži faili netika attēloti, jo izmaiņu fails ir pārāk liels