|
|
@@ -1,101 +0,0 @@
|
|
|
-/*
|
|
|
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
|
- * or more contributor license agreements. Licensed under the Elastic License
|
|
|
- * 2.0; you may not use this file except in compliance with the Elastic License
|
|
|
- * 2.0.
|
|
|
- */
|
|
|
-
|
|
|
-package org.elasticsearch.xpack.ml.inference.nlp;
|
|
|
-
|
|
|
-import org.elasticsearch.ElasticsearchStatusException;
|
|
|
-import org.elasticsearch.common.bytes.BytesReference;
|
|
|
-import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
-import org.elasticsearch.common.xcontent.XContentType;
|
|
|
-import org.elasticsearch.test.ESTestCase;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
|
|
|
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
|
|
|
-
|
|
|
-import java.io.IOException;
|
|
|
-import java.util.Arrays;
|
|
|
-import java.util.List;
|
|
|
-import java.util.Map;
|
|
|
-
|
|
|
-import static org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilderTests.nthListItemFromMap;
|
|
|
-import static org.hamcrest.Matchers.containsString;
|
|
|
-import static org.hamcrest.Matchers.hasSize;
|
|
|
-
|
|
|
-public class DistilBertRequestBuilderTests extends ESTestCase {
|
|
|
-
|
|
|
- public void testBuildRequest() throws IOException {
|
|
|
- BertTokenizer tokenizer = BertTokenizer.builder(
|
|
|
- Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
|
|
|
- new DistilBertTokenization(null, null, 512)
|
|
|
- ).build();
|
|
|
-
|
|
|
- DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
|
|
|
- BytesReference bytesReference = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1").processInput;
|
|
|
-
|
|
|
- Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(bytesReference, true, XContentType.JSON).v2();
|
|
|
-
|
|
|
- assertThat(jsonDocAsMap.keySet(), hasSize(3));
|
|
|
- assertEquals("request1", jsonDocAsMap.get("request_id"));
|
|
|
- assertEquals(Arrays.asList(3, 0, 1, 2, 4), nthListItemFromMap("tokens", 0, jsonDocAsMap));
|
|
|
- assertEquals(Arrays.asList(1, 1, 1, 1, 1), nthListItemFromMap("arg_1", 0, jsonDocAsMap));
|
|
|
- }
|
|
|
-
|
|
|
- public void testInputTooLarge() throws IOException {
|
|
|
- BertTokenizer tokenizer = BertTokenizer.builder(
|
|
|
- Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
|
|
|
- new DistilBertTokenization(null, null, 5)
|
|
|
- ).build();
|
|
|
- {
|
|
|
- DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
|
|
|
- ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> requestBuilder.buildRequest(List.of("Elasticsearch fun Elasticsearch fun Elasticsearch fun"), "request1"));
|
|
|
-
|
|
|
- assertThat(e.getMessage(),
|
|
|
- containsString("Input too large. The tokenized input length [11] exceeds the maximum sequence length [5]"));
|
|
|
- }
|
|
|
- {
|
|
|
- DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
|
|
|
- // input will become 3 tokens + the Class and Separator token = 5 which is
|
|
|
- // our max sequence length
|
|
|
- requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1");
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
- public void testBatchWithPadding() throws IOException {
|
|
|
- BertTokenizer tokenizer = BertTokenizer.builder(
|
|
|
- Arrays.asList(BertTokenizer.PAD_TOKEN, BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN,
|
|
|
- "Elastic", "##search", "fun",
|
|
|
- "Pancake", "day",
|
|
|
- "my", "little", "red", "car",
|
|
|
- "God", "##zilla"
|
|
|
- ),
|
|
|
- new BertTokenization(null, null, 512)
|
|
|
- ).build();
|
|
|
-
|
|
|
- DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
|
|
|
- NlpTask.Request request = requestBuilder.buildRequest(
|
|
|
- List.of("Elasticsearch",
|
|
|
- "my little red car",
|
|
|
- "Godzilla day"), "request1");
|
|
|
- Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
|
|
|
-
|
|
|
- assertEquals("request1", jsonDocAsMap.get("request_id"));
|
|
|
- assertThat(jsonDocAsMap.keySet(), hasSize(3));
|
|
|
- assertThat((List<List<Integer>>) jsonDocAsMap.get("tokens"), hasSize(3));
|
|
|
- assertThat((List<List<Integer>>) jsonDocAsMap.get("arg_1"), hasSize(3));
|
|
|
-
|
|
|
- assertEquals(Arrays.asList(1, 3, 4, 2, 0, 0), nthListItemFromMap("tokens", 0, jsonDocAsMap));
|
|
|
- assertEquals(Arrays.asList(1, 1, 1, 1, 0, 0), nthListItemFromMap("arg_1", 0, jsonDocAsMap));
|
|
|
-
|
|
|
- assertEquals(Arrays.asList(1, 8, 9, 10, 11, 2), nthListItemFromMap("tokens", 1, jsonDocAsMap));
|
|
|
- assertEquals(Arrays.asList(1, 1, 1, 1, 1, 1), nthListItemFromMap("arg_1", 1, jsonDocAsMap));
|
|
|
-
|
|
|
- assertEquals(Arrays.asList(1, 12, 13, 7, 2, 0), nthListItemFromMap("tokens", 2, jsonDocAsMap));
|
|
|
- assertEquals(Arrays.asList(1, 1, 1, 1, 1, 0), nthListItemFromMap("arg_1", 2, jsonDocAsMap));
|
|
|
- }
|
|
|
-}
|