1
0
Эх сурвалжийг харах

[ML] fixing zero_shot_classification config override (#78415)

bug in zero_shot_classification didn't take the provided override labels into account.
Benjamin Trent 4 жил өмнө
parent
commit
bceb38a6dc

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java

@@ -69,7 +69,7 @@ public class ZeroShotClassificationProcessor implements NlpTask.Processor {
         } else {
             labels = this.labels;
         }
-        if (this.labels == null || this.labels.length == 0) {
+        if (labels == null || labels.length == 0) {
             throw ExceptionsHelper.badRequestException("zero_shot_classification requires non-empty [labels]");
         }
         return new RequestBuilder(tokenizer, labels, hypothesisTemplate);

+ 64 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessorTests.java

@@ -0,0 +1,64 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.hasSize;
+
+public class ZeroShotClassificationProcessorTests extends ESTestCase {
+
+    @SuppressWarnings("unchecked")
+    public void testBuildRequest() throws IOException {
+        NlpTokenizer tokenizer = NlpTokenizer.build(
+            new Vocabulary(
+                Arrays.asList("Elastic", "##search", "fun", "default", "label", "new", "stuff", "This", "example", "is", ".",
+                    BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
+                randomAlphaOfLength(10)
+            ),
+            new BertTokenization(null, true, 512));
+
+        ZeroShotClassificationConfig config = new ZeroShotClassificationConfig(
+            List.of("entailment", "neutral", "contradiction"),
+            new VocabularyConfig("test-index"),
+            null,
+            null,
+            null,
+            null
+        );
+        ZeroShotClassificationProcessor processor = new ZeroShotClassificationProcessor(tokenizer, config);
+
+        NlpTask.Request request = processor.getRequestBuilder(
+            (NlpConfig)new ZeroShotClassificationConfigUpdate.Builder().setLabels(List.of("new", "stuff")).build().apply(config)
+        ).buildRequest(List.of("Elasticsearch fun"), "request1");
+
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+
+        assertThat(jsonDocAsMap.keySet(), hasSize(5));
+        assertEquals("request1", jsonDocAsMap.get("request_id"));
+        assertEquals(Arrays.asList(11, 0, 1, 2, 12, 7, 8, 9, 5, 10, 12), ((List<List<Integer>>)jsonDocAsMap.get("tokens")).get(0));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), ((List<List<Integer>>)jsonDocAsMap.get("arg_1")).get(0));
+        assertEquals(Arrays.asList(11, 0, 1, 2, 12, 7, 8, 9, 6, 10, 12), ((List<List<Integer>>)jsonDocAsMap.get("tokens")).get(1));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), ((List<List<Integer>>)jsonDocAsMap.get("arg_1")).get(1));
+    }
+
+}