Browse Source

[ML] expand allowed NER labels to be any I-O-B tagged labels (#87091)

Named entity recognition (NER) is a special form of token classification. The specific kind of labelling we support is Inside-Outside-Beginning (IOB) tagging. These labels indicate if they are the inside of a token (with a `I-` or `I_`), the beginning (`B-` or `B_`) or outside (`O`). 

Each valid token classification label starts with the require prefix or `O`. 

Before this commit, we restricted the labels to a specific set:

```
O(Entity.NONE),      // Outside a named entity
B_MISC(Entity.MISC), // Beginning of a miscellaneous entity right after another miscellaneous entity
I_MISC(Entity.MISC), // Miscellaneous entity
B_PER(Entity.PER),   // Beginning of a person's name right after another person's name
I_PER(Entity.PER),   // Person's name
B_ORG(Entity.ORG),   // Beginning of an organization right after another organization
I_ORG(Entity.ORG),   // Organisation
B_LOC(Entity.LOC),   // Beginning of a location right after another location
I_LOC(Entity.LOC);   // Location
```

But now, any entity is allowed, as long as the naming of the labels adhere to IOB tagging rules.
Benjamin Trent 3 years ago
parent
commit
90d93a9309

+ 5 - 0
docs/changelog/87091.yaml

@@ -0,0 +1,5 @@
+pr: 87091
+summary: Expand allowed NER labels to be any I-O-B tagged labels
+area: Machine Learning
+type: enhancement
+issues: []

+ 25 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java

@@ -21,11 +21,20 @@ import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.List;
+import java.util.Locale;
 import java.util.Objects;
 import java.util.Optional;
 
 public class NerConfig implements NlpConfig {
 
+    public static boolean validIOBTag(String label) {
+        return label.toUpperCase(Locale.ROOT).startsWith("I-")
+            || label.toUpperCase(Locale.ROOT).startsWith("B-")
+            || label.toUpperCase(Locale.ROOT).startsWith("I_")
+            || label.toUpperCase(Locale.ROOT).startsWith("B_")
+            || label.toUpperCase(Locale.ROOT).startsWith("O");
+    }
+
     public static final String NAME = "ner";
 
     public static NerConfig fromXContentStrict(XContentParser parser) {
@@ -80,6 +89,22 @@ public class NerConfig implements NlpConfig {
             .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
         this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
         this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
+        if (this.classificationLabels.isEmpty() == false) {
+            List<String> badLabels = this.classificationLabels.stream().filter(l -> validIOBTag(l) == false).toList();
+            if (badLabels.isEmpty() == false) {
+                throw ExceptionsHelper.badRequestException(
+                    "[{}] only allows IOB tokenization tagging for classification labels; provided {}",
+                    NAME,
+                    badLabels
+                );
+            }
+            if (this.classificationLabels.stream().noneMatch(l -> l.toUpperCase(Locale.ROOT).equals("O"))) {
+                throw ExceptionsHelper.badRequestException(
+                    "[{}] only allows IOB tokenization tagging for classification labels; missing outside label [O]",
+                    NAME
+                );
+            }
+        }
         this.resultsField = resultsField;
         if (this.tokenization.span != -1) {
             throw ExceptionsHelper.badRequestException(

+ 11 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java

@@ -13,7 +13,11 @@ import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Set;
 import java.util.function.Predicate;
+import java.util.stream.Stream;
 
 public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
 
@@ -48,6 +52,12 @@ public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
     }
 
     public static NerConfig createRandom() {
+        Set<String> randomClassificationLabels = new HashSet<>(
+            Stream.generate(() -> randomFrom("O", "B_PER", "I_PER", "B_ORG", "I_ORG", "B_LOC", "I_LOC", "B_CUSTOM", "I_CUSTOM"))
+                .limit(10)
+                .toList()
+        );
+        randomClassificationLabels.add("O");
         return new NerConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
             randomBoolean()
@@ -57,7 +67,7 @@ public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
                     MPNetTokenizationTests.createRandom(),
                     RobertaTokenizationTests.createRandom()
                 ),
-            randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10)),
+            randomBoolean() ? null : new ArrayList<>(randomClassificationLabels),
             randomBoolean() ? null : randomAlphaOfLength(5)
         );
     }

+ 40 - 64
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -9,8 +9,6 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.common.ValidationException;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
@@ -21,64 +19,58 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
-import java.util.EnumSet;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
 import java.util.Optional;
+import java.util.Set;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 
 public class NerProcessor extends NlpTask.Processor {
 
-    public enum Entity implements Writeable {
-        NONE,
-        MISC,
-        PER,
-        ORG,
-        LOC;
-
-        @Override
-        public void writeTo(StreamOutput out) throws IOException {
-            out.writeEnum(this);
-        }
-
-        @Override
-        public String toString() {
-            return name().toUpperCase(Locale.ROOT);
+    record IobTag(String tag, String entity) {
+        static IobTag fromTag(String tag) {
+            String entity = tag.toUpperCase(Locale.ROOT);
+            if (entity.startsWith("B-") || entity.startsWith("I-") || entity.startsWith("B_") || entity.startsWith("I_")) {
+                entity = entity.substring(2);
+                return new IobTag(tag, entity);
+            } else if (entity.equals("O")) {
+                return new IobTag(tag, entity);
+            } else {
+                throw new IllegalArgumentException("classification label [" + tag + "] is not an entity I-O-B tag.");
+            }
         }
-    }
-
-    // Inside-Outside-Beginning (IOB) tag
-    enum IobTag {
-        O(Entity.NONE),      // Outside a named entity
-        B_MISC(Entity.MISC), // Beginning of a miscellaneous entity right after another miscellaneous entity
-        I_MISC(Entity.MISC), // Miscellaneous entity
-        B_PER(Entity.PER),   // Beginning of a person's name right after another person's name
-        I_PER(Entity.PER),   // Person's name
-        B_ORG(Entity.ORG),   // Beginning of an organisation right after another organisation
-        I_ORG(Entity.ORG),   // Organisation
-        B_LOC(Entity.LOC),   // Beginning of a location right after another location
-        I_LOC(Entity.LOC);   // Location
 
-        private final Entity entity;
-
-        IobTag(Entity entity) {
-            this.entity = entity;
+        boolean isBeginning() {
+            return tag.startsWith("b") || tag.startsWith("B");
         }
 
-        Entity getEntity() {
-            return entity;
+        boolean isNone() {
+            return tag.equals("o") || tag.equals("O");
         }
 
-        boolean isBeginning() {
-            return name().toLowerCase(Locale.ROOT).startsWith("b");
+        @Override
+        public String toString() {
+            return tag;
         }
     }
 
+    static final IobTag[] DEFAULT_IOB_TAGS = new IobTag[] {
+        IobTag.fromTag("O"),       // Outside a named entity
+        IobTag.fromTag("B_MISC"),  // Beginning of a miscellaneous entity right after another miscellaneous entity
+        IobTag.fromTag("I_MISC"),  // Miscellaneous entity
+        IobTag.fromTag("B_PER"),   // Beginning of a person's name right after another person's name
+        IobTag.fromTag("I_PER"),   // Person's name
+        IobTag.fromTag("B_ORG"),   // Beginning of an organisation right after another organisation
+        IobTag.fromTag("I_ORG"),   // Organisation
+        IobTag.fromTag("B_LOC"),   // Beginning of a location right after another location
+        IobTag.fromTag("I_LOC")    // Location
+    };
+
     private final NlpTask.RequestBuilder requestBuilder;
     private final IobTag[] iobMap;
     private final String resultsField;
@@ -102,10 +94,10 @@ public class NerProcessor extends NlpTask.Processor {
         }
 
         ValidationException ve = new ValidationException();
-        EnumSet<IobTag> tags = EnumSet.noneOf(IobTag.class);
+        Set<IobTag> tags = new HashSet<>();
         for (String label : classificationLabels) {
             try {
-                IobTag iobTag = IobTag.valueOf(label);
+                IobTag iobTag = IobTag.fromTag(label);
                 if (tags.contains(iobTag)) {
                     ve.addValidationError("the classification label [" + label + "] is duplicated in the list " + classificationLabels);
                 }
@@ -114,23 +106,20 @@ public class NerProcessor extends NlpTask.Processor {
                 ve.addValidationError("classification label [" + label + "] is not an entity I-O-B tag.");
             }
         }
-
         if (ve.validationErrors().isEmpty() == false) {
-            ve.addValidationError("Valid entity I-O-B tags are " + Arrays.toString(IobTag.values()));
             throw ve;
         }
     }
 
     static IobTag[] buildIobMap(List<String> classificationLabels) {
         if (classificationLabels == null || classificationLabels.isEmpty()) {
-            return IobTag.values();
+            return DEFAULT_IOB_TAGS;
         }
 
         IobTag[] map = new IobTag[classificationLabels.size()];
         for (int i = 0; i < classificationLabels.size(); i++) {
-            map[i] = IobTag.valueOf(classificationLabels.get(i));
+            map[i] = IobTag.fromTag(classificationLabels.get(i));
         }
-
         return map;
     }
 
@@ -281,7 +270,7 @@ public class NerProcessor extends NlpTask.Processor {
             int startTokenIndex = 0;
             while (startTokenIndex < tokens.size()) {
                 TaggedToken token = tokens.get(startTokenIndex);
-                if (token.tag.getEntity() == Entity.NONE) {
+                if (token.tag.isNone()) {
                     startTokenIndex++;
                     continue;
                 }
@@ -289,7 +278,7 @@ public class NerProcessor extends NlpTask.Processor {
                 double scoreSum = token.score;
                 while (endTokenIndex < tokens.size()) {
                     TaggedToken endToken = tokens.get(endTokenIndex);
-                    if (endToken.tag.isBeginning() || endToken.tag.getEntity() != token.tag.getEntity()) {
+                    if (endToken.tag.isBeginning() || endToken.tag.entity().equals(token.tag.entity()) == false) {
                         break;
                     }
                     scoreSum += endToken.score;
@@ -300,13 +289,7 @@ public class NerProcessor extends NlpTask.Processor {
                 int endPos = tokens.get(endTokenIndex - 1).token.endOffset();
                 String entity = inputSeq.substring(startPos, endPos);
                 entities.add(
-                    new NerResults.EntityGroup(
-                        entity,
-                        token.tag.getEntity().toString(),
-                        scoreSum / (endTokenIndex - startTokenIndex),
-                        startPos,
-                        endPos
-                    )
+                    new NerResults.EntityGroup(entity, token.tag.entity(), scoreSum / (endTokenIndex - startTokenIndex), startPos, endPos)
                 );
                 startTokenIndex = endTokenIndex;
             }
@@ -317,14 +300,7 @@ public class NerProcessor extends NlpTask.Processor {
         record TaggedToken(DelimitedToken token, IobTag tag, double score) {
             @Override
             public String toString() {
-                return new StringBuilder("{").append("token:")
-                    .append(token)
-                    .append(", ")
-                    .append(tag)
-                    .append(", ")
-                    .append(score)
-                    .append("}")
-                    .toString();
+                return "{" + "token:" + token + ", " + tag + ", " + score + "}";
             }
         }
     }

+ 149 - 133
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -23,7 +23,6 @@ import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResu
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.List;
 import java.util.stream.Collectors;
 
@@ -38,34 +37,24 @@ import static org.mockito.Mockito.mock;
 
 public class NerProcessorTests extends ESTestCase {
 
-    public void testBuildIobMap_WithDefault() {
-        NerProcessor.IobTag[] map = NerProcessor.buildIobMap(randomBoolean() ? null : Collections.emptyList());
-        for (int i = 0; i < map.length; i++) {
-            assertEquals(i, map[i].ordinal());
-        }
-    }
-
     public void testBuildIobMap_Reordered() {
         NerProcessor.IobTag[] tags = new NerProcessor.IobTag[] {
-            NerProcessor.IobTag.I_MISC,
-            NerProcessor.IobTag.O,
-            NerProcessor.IobTag.B_MISC,
-            NerProcessor.IobTag.I_PER };
+            NerProcessor.IobTag.fromTag("I_MISC"),
+            NerProcessor.IobTag.fromTag("O"),
+            NerProcessor.IobTag.fromTag("B_MISC"),
+            NerProcessor.IobTag.fromTag("I_PER") };
 
         List<String> classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList());
         NerProcessor.IobTag[] map = NerProcessor.buildIobMap(classLabels);
-        for (int i = 0; i < map.length; i++) {
-            assertNotEquals(i, map[i].ordinal());
-        }
         assertArrayEquals(tags, map);
     }
 
     public void testValidate_DuplicateLabels() {
         NerProcessor.IobTag[] tags = new NerProcessor.IobTag[] {
-            NerProcessor.IobTag.I_MISC,
-            NerProcessor.IobTag.B_MISC,
-            NerProcessor.IobTag.B_MISC,
-            NerProcessor.IobTag.O, };
+            NerProcessor.IobTag.fromTag("I_MISC"),
+            NerProcessor.IobTag.fromTag("B_MISC"),
+            NerProcessor.IobTag.fromTag("B_MISC"),
+            NerProcessor.IobTag.fromTag("O"), };
 
         List<String> classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList());
         NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels, null);
@@ -77,20 +66,8 @@ public class NerProcessorTests extends ESTestCase {
         );
     }
 
-    public void testValidate_NotAEntityLabel() {
-        List<String> classLabels = List.of("foo", NerProcessor.IobTag.B_MISC.toString());
-        NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels, null);
-
-        ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
-        assertThat(ve.getMessage(), containsString("classification label [foo] is not an entity I-O-B tag"));
-        assertThat(
-            ve.getMessage(),
-            containsString("Valid entity I-O-B tags are [O, B_MISC, I_MISC, B_PER, I_PER, B_ORG, I_ORG, B_LOC, I_LOC]")
-        );
-    }
-
     public void testProcessResults_GivenNoTokens() {
-        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, false);
+        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.DEFAULT_IOB_TAGS, null, false);
         TokenizationResult tokenization = tokenize(List.of(BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN), "");
 
         var e = expectThrows(
@@ -101,88 +78,124 @@ public class NerProcessorTests extends ESTestCase {
     }
 
     public void testProcessResultsWithSpecialTokens() {
-        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, true);
-        BertTokenizer tokenizer = BertTokenizer.builder(
-            List.of(
-                "el",
-                "##astic",
-                "##search",
-                "many",
-                "use",
-                "in",
-                "london",
-                BertTokenizer.PAD_TOKEN,
-                BertTokenizer.UNKNOWN_TOKEN,
-                BertTokenizer.SEPARATOR_TOKEN,
-                BertTokenizer.CLASS_TOKEN
-            ),
-            new BertTokenization(true, true, null, Tokenization.Truncate.NONE, -1)
-        ).build();
-        TokenizationResult tokenization = tokenizer.buildTokenizationResult(
-            List.of(tokenizer.tokenize("Many use Elasticsearch in London", Tokenization.Truncate.NONE, -1, 1).get(0))
+        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.DEFAULT_IOB_TAGS, null, true);
+        try (
+            BertTokenizer tokenizer = BertTokenizer.builder(
+                List.of(
+                    "el",
+                    "##astic",
+                    "##search",
+                    "many",
+                    "use",
+                    "in",
+                    "london",
+                    BertTokenizer.PAD_TOKEN,
+                    BertTokenizer.UNKNOWN_TOKEN,
+                    BertTokenizer.SEPARATOR_TOKEN,
+                    BertTokenizer.CLASS_TOKEN
+                ),
+                new BertTokenization(true, true, null, Tokenization.Truncate.NONE, -1)
+            ).build()
+        ) {
+            TokenizationResult tokenization = tokenizer.buildTokenizationResult(
+                List.of(tokenizer.tokenize("Many use Elasticsearch in London", Tokenization.Truncate.NONE, -1, 1).get(0))
+            );
+
+            double[][][] scores = {
+                {
+                    { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // cls
+                    { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // many
+                    { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // use
+                    { 0.01, 0.01, 0, 0.01, 0, 7, 0, 3, 0 }, // el
+                    { 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0 }, // ##astic
+                    { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // ##search
+                    { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
+                    { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london
+                    { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep
+                } };
+            NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L));
+
+            assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
+            assertThat(result.getEntityGroups().size(), equalTo(2));
+            assertThat(result.getEntityGroups().get(0).getEntity(), equalTo("elasticsearch"));
+            assertThat(result.getEntityGroups().get(0).getClassName(), equalTo("ORG"));
+            assertThat(result.getEntityGroups().get(1).getEntity(), equalTo("london"));
+            assertThat(result.getEntityGroups().get(1).getClassName(), equalTo("LOC"));
+        }
+    }
+
+    public void testProcessResults() {
+        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.DEFAULT_IOB_TAGS, null, true);
+        TokenizationResult tokenization = tokenize(
+            Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london", BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN),
+            "Many use Elasticsearch in London"
         );
 
         double[][][] scores = {
             {
-                { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // cls
                 { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // many
                 { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // use
                 { 0.01, 0.01, 0, 0.01, 0, 7, 0, 3, 0 }, // el
                 { 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0 }, // ##astic
                 { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // ##search
                 { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
-                { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london
-                { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep
+                { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london
             } };
         NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L));
 
         assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));
         assertThat(result.getEntityGroups().get(0).getEntity(), equalTo("elasticsearch"));
-        assertThat(result.getEntityGroups().get(0).getClassName(), equalTo(NerProcessor.Entity.ORG.toString()));
+        assertThat(result.getEntityGroups().get(0).getClassName(), equalTo("ORG"));
         assertThat(result.getEntityGroups().get(1).getEntity(), equalTo("london"));
-        assertThat(result.getEntityGroups().get(1).getClassName(), equalTo(NerProcessor.Entity.LOC.toString()));
+        assertThat(result.getEntityGroups().get(1).getClassName(), equalTo("LOC"));
     }
 
-    public void testProcessResults() {
-        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, true);
+    public void testProcessResults_withIobMap() {
+
+        NerProcessor.IobTag[] iobMap = new NerProcessor.IobTag[] {
+            NerProcessor.IobTag.fromTag("B_LOC"),
+            NerProcessor.IobTag.fromTag("I_LOC"),
+            NerProcessor.IobTag.fromTag("B_MISC"),
+            NerProcessor.IobTag.fromTag("I_MISC"),
+            NerProcessor.IobTag.fromTag("B_PER"),
+            NerProcessor.IobTag.fromTag("I_PER"),
+            NerProcessor.IobTag.fromTag("B_ORG"),
+            NerProcessor.IobTag.fromTag("I_ORG"),
+            NerProcessor.IobTag.fromTag("O") };
+
+        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(iobMap, null, true);
         TokenizationResult tokenization = tokenize(
-            Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london", BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN),
-            "Many use Elasticsearch in London"
+            Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london", BertTokenizer.UNKNOWN_TOKEN, BertTokenizer.PAD_TOKEN),
+            "Elasticsearch in London"
         );
 
         double[][][] scores = {
             {
-                { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // many
-                { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // use
-                { 0.01, 0.01, 0, 0.01, 0, 7, 0, 3, 0 }, // el
+                { 0.01, 0.01, 0, 0.01, 0, 0, 7, 3, 0 }, // el
                 { 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0 }, // ##astic
                 { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // ##search
-                { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
-                { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london
+                { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in
+                { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london
             } };
         NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L));
 
-        assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
+        assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));
         assertThat(result.getEntityGroups().get(0).getEntity(), equalTo("elasticsearch"));
-        assertThat(result.getEntityGroups().get(0).getClassName(), equalTo(NerProcessor.Entity.ORG.toString()));
+        assertThat(result.getEntityGroups().get(0).getClassName(), equalTo("ORG"));
         assertThat(result.getEntityGroups().get(1).getEntity(), equalTo("london"));
-        assertThat(result.getEntityGroups().get(1).getClassName(), equalTo(NerProcessor.Entity.LOC.toString()));
+        assertThat(result.getEntityGroups().get(1).getClassName(), equalTo("LOC"));
     }
 
-    public void testProcessResults_withIobMap() {
+    public void testProcessResults_withCustomIobMap() {
 
         NerProcessor.IobTag[] iobMap = new NerProcessor.IobTag[] {
-            NerProcessor.IobTag.B_LOC,
-            NerProcessor.IobTag.I_LOC,
-            NerProcessor.IobTag.B_MISC,
-            NerProcessor.IobTag.I_MISC,
-            NerProcessor.IobTag.B_PER,
-            NerProcessor.IobTag.I_PER,
-            NerProcessor.IobTag.B_ORG,
-            NerProcessor.IobTag.I_ORG,
-            NerProcessor.IobTag.O };
+            NerProcessor.IobTag.fromTag("B_LOC"),
+            NerProcessor.IobTag.fromTag("I_LOC"),
+            NerProcessor.IobTag.fromTag("B_SOFTWARE"),
+            NerProcessor.IobTag.fromTag("I_SOFTWARE"),
+            NerProcessor.IobTag.fromTag("O") };
 
         NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(iobMap, null, true);
         TokenizationResult tokenization = tokenize(
@@ -192,20 +205,20 @@ public class NerProcessorTests extends ESTestCase {
 
         double[][][] scores = {
             {
-                { 0.01, 0.01, 0, 0.01, 0, 0, 7, 3, 0 }, // el
-                { 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0 }, // ##astic
-                { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // ##search
-                { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in
-                { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london
+                { 0.01, 0.01, 7, 3, 0 }, // el
+                { 0.01, 0.01, 0, 0, 0 }, // ##astic
+                { 0, 0, 0, 0, 0 }, // ##search
+                { 0, 0, 0, 0, 5 }, // in
+                { 6, 0, 0, 0, 0 } // london
             } };
         NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L));
 
-        assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
+        assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](SOFTWARE&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));
         assertThat(result.getEntityGroups().get(0).getEntity(), equalTo("elasticsearch"));
-        assertThat(result.getEntityGroups().get(0).getClassName(), equalTo(NerProcessor.Entity.ORG.toString()));
+        assertThat(result.getEntityGroups().get(0).getClassName(), equalTo("SOFTWARE"));
         assertThat(result.getEntityGroups().get(1).getEntity(), equalTo("london"));
-        assertThat(result.getEntityGroups().get(1).getClassName(), equalTo(NerProcessor.Entity.LOC.toString()));
+        assertThat(result.getEntityGroups().get(1).getClassName(), equalTo("LOC"));
     }
 
     public void testGroupTaggedTokens() throws IOException {
@@ -215,18 +228,18 @@ public class NerProcessorTests extends ESTestCase {
 
         List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
         int i = 0;
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_LOC, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.B_ORG, 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_LOC"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.fromTag("B_ORG"), 1.0));
 
         List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
         assertThat(entityGroups, hasSize(3));
@@ -243,8 +256,8 @@ public class NerProcessorTests extends ESTestCase {
         List<DelimitedToken> tokens = basicTokenize(randomBoolean(), randomBoolean(), List.of(), input);
 
         List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(0), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(1), NerProcessor.IobTag.O, 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(0), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(1), NerProcessor.IobTag.fromTag("O"), 1.0));
 
         List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
         assertThat(entityGroups, is(empty()));
@@ -256,13 +269,13 @@ public class NerProcessorTests extends ESTestCase {
 
         List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
         int i = 0;
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.O, 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.fromTag("O"), 1.0));
 
         List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
         assertThat(entityGroups, hasSize(3));
@@ -280,12 +293,12 @@ public class NerProcessorTests extends ESTestCase {
 
         List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
         int i = 0;
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.B_ORG, 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.fromTag("B_ORG"), 1.0));
 
         List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
         assertThat(entityGroups, hasSize(3));
@@ -302,21 +315,21 @@ public class NerProcessorTests extends ESTestCase {
 
         List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
         int i = 0;
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.O, 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_ORG"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_ORG"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_ORG"), 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.fromTag("O"), 1.0));
         assertEquals(tokens.size(), taggedTokens.size());
 
         List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
@@ -351,10 +364,13 @@ public class NerProcessorTests extends ESTestCase {
     }
 
     private static TokenizationResult tokenize(List<String> vocab, String input) {
-        BertTokenizer tokenizer = BertTokenizer.builder(vocab, new BertTokenization(true, false, null, Tokenization.Truncate.NONE, -1))
-            .setDoLowerCase(true)
-            .setWithSpecialTokens(false)
-            .build();
-        return tokenizer.buildTokenizationResult(tokenizer.tokenize(input, Tokenization.Truncate.NONE, -1, 0));
+        try (
+            BertTokenizer tokenizer = BertTokenizer.builder(vocab, new BertTokenization(true, false, null, Tokenization.Truncate.NONE, -1))
+                .setDoLowerCase(true)
+                .setWithSpecialTokens(false)
+                .build()
+        ) {
+            return tokenizer.buildTokenizationResult(tokenizer.tokenize(input, Tokenization.Truncate.NONE, -1, 0));
+        }
     }
 }