|
@@ -7,6 +7,7 @@
|
|
|
package org.elasticsearch.xpack.ml.inference.ingest;
|
|
|
|
|
|
import org.elasticsearch.ElasticsearchException;
|
|
|
+import org.elasticsearch.ElasticsearchParseException;
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
import org.elasticsearch.cluster.ClusterName;
|
|
@@ -26,7 +27,6 @@ import org.elasticsearch.common.transport.TransportAddress;
|
|
|
import org.elasticsearch.common.util.Maps;
|
|
|
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
|
|
import org.elasticsearch.core.Tuple;
|
|
|
-import org.elasticsearch.inference.InferenceResults;
|
|
|
import org.elasticsearch.ingest.IngestMetadata;
|
|
|
import org.elasticsearch.ingest.PipelineConfiguration;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
@@ -36,21 +36,33 @@ import org.elasticsearch.xcontent.XContentFactory;
|
|
|
import org.elasticsearch.xcontent.XContentType;
|
|
|
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
|
|
|
import org.elasticsearch.xpack.ml.MachineLearning;
|
|
|
import org.junit.Before;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.net.InetAddress;
|
|
|
+import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
@@ -59,7 +71,11 @@ import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
|
|
|
+import static org.hamcrest.Matchers.containsString;
|
|
|
+import static org.hamcrest.Matchers.empty;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
+import static org.hamcrest.Matchers.hasEntry;
|
|
|
+import static org.hamcrest.Matchers.hasSize;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.when;
|
|
|
|
|
@@ -89,7 +105,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
clusterService = new ClusterService(settings, clusterSettings, tp, null);
|
|
|
}
|
|
|
|
|
|
- public void testCreateProcessorWithTooManyExisting() throws Exception {
|
|
|
+ public void testCreateProcessorWithTooManyExisting() {
|
|
|
Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
|
|
|
|
|
|
includeNodeInfoValues.forEach(includeNodeInfo -> {
|
|
@@ -135,7 +151,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
Map<String, Object> config = new HashMap<>() {
|
|
|
{
|
|
|
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap()));
|
|
|
}
|
|
@@ -158,7 +174,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
Map<String, Object> config2 = new HashMap<>() {
|
|
|
{
|
|
|
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom"));
|
|
|
}
|
|
@@ -172,7 +188,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
Map<String, Object> config3 = new HashMap<>() {
|
|
|
{
|
|
|
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap());
|
|
|
}
|
|
@@ -185,7 +201,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException {
|
|
|
+ public void testCreateProcessorWithTooOldMinNodeVersion() {
|
|
|
Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
|
|
|
|
|
|
includeNodeInfoValues.forEach(includeNodeInfo -> {
|
|
@@ -203,7 +219,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
Map<String, Object> regression = new HashMap<>() {
|
|
|
{
|
|
|
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
put(
|
|
|
InferenceProcessor.INFERENCE_CONFIG,
|
|
@@ -224,7 +240,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
Map<String, Object> classification = new HashMap<>() {
|
|
|
{
|
|
|
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
put(
|
|
|
InferenceProcessor.INFERENCE_CONFIG,
|
|
@@ -315,7 +331,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
|
|
|
Map<String, Object> minimalConfig = new HashMap<>() {
|
|
|
{
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
}
|
|
|
};
|
|
@@ -342,7 +358,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
Map<String, Object> regression = new HashMap<>() {
|
|
|
{
|
|
|
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
put(
|
|
|
InferenceProcessor.INFERENCE_CONFIG,
|
|
@@ -351,12 +367,18 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
- processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression);
|
|
|
+ var processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression);
|
|
|
+ assertEquals(includeNodeInfo, processor.getAuditor().includeNodeInfo());
|
|
|
+ assertFalse(processor.isConfiguredWithInputsFields());
|
|
|
+ assertEquals("my_model", processor.getModelId());
|
|
|
+ assertEquals("result", processor.getTargetField());
|
|
|
+ assertThat(processor.getFieldMap().entrySet(), empty());
|
|
|
+ assertNull(processor.getInputs());
|
|
|
|
|
|
Map<String, Object> classification = new HashMap<>() {
|
|
|
{
|
|
|
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
put(
|
|
|
InferenceProcessor.INFERENCE_CONFIG,
|
|
@@ -368,19 +390,79 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
- processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, classification);
|
|
|
+ processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, classification);
|
|
|
+ assertFalse(processor.isConfiguredWithInputsFields());
|
|
|
|
|
|
Map<String, Object> mininmal = new HashMap<>() {
|
|
|
{
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
}
|
|
|
};
|
|
|
|
|
|
- processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, mininmal);
|
|
|
+ processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, mininmal);
|
|
|
+ assertFalse(processor.isConfiguredWithInputsFields());
|
|
|
+ assertEquals("my_model", processor.getModelId());
|
|
|
+ assertEquals("result", processor.getTargetField());
|
|
|
+ assertNull(processor.getInputs());
|
|
|
});
|
|
|
}
|
|
|
|
|
|
+ public void testCreateProcessorWithFieldMap() {
|
|
|
+ InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, false);
|
|
|
+
|
|
|
+ Map<String, Object> config = new HashMap<>() {
|
|
|
+ {
|
|
|
+ put(InferenceProcessor.FIELD_MAP, Collections.singletonMap("source", "dest"));
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
+ put(InferenceProcessor.TARGET_FIELD, "result");
|
|
|
+ put(
|
|
|
+ InferenceProcessor.INFERENCE_CONFIG,
|
|
|
+ Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap())
|
|
|
+ );
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ var processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, config);
|
|
|
+ assertFalse(processor.isConfiguredWithInputsFields());
|
|
|
+ assertEquals("my_model", processor.getModelId());
|
|
|
+ assertEquals("result", processor.getTargetField());
|
|
|
+ assertNull(processor.getInputs());
|
|
|
+ var fieldMap = processor.getFieldMap();
|
|
|
+ assertThat(fieldMap.entrySet(), hasSize(1));
|
|
|
+ assertThat(fieldMap, hasEntry("source", "dest"));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCreateProcessorWithInputOutputs() {
|
|
|
+ InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, false);
|
|
|
+
|
|
|
+ Map<String, Object> config = new HashMap<>();
|
|
|
+ config.put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
+
|
|
|
+ Map<String, Object> input1 = new HashMap<>();
|
|
|
+ input1.put(InferenceProcessor.INPUT_FIELD, "in1");
|
|
|
+ input1.put(InferenceProcessor.OUTPUT_FIELD, "out1");
|
|
|
+ Map<String, Object> input2 = new HashMap<>();
|
|
|
+ input2.put(InferenceProcessor.INPUT_FIELD, "in2");
|
|
|
+ input2.put(InferenceProcessor.OUTPUT_FIELD, "out2");
|
|
|
+
|
|
|
+ List<Map<String, Object>> inputOutputs = new ArrayList<>();
|
|
|
+ inputOutputs.add(input1);
|
|
|
+ inputOutputs.add(input2);
|
|
|
+ config.put(InferenceProcessor.INPUT_OUTPUT, inputOutputs);
|
|
|
+
|
|
|
+ var processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, config);
|
|
|
+ assertTrue(processor.isConfiguredWithInputsFields());
|
|
|
+ assertEquals("my_model", processor.getModelId());
|
|
|
+ var configuredInputs = processor.getInputs();
|
|
|
+ assertThat(configuredInputs, hasSize(2));
|
|
|
+ assertEquals(configuredInputs.get(0).inputField(), "in1");
|
|
|
+ assertEquals(configuredInputs.get(0).outputField(), "out1");
|
|
|
+ assertEquals(configuredInputs.get(1).inputField(), "in2");
|
|
|
+ assertEquals(configuredInputs.get(1).outputField(), "out2");
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
public void testCreateProcessorWithDuplicateFields() {
|
|
|
Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
|
|
|
|
|
@@ -395,7 +477,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
Map<String, Object> regression = new HashMap<>() {
|
|
|
{
|
|
|
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
put(InferenceProcessor.TARGET_FIELD, "ml");
|
|
|
put(
|
|
|
InferenceProcessor.INFERENCE_CONFIG,
|
|
@@ -415,7 +497,41 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- public void testParseFromMap() {
|
|
|
+ public void testCreateProcessorWithIgnoreMissing() {
|
|
|
+ Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
|
|
|
+
|
|
|
+ includeNodeInfoValues.forEach(includeNodeInfo -> {
|
|
|
+ InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
|
|
|
+ client,
|
|
|
+ clusterService,
|
|
|
+ Settings.EMPTY,
|
|
|
+ includeNodeInfo
|
|
|
+ );
|
|
|
+
|
|
|
+ Map<String, Object> regression = new HashMap<>() {
|
|
|
+ {
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
+ put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
|
|
+ put("ignore_missing", Boolean.TRUE);
|
|
|
+ put(
|
|
|
+ InferenceProcessor.INFERENCE_CONFIG,
|
|
|
+ Collections.singletonMap(
|
|
|
+ RegressionConfig.NAME.getPreferredName(),
|
|
|
+ Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ Exception ex = expectThrows(
|
|
|
+ Exception.class,
|
|
|
+ () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression)
|
|
|
+ );
|
|
|
+ assertThat(ex.getMessage(), equalTo("Invalid inference config. " + "More than one field is configured as [warning]"));
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParseInferenceConfigFromMap() {
|
|
|
Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
|
|
|
|
|
|
includeNodeInfoValues.forEach(includeNodeInfo -> {
|
|
@@ -433,6 +549,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
Tuple.tuple(PassThroughConfig.NAME, Map.of()),
|
|
|
Tuple.tuple(TextClassificationConfig.NAME, Map.of()),
|
|
|
Tuple.tuple(TextEmbeddingConfig.NAME, Map.of()),
|
|
|
+ Tuple.tuple(TextExpansionConfig.NAME, Map.of()),
|
|
|
Tuple.tuple(ZeroShotClassificationConfig.NAME, Map.of()),
|
|
|
Tuple.tuple(QuestionAnsweringConfig.NAME, Map.of("question", "What is the answer to life, the universe and everything?"))
|
|
|
)) {
|
|
@@ -444,8 +561,231 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- private static ClusterState buildClusterState(Metadata metadata) {
|
|
|
- return ClusterState.builder(new ClusterName("_name")).metadata(metadata).build();
|
|
|
+ public void testCreateProcessorWithIncompatibleTargetFieldSetting() {
|
|
|
+ InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
|
|
|
+ client,
|
|
|
+ clusterService,
|
|
|
+ Settings.EMPTY,
|
|
|
+ randomBoolean()
|
|
|
+ );
|
|
|
+
|
|
|
+ Map<String, Object> input = new HashMap<>() {
|
|
|
+ {
|
|
|
+ put(InferenceProcessor.INPUT_FIELD, "in");
|
|
|
+ put(InferenceProcessor.OUTPUT_FIELD, "out");
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ Map<String, Object> config = new HashMap<>() {
|
|
|
+ {
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
+ put(InferenceProcessor.TARGET_FIELD, "ml");
|
|
|
+ put(InferenceProcessor.INPUT_OUTPUT, List.of(input));
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ ElasticsearchParseException ex = expectThrows(
|
|
|
+ ElasticsearchParseException.class,
|
|
|
+ () -> processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config)
|
|
|
+ );
|
|
|
+ assertThat(
|
|
|
+ ex.getMessage(),
|
|
|
+ containsString(
|
|
|
+ "[target_field] option is incompatible with [input_output]. Use the [output_field] option to specify where to write the "
|
|
|
+ + "inference results to."
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCreateProcessorWithIncompatibleResultFieldSetting() {
|
|
|
+ InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
|
|
|
+ client,
|
|
|
+ clusterService,
|
|
|
+ Settings.EMPTY,
|
|
|
+ randomBoolean()
|
|
|
+ );
|
|
|
+
|
|
|
+ Map<String, Object> input = new HashMap<>() {
|
|
|
+ {
|
|
|
+ put(InferenceProcessor.INPUT_FIELD, "in");
|
|
|
+ put(InferenceProcessor.OUTPUT_FIELD, "out");
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ Map<String, Object> config = new HashMap<>() {
|
|
|
+ {
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
+ put(InferenceProcessor.INPUT_OUTPUT, List.of(input));
|
|
|
+ put(
|
|
|
+ InferenceProcessor.INFERENCE_CONFIG,
|
|
|
+ Collections.singletonMap(
|
|
|
+ TextExpansionConfig.NAME,
|
|
|
+ Collections.singletonMap(TextExpansionConfig.RESULTS_FIELD.getPreferredName(), "foo")
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ ElasticsearchParseException ex = expectThrows(
|
|
|
+ ElasticsearchParseException.class,
|
|
|
+ () -> processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config)
|
|
|
+ );
|
|
|
+ assertThat(
|
|
|
+ ex.getMessage(),
|
|
|
+ containsString(
|
|
|
+ "The [inference_config.results_field] setting is incompatible with using [input_output]. "
|
|
|
+ + "Prefer to use the [input_output.output_field] option to specify where to write the inference results to."
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCreateProcessorWithInputFields() {
|
|
|
+ InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
|
|
|
+ client,
|
|
|
+ clusterService,
|
|
|
+ Settings.EMPTY,
|
|
|
+ randomBoolean()
|
|
|
+ );
|
|
|
+
|
|
|
+ Map<String, Object> inputMap = new HashMap<>() {
|
|
|
+ {
|
|
|
+ put(InferenceProcessor.INPUT_FIELD, "in");
|
|
|
+ put(InferenceProcessor.OUTPUT_FIELD, "out");
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ String inferenceConfigType = randomFrom(
|
|
|
+ ClassificationConfigUpdate.NAME.getPreferredName(),
|
|
|
+ RegressionConfigUpdate.NAME.getPreferredName(),
|
|
|
+ FillMaskConfigUpdate.NAME,
|
|
|
+ NerConfigUpdate.NAME,
|
|
|
+ PassThroughConfigUpdate.NAME,
|
|
|
+ QuestionAnsweringConfigUpdate.NAME,
|
|
|
+ TextClassificationConfigUpdate.NAME,
|
|
|
+ TextEmbeddingConfigUpdate.NAME,
|
|
|
+ TextExpansionConfigUpdate.NAME,
|
|
|
+ TextSimilarityConfigUpdate.NAME,
|
|
|
+ ZeroShotClassificationConfigUpdate.NAME
|
|
|
+ );
|
|
|
+
|
|
|
+ Map<String, Object> config = new HashMap<>() {
|
|
|
+ {
|
|
|
+ put(InferenceProcessor.MODEL_ID, "my_model");
|
|
|
+ put(InferenceProcessor.INPUT_OUTPUT, List.of(inputMap));
|
|
|
+ put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, Collections.emptyMap()));
|
|
|
+ }
|
|
|
+ };
|
|
|
+ // create valid inference configs with required fields
|
|
|
+ if (inferenceConfigType.equals(TextSimilarityConfigUpdate.NAME)) {
|
|
|
+ var inferenceConfig = new HashMap<String, String>();
|
|
|
+ inferenceConfig.put(TextSimilarityConfig.TEXT.getPreferredName(), "text to compare");
|
|
|
+ config.put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, inferenceConfig));
|
|
|
+ } else if (inferenceConfigType.equals(QuestionAnsweringConfigUpdate.NAME)) {
|
|
|
+ var inferenceConfig = new HashMap<String, String>();
|
|
|
+ inferenceConfig.put(QuestionAnsweringConfig.QUESTION.getPreferredName(), "why is the sky blue?");
|
|
|
+ config.put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, inferenceConfig));
|
|
|
+ } else {
|
|
|
+ config.put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, Collections.emptyMap()));
|
|
|
+ }
|
|
|
+
|
|
|
+ var inferenceProcessor = processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config);
|
|
|
+ assertEquals("my_model", inferenceProcessor.getModelId());
|
|
|
+ assertTrue(inferenceProcessor.isConfiguredWithInputsFields());
|
|
|
+
|
|
|
+ var inputs = inferenceProcessor.getInputs();
|
|
|
+ assertThat(inputs, hasSize(1));
|
|
|
+ assertEquals(inputs.get(0), new InferenceProcessor.Factory.InputConfig("in", null, "out", Map.of()));
|
|
|
+
|
|
|
+ assertNull(inferenceProcessor.getFieldMap());
|
|
|
+ assertNull(inferenceProcessor.getTargetField());
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParsingInputFields() {
|
|
|
+ InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
|
|
|
+ client,
|
|
|
+ clusterService,
|
|
|
+ Settings.EMPTY,
|
|
|
+ randomBoolean()
|
|
|
+ );
|
|
|
+
|
|
|
+ int numInputs = randomIntBetween(1, 3);
|
|
|
+ List<Map<String, Object>> inputs = new ArrayList<>();
|
|
|
+ for (int i = 0; i < numInputs; i++) {
|
|
|
+ Map<String, Object> inputMap = new HashMap<>();
|
|
|
+ inputMap.put(InferenceProcessor.INPUT_FIELD, "in" + i);
|
|
|
+ inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out." + i);
|
|
|
+ inputs.add(inputMap);
|
|
|
+ }
|
|
|
+
|
|
|
+ var parsedInputs = processorFactory.parseInputFields("my_processor", inputs);
|
|
|
+ assertThat(parsedInputs, hasSize(numInputs));
|
|
|
+ for (int i = 0; i < numInputs; i++) {
|
|
|
+ assertEquals(new InferenceProcessor.Factory.InputConfig("in" + i, "out", Integer.toString(i), Map.of()), parsedInputs.get(i));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParsingInputFieldsDuplicateFieldNames() {
|
|
|
+ InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
|
|
|
+ client,
|
|
|
+ clusterService,
|
|
|
+ Settings.EMPTY,
|
|
|
+ randomBoolean()
|
|
|
+ );
|
|
|
+
|
|
|
+ int numInputs = 2;
|
|
|
+ {
|
|
|
+ List<Map<String, Object>> inputs = new ArrayList<>();
|
|
|
+ for (int i = 0; i < numInputs; i++) {
|
|
|
+ Map<String, Object> inputMap = new HashMap<>();
|
|
|
+ inputMap.put(InferenceProcessor.INPUT_FIELD, "in");
|
|
|
+ inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out" + i);
|
|
|
+ inputs.add(inputMap);
|
|
|
+ }
|
|
|
+
|
|
|
+ var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", inputs));
|
|
|
+ assertThat(e.getMessage(), containsString("[input_field] names must be unique but [in] is repeated"));
|
|
|
+ }
|
|
|
+
|
|
|
+ {
|
|
|
+ List<Map<String, Object>> inputs = new ArrayList<>();
|
|
|
+ for (int i = 0; i < numInputs; i++) {
|
|
|
+ Map<String, Object> inputMap = new HashMap<>();
|
|
|
+ inputMap.put(InferenceProcessor.INPUT_FIELD, "in" + i);
|
|
|
+ inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out");
|
|
|
+ inputs.add(inputMap);
|
|
|
+ }
|
|
|
+
|
|
|
+ var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", inputs));
|
|
|
+ assertThat(e.getMessage(), containsString("[output_field] names must be unique but [out] is repeated"));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testExtractBasePathAndFinalElement() {
|
|
|
+ {
|
|
|
+ String path = "foo.bar.result";
|
|
|
+ var extractedPaths = InferenceProcessor.Factory.extractBasePathAndFinalElement(path);
|
|
|
+ assertEquals("foo.bar", extractedPaths.v1());
|
|
|
+ assertEquals("result", extractedPaths.v2());
|
|
|
+ }
|
|
|
+
|
|
|
+ {
|
|
|
+ String path = "result";
|
|
|
+ var extractedPaths = InferenceProcessor.Factory.extractBasePathAndFinalElement(path);
|
|
|
+ assertNull(extractedPaths.v1());
|
|
|
+ assertEquals("result", extractedPaths.v2());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParsingInputFieldsGivenNoInputs() {
|
|
|
+ InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
|
|
|
+ client,
|
|
|
+ clusterService,
|
|
|
+ Settings.EMPTY,
|
|
|
+ randomBoolean()
|
|
|
+ );
|
|
|
+
|
|
|
+ var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", List.of()));
|
|
|
+ assertThat(e.getMessage(), containsString("[input_output] cannot be empty at least one is required"));
|
|
|
}
|
|
|
|
|
|
private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException {
|
|
@@ -513,7 +853,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|
|
private static Map<String, Object> inferenceProcessorForModel(String modelId) {
|
|
|
return Collections.singletonMap(InferenceProcessor.TYPE, new HashMap<>() {
|
|
|
{
|
|
|
- put(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId);
|
|
|
+ put(InferenceProcessor.MODEL_ID, modelId);
|
|
|
put(
|
|
|
InferenceProcessor.INFERENCE_CONFIG,
|
|
|
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap())
|