|
@@ -23,7 +23,6 @@ import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResu
|
|
|
import java.io.IOException;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
|
-import java.util.Collections;
|
|
|
import java.util.List;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
@@ -38,34 +37,24 @@ import static org.mockito.Mockito.mock;
|
|
|
|
|
|
public class NerProcessorTests extends ESTestCase {
|
|
|
|
|
|
- public void testBuildIobMap_WithDefault() {
|
|
|
- NerProcessor.IobTag[] map = NerProcessor.buildIobMap(randomBoolean() ? null : Collections.emptyList());
|
|
|
- for (int i = 0; i < map.length; i++) {
|
|
|
- assertEquals(i, map[i].ordinal());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
public void testBuildIobMap_Reordered() {
|
|
|
NerProcessor.IobTag[] tags = new NerProcessor.IobTag[] {
|
|
|
- NerProcessor.IobTag.I_MISC,
|
|
|
- NerProcessor.IobTag.O,
|
|
|
- NerProcessor.IobTag.B_MISC,
|
|
|
- NerProcessor.IobTag.I_PER };
|
|
|
+ NerProcessor.IobTag.fromTag("I_MISC"),
|
|
|
+ NerProcessor.IobTag.fromTag("O"),
|
|
|
+ NerProcessor.IobTag.fromTag("B_MISC"),
|
|
|
+ NerProcessor.IobTag.fromTag("I_PER") };
|
|
|
|
|
|
List<String> classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList());
|
|
|
NerProcessor.IobTag[] map = NerProcessor.buildIobMap(classLabels);
|
|
|
- for (int i = 0; i < map.length; i++) {
|
|
|
- assertNotEquals(i, map[i].ordinal());
|
|
|
- }
|
|
|
assertArrayEquals(tags, map);
|
|
|
}
|
|
|
|
|
|
public void testValidate_DuplicateLabels() {
|
|
|
NerProcessor.IobTag[] tags = new NerProcessor.IobTag[] {
|
|
|
- NerProcessor.IobTag.I_MISC,
|
|
|
- NerProcessor.IobTag.B_MISC,
|
|
|
- NerProcessor.IobTag.B_MISC,
|
|
|
- NerProcessor.IobTag.O, };
|
|
|
+ NerProcessor.IobTag.fromTag("I_MISC"),
|
|
|
+ NerProcessor.IobTag.fromTag("B_MISC"),
|
|
|
+ NerProcessor.IobTag.fromTag("B_MISC"),
|
|
|
+ NerProcessor.IobTag.fromTag("O"), };
|
|
|
|
|
|
List<String> classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList());
|
|
|
NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels, null);
|
|
@@ -77,20 +66,8 @@ public class NerProcessorTests extends ESTestCase {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- public void testValidate_NotAEntityLabel() {
|
|
|
- List<String> classLabels = List.of("foo", NerProcessor.IobTag.B_MISC.toString());
|
|
|
- NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels, null);
|
|
|
-
|
|
|
- ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
|
|
|
- assertThat(ve.getMessage(), containsString("classification label [foo] is not an entity I-O-B tag"));
|
|
|
- assertThat(
|
|
|
- ve.getMessage(),
|
|
|
- containsString("Valid entity I-O-B tags are [O, B_MISC, I_MISC, B_PER, I_PER, B_ORG, I_ORG, B_LOC, I_LOC]")
|
|
|
- );
|
|
|
- }
|
|
|
-
|
|
|
public void testProcessResults_GivenNoTokens() {
|
|
|
- NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, false);
|
|
|
+ NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.DEFAULT_IOB_TAGS, null, false);
|
|
|
TokenizationResult tokenization = tokenize(List.of(BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN), "");
|
|
|
|
|
|
var e = expectThrows(
|
|
@@ -101,88 +78,124 @@ public class NerProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
public void testProcessResultsWithSpecialTokens() {
|
|
|
- NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, true);
|
|
|
- BertTokenizer tokenizer = BertTokenizer.builder(
|
|
|
- List.of(
|
|
|
- "el",
|
|
|
- "##astic",
|
|
|
- "##search",
|
|
|
- "many",
|
|
|
- "use",
|
|
|
- "in",
|
|
|
- "london",
|
|
|
- BertTokenizer.PAD_TOKEN,
|
|
|
- BertTokenizer.UNKNOWN_TOKEN,
|
|
|
- BertTokenizer.SEPARATOR_TOKEN,
|
|
|
- BertTokenizer.CLASS_TOKEN
|
|
|
- ),
|
|
|
- new BertTokenization(true, true, null, Tokenization.Truncate.NONE, -1)
|
|
|
- ).build();
|
|
|
- TokenizationResult tokenization = tokenizer.buildTokenizationResult(
|
|
|
- List.of(tokenizer.tokenize("Many use Elasticsearch in London", Tokenization.Truncate.NONE, -1, 1).get(0))
|
|
|
+ NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.DEFAULT_IOB_TAGS, null, true);
|
|
|
+ try (
|
|
|
+ BertTokenizer tokenizer = BertTokenizer.builder(
|
|
|
+ List.of(
|
|
|
+ "el",
|
|
|
+ "##astic",
|
|
|
+ "##search",
|
|
|
+ "many",
|
|
|
+ "use",
|
|
|
+ "in",
|
|
|
+ "london",
|
|
|
+ BertTokenizer.PAD_TOKEN,
|
|
|
+ BertTokenizer.UNKNOWN_TOKEN,
|
|
|
+ BertTokenizer.SEPARATOR_TOKEN,
|
|
|
+ BertTokenizer.CLASS_TOKEN
|
|
|
+ ),
|
|
|
+ new BertTokenization(true, true, null, Tokenization.Truncate.NONE, -1)
|
|
|
+ ).build()
|
|
|
+ ) {
|
|
|
+ TokenizationResult tokenization = tokenizer.buildTokenizationResult(
|
|
|
+ List.of(tokenizer.tokenize("Many use Elasticsearch in London", Tokenization.Truncate.NONE, -1, 1).get(0))
|
|
|
+ );
|
|
|
+
|
|
|
+ double[][][] scores = {
|
|
|
+ {
|
|
|
+ { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // cls
|
|
|
+ { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // many
|
|
|
+ { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // use
|
|
|
+ { 0.01, 0.01, 0, 0.01, 0, 7, 0, 3, 0 }, // el
|
|
|
+ { 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0 }, // ##astic
|
|
|
+ { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // ##search
|
|
|
+ { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
|
|
|
+ { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london
|
|
|
+ { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep
|
|
|
+ } };
|
|
|
+ NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L));
|
|
|
+
|
|
|
+ assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
|
|
|
+ assertThat(result.getEntityGroups().size(), equalTo(2));
|
|
|
+ assertThat(result.getEntityGroups().get(0).getEntity(), equalTo("elasticsearch"));
|
|
|
+ assertThat(result.getEntityGroups().get(0).getClassName(), equalTo("ORG"));
|
|
|
+ assertThat(result.getEntityGroups().get(1).getEntity(), equalTo("london"));
|
|
|
+ assertThat(result.getEntityGroups().get(1).getClassName(), equalTo("LOC"));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testProcessResults() {
|
|
|
+ NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.DEFAULT_IOB_TAGS, null, true);
|
|
|
+ TokenizationResult tokenization = tokenize(
|
|
|
+ Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london", BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN),
|
|
|
+ "Many use Elasticsearch in London"
|
|
|
);
|
|
|
|
|
|
double[][][] scores = {
|
|
|
{
|
|
|
- { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // cls
|
|
|
{ 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // many
|
|
|
{ 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // use
|
|
|
{ 0.01, 0.01, 0, 0.01, 0, 7, 0, 3, 0 }, // el
|
|
|
{ 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0 }, // ##astic
|
|
|
{ 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // ##search
|
|
|
{ 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
|
|
|
- { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london
|
|
|
- { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep
|
|
|
+ { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london
|
|
|
} };
|
|
|
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L));
|
|
|
|
|
|
assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
|
|
|
assertThat(result.getEntityGroups().size(), equalTo(2));
|
|
|
assertThat(result.getEntityGroups().get(0).getEntity(), equalTo("elasticsearch"));
|
|
|
- assertThat(result.getEntityGroups().get(0).getClassName(), equalTo(NerProcessor.Entity.ORG.toString()));
|
|
|
+ assertThat(result.getEntityGroups().get(0).getClassName(), equalTo("ORG"));
|
|
|
assertThat(result.getEntityGroups().get(1).getEntity(), equalTo("london"));
|
|
|
- assertThat(result.getEntityGroups().get(1).getClassName(), equalTo(NerProcessor.Entity.LOC.toString()));
|
|
|
+ assertThat(result.getEntityGroups().get(1).getClassName(), equalTo("LOC"));
|
|
|
}
|
|
|
|
|
|
- public void testProcessResults() {
|
|
|
- NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, true);
|
|
|
+ public void testProcessResults_withIobMap() {
|
|
|
+
|
|
|
+ NerProcessor.IobTag[] iobMap = new NerProcessor.IobTag[] {
|
|
|
+ NerProcessor.IobTag.fromTag("B_LOC"),
|
|
|
+ NerProcessor.IobTag.fromTag("I_LOC"),
|
|
|
+ NerProcessor.IobTag.fromTag("B_MISC"),
|
|
|
+ NerProcessor.IobTag.fromTag("I_MISC"),
|
|
|
+ NerProcessor.IobTag.fromTag("B_PER"),
|
|
|
+ NerProcessor.IobTag.fromTag("I_PER"),
|
|
|
+ NerProcessor.IobTag.fromTag("B_ORG"),
|
|
|
+ NerProcessor.IobTag.fromTag("I_ORG"),
|
|
|
+ NerProcessor.IobTag.fromTag("O") };
|
|
|
+
|
|
|
+ NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(iobMap, null, true);
|
|
|
TokenizationResult tokenization = tokenize(
|
|
|
- Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london", BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN),
|
|
|
- "Many use Elasticsearch in London"
|
|
|
+ Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london", BertTokenizer.UNKNOWN_TOKEN, BertTokenizer.PAD_TOKEN),
|
|
|
+ "Elasticsearch in London"
|
|
|
);
|
|
|
|
|
|
double[][][] scores = {
|
|
|
{
|
|
|
- { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // many
|
|
|
- { 7, 0, 0, 0, 0, 0, 0, 0, 0 }, // use
|
|
|
- { 0.01, 0.01, 0, 0.01, 0, 7, 0, 3, 0 }, // el
|
|
|
+ { 0.01, 0.01, 0, 0.01, 0, 0, 7, 3, 0 }, // el
|
|
|
{ 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0 }, // ##astic
|
|
|
{ 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // ##search
|
|
|
- { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
|
|
|
- { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london
|
|
|
+ { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in
|
|
|
+ { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london
|
|
|
} };
|
|
|
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L));
|
|
|
|
|
|
- assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
|
|
|
+ assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
|
|
|
assertThat(result.getEntityGroups().size(), equalTo(2));
|
|
|
assertThat(result.getEntityGroups().get(0).getEntity(), equalTo("elasticsearch"));
|
|
|
- assertThat(result.getEntityGroups().get(0).getClassName(), equalTo(NerProcessor.Entity.ORG.toString()));
|
|
|
+ assertThat(result.getEntityGroups().get(0).getClassName(), equalTo("ORG"));
|
|
|
assertThat(result.getEntityGroups().get(1).getEntity(), equalTo("london"));
|
|
|
- assertThat(result.getEntityGroups().get(1).getClassName(), equalTo(NerProcessor.Entity.LOC.toString()));
|
|
|
+ assertThat(result.getEntityGroups().get(1).getClassName(), equalTo("LOC"));
|
|
|
}
|
|
|
|
|
|
- public void testProcessResults_withIobMap() {
|
|
|
+ public void testProcessResults_withCustomIobMap() {
|
|
|
|
|
|
NerProcessor.IobTag[] iobMap = new NerProcessor.IobTag[] {
|
|
|
- NerProcessor.IobTag.B_LOC,
|
|
|
- NerProcessor.IobTag.I_LOC,
|
|
|
- NerProcessor.IobTag.B_MISC,
|
|
|
- NerProcessor.IobTag.I_MISC,
|
|
|
- NerProcessor.IobTag.B_PER,
|
|
|
- NerProcessor.IobTag.I_PER,
|
|
|
- NerProcessor.IobTag.B_ORG,
|
|
|
- NerProcessor.IobTag.I_ORG,
|
|
|
- NerProcessor.IobTag.O };
|
|
|
+ NerProcessor.IobTag.fromTag("B_LOC"),
|
|
|
+ NerProcessor.IobTag.fromTag("I_LOC"),
|
|
|
+ NerProcessor.IobTag.fromTag("B_SOFTWARE"),
|
|
|
+ NerProcessor.IobTag.fromTag("I_SOFTWARE"),
|
|
|
+ NerProcessor.IobTag.fromTag("O") };
|
|
|
|
|
|
NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(iobMap, null, true);
|
|
|
TokenizationResult tokenization = tokenize(
|
|
@@ -192,20 +205,20 @@ public class NerProcessorTests extends ESTestCase {
|
|
|
|
|
|
double[][][] scores = {
|
|
|
{
|
|
|
- { 0.01, 0.01, 0, 0.01, 0, 0, 7, 3, 0 }, // el
|
|
|
- { 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0 }, // ##astic
|
|
|
- { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // ##search
|
|
|
- { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in
|
|
|
- { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london
|
|
|
+ { 0.01, 0.01, 7, 3, 0 }, // el
|
|
|
+ { 0.01, 0.01, 0, 0, 0 }, // ##astic
|
|
|
+ { 0, 0, 0, 0, 0 }, // ##search
|
|
|
+ { 0, 0, 0, 0, 5 }, // in
|
|
|
+ { 6, 0, 0, 0, 0 } // london
|
|
|
} };
|
|
|
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L));
|
|
|
|
|
|
- assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
|
|
|
+ assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](SOFTWARE&Elasticsearch) in [London](LOC&London)"));
|
|
|
assertThat(result.getEntityGroups().size(), equalTo(2));
|
|
|
assertThat(result.getEntityGroups().get(0).getEntity(), equalTo("elasticsearch"));
|
|
|
- assertThat(result.getEntityGroups().get(0).getClassName(), equalTo(NerProcessor.Entity.ORG.toString()));
|
|
|
+ assertThat(result.getEntityGroups().get(0).getClassName(), equalTo("SOFTWARE"));
|
|
|
assertThat(result.getEntityGroups().get(1).getEntity(), equalTo("london"));
|
|
|
- assertThat(result.getEntityGroups().get(1).getClassName(), equalTo(NerProcessor.Entity.LOC.toString()));
|
|
|
+ assertThat(result.getEntityGroups().get(1).getClassName(), equalTo("LOC"));
|
|
|
}
|
|
|
|
|
|
public void testGroupTaggedTokens() throws IOException {
|
|
@@ -215,18 +228,18 @@ public class NerProcessorTests extends ESTestCase {
|
|
|
|
|
|
List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
|
|
|
int i = 0;
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_LOC, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.B_ORG, 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_LOC"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.fromTag("B_ORG"), 1.0));
|
|
|
|
|
|
List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
|
|
|
assertThat(entityGroups, hasSize(3));
|
|
@@ -243,8 +256,8 @@ public class NerProcessorTests extends ESTestCase {
|
|
|
List<DelimitedToken> tokens = basicTokenize(randomBoolean(), randomBoolean(), List.of(), input);
|
|
|
|
|
|
List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(0), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(1), NerProcessor.IobTag.O, 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(0), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(1), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
|
|
|
List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
|
|
|
assertThat(entityGroups, is(empty()));
|
|
@@ -256,13 +269,13 @@ public class NerProcessorTests extends ESTestCase {
|
|
|
|
|
|
List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
|
|
|
int i = 0;
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.O, 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
|
|
|
List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
|
|
|
assertThat(entityGroups, hasSize(3));
|
|
@@ -280,12 +293,12 @@ public class NerProcessorTests extends ESTestCase {
|
|
|
|
|
|
List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
|
|
|
int i = 0;
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.B_ORG, 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("B_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.fromTag("B_ORG"), 1.0));
|
|
|
|
|
|
List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
|
|
|
assertThat(entityGroups, hasSize(3));
|
|
@@ -302,21 +315,21 @@ public class NerProcessorTests extends ESTestCase {
|
|
|
|
|
|
List<NerProcessor.NerResultProcessor.TaggedToken> taggedTokens = new ArrayList<>();
|
|
|
int i = 0;
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0));
|
|
|
- taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.O, 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_PER"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_ORG"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_ORG"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.fromTag("I_ORG"), 1.0));
|
|
|
+ taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.fromTag("O"), 1.0));
|
|
|
assertEquals(tokens.size(), taggedTokens.size());
|
|
|
|
|
|
List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
|
|
@@ -351,10 +364,13 @@ public class NerProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
private static TokenizationResult tokenize(List<String> vocab, String input) {
|
|
|
- BertTokenizer tokenizer = BertTokenizer.builder(vocab, new BertTokenization(true, false, null, Tokenization.Truncate.NONE, -1))
|
|
|
- .setDoLowerCase(true)
|
|
|
- .setWithSpecialTokens(false)
|
|
|
- .build();
|
|
|
- return tokenizer.buildTokenizationResult(tokenizer.tokenize(input, Tokenization.Truncate.NONE, -1, 0));
|
|
|
+ try (
|
|
|
+ BertTokenizer tokenizer = BertTokenizer.builder(vocab, new BertTokenization(true, false, null, Tokenization.Truncate.NONE, -1))
|
|
|
+ .setDoLowerCase(true)
|
|
|
+ .setWithSpecialTokens(false)
|
|
|
+ .build()
|
|
|
+ ) {
|
|
|
+ return tokenizer.buildTokenizationResult(tokenizer.tokenize(input, Tokenization.Truncate.NONE, -1, 0));
|
|
|
+ }
|
|
|
}
|
|
|
}
|