Browse Source

[ML] fix NER token grouping when special tokens are used (#84042)

bug Introduced by #83835

This switches back our token tagging to take into account the tokens position when reconstituting and tagging tokens for NER.
Benjamin Trent 3 years ago
parent
commit
ed40e1e0c2

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -229,7 +229,7 @@ public class NerProcessor extends NlpTask.Processor {
             int startTokenIndex = 0;
             int numSpecialTokens = 0;
             while (startTokenIndex < tokenization.tokenIds().length) {
-                int inputMapping = tokenization.tokenIds()[startTokenIndex];
+                int inputMapping = tokenization.tokenMap()[startTokenIndex];
                 if (inputMapping < 0) {
                     // This token does not map to a token in the input (special tokens)
                     startTokenIndex++;

+ 44 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -100,6 +100,50 @@ public class NerProcessorTests extends ESTestCase {
         assertThat(e, instanceOf(ElasticsearchStatusException.class));
     }
 
+    public void testProcessResultsWithSpecialTokens() {
+        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, true);
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            List.of(
+                "el",
+                "##astic",
+                "##search",
+                "many",
+                "use",
+                "in",
+                "london",
+                BertTokenizer.PAD_TOKEN,
+                BertTokenizer.UNKNOWN_TOKEN,
+                BertTokenizer.SEPARATOR_TOKEN,
+                BertTokenizer.CLASS_TOKEN
+            ),
+            new BertTokenization(true, true, null, Tokenization.Truncate.NONE)
+        ).build();
+        TokenizationResult tokenization = tokenizer.buildTokenizationResult(
+            List.of(tokenizer.tokenize("Many use Elasticsearch in London", Tokenization.Truncate.NONE))
+        );
+
+        double[][][] scores = {
+            {
+                { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // cls
+                { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // many
+                { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // use
+                { 0.01, 0.01, 0, 0.01, 0, 7, 0, 3, 0 }, // el
+                { 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0 }, // ##astic
+                { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // ##search
+                { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
+                { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london
+                { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep
+            } };
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
+
+        assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
+        assertThat(result.getEntityGroups().size(), equalTo(2));
+        assertThat(result.getEntityGroups().get(0).getEntity(), equalTo("elasticsearch"));
+        assertThat(result.getEntityGroups().get(0).getClassName(), equalTo(NerProcessor.Entity.ORG.toString()));
+        assertThat(result.getEntityGroups().get(1).getEntity(), equalTo("london"));
+        assertThat(result.getEntityGroups().get(1).getClassName(), equalTo(NerProcessor.Entity.LOC.toString()));
+    }
+
     public void testProcessResults() {
         NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, true);
         TokenizationResult tokenization = tokenize(