|
@@ -13,8 +13,11 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
|
@@ -31,7 +34,6 @@ import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
-import java.util.concurrent.ExecutionException;
|
|
|
|
|
|
import static org.hamcrest.CoreMatchers.is;
|
|
|
import static org.hamcrest.Matchers.closeTo;
|
|
@@ -48,22 +50,23 @@ public class LocalModelTests extends ESTestCase {
|
|
|
.setTrainedModel(buildClassification(false))
|
|
|
.build();
|
|
|
|
|
|
- Model model = new LocalModel(modelId,
|
|
|
+ Model<ClassificationConfig> model = new LocalModel<>(modelId,
|
|
|
definition,
|
|
|
new TrainedModelInput(inputFields),
|
|
|
- Collections.singletonMap("field.foo", "field.foo.keyword"));
|
|
|
+ Collections.singletonMap("field.foo", "field.foo.keyword"),
|
|
|
+ ClassificationConfig.EMPTY_PARAMS);
|
|
|
Map<String, Object> fields = new HashMap<>() {{
|
|
|
put("field.foo", 1.0);
|
|
|
put("field.bar", 0.5);
|
|
|
put("categorical", "dog");
|
|
|
}};
|
|
|
|
|
|
- SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0));
|
|
|
+ SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
|
|
|
assertThat(result.value(), equalTo(0.0));
|
|
|
assertThat(result.valueAsString(), is("0"));
|
|
|
|
|
|
ClassificationInferenceResults classificationResult =
|
|
|
- (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));
|
|
|
+ (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null));
|
|
|
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
|
|
|
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0"));
|
|
|
|
|
@@ -72,22 +75,29 @@ public class LocalModelTests extends ESTestCase {
|
|
|
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
|
|
.setTrainedModel(buildClassification(true))
|
|
|
.build();
|
|
|
- model = new LocalModel(modelId,
|
|
|
+ model = new LocalModel<>(modelId,
|
|
|
definition,
|
|
|
new TrainedModelInput(inputFields),
|
|
|
- Collections.singletonMap("field.foo", "field.foo.keyword"));
|
|
|
- result = getSingleValue(model, fields, new ClassificationConfig(0));
|
|
|
+ Collections.singletonMap("field.foo", "field.foo.keyword"),
|
|
|
+ ClassificationConfig.EMPTY_PARAMS);
|
|
|
+ result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
|
|
|
assertThat(result.value(), equalTo(0.0));
|
|
|
assertThat(result.valueAsString(), equalTo("not_to_be"));
|
|
|
|
|
|
- classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));
|
|
|
+ classificationResult = (ClassificationInferenceResults)getSingleValue(model,
|
|
|
+ fields,
|
|
|
+ new ClassificationConfigUpdate(1, null, null, null));
|
|
|
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
|
|
|
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
|
|
|
|
|
|
- classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(2));
|
|
|
+ classificationResult = (ClassificationInferenceResults)getSingleValue(model,
|
|
|
+ fields,
|
|
|
+ new ClassificationConfigUpdate(2, null, null, null));
|
|
|
assertThat(classificationResult.getTopClasses(), hasSize(2));
|
|
|
|
|
|
- classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(-1));
|
|
|
+ classificationResult = (ClassificationInferenceResults)getSingleValue(model,
|
|
|
+ fields,
|
|
|
+ new ClassificationConfigUpdate(-1, null, null, null));
|
|
|
assertThat(classificationResult.getTopClasses(), hasSize(2));
|
|
|
}
|
|
|
|
|
@@ -97,10 +107,11 @@ public class LocalModelTests extends ESTestCase {
|
|
|
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
|
|
.setTrainedModel(buildRegression())
|
|
|
.build();
|
|
|
- Model model = new LocalModel("regression_model",
|
|
|
+ Model<RegressionConfig> model = new LocalModel<>("regression_model",
|
|
|
trainedModelDefinition,
|
|
|
new TrainedModelInput(inputFields),
|
|
|
- Collections.singletonMap("bar", "bar.keyword"));
|
|
|
+ Collections.singletonMap("bar", "bar.keyword"),
|
|
|
+ RegressionConfig.EMPTY_PARAMS);
|
|
|
|
|
|
Map<String, Object> fields = new HashMap<>() {{
|
|
|
put("foo", 1.0);
|
|
@@ -108,14 +119,8 @@ public class LocalModelTests extends ESTestCase {
|
|
|
put("categorical", "dog");
|
|
|
}};
|
|
|
|
|
|
- SingleValueInferenceResults results = getSingleValue(model, fields, RegressionConfig.EMPTY_PARAMS);
|
|
|
+ SingleValueInferenceResults results = getSingleValue(model, fields, RegressionConfigUpdate.EMPTY_PARAMS);
|
|
|
assertThat(results.value(), equalTo(1.3));
|
|
|
-
|
|
|
- PlainActionFuture<InferenceResults> failedFuture = new PlainActionFuture<>();
|
|
|
- model.infer(fields, new ClassificationConfig(2), failedFuture);
|
|
|
- ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get);
|
|
|
- assertThat(ex.getCause().getMessage(),
|
|
|
- equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]"));
|
|
|
}
|
|
|
|
|
|
public void testAllFieldsMissing() throws Exception {
|
|
@@ -124,7 +129,12 @@ public class LocalModelTests extends ESTestCase {
|
|
|
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
|
|
.setTrainedModel(buildRegression())
|
|
|
.build();
|
|
|
- Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields), null);
|
|
|
+ Model<RegressionConfig> model = new LocalModel<>(
|
|
|
+ "regression_model",
|
|
|
+ trainedModelDefinition,
|
|
|
+ new TrainedModelInput(inputFields),
|
|
|
+ null,
|
|
|
+ RegressionConfig.EMPTY_PARAMS);
|
|
|
|
|
|
Map<String, Object> fields = new HashMap<>() {{
|
|
|
put("something", 1.0);
|
|
@@ -132,18 +142,21 @@ public class LocalModelTests extends ESTestCase {
|
|
|
put("baz", "dog");
|
|
|
}};
|
|
|
|
|
|
- WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfig.EMPTY_PARAMS);
|
|
|
+ WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfigUpdate.EMPTY_PARAMS);
|
|
|
assertThat(results.getWarning(),
|
|
|
equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model")));
|
|
|
}
|
|
|
|
|
|
- private static SingleValueInferenceResults getSingleValue(Model model,
|
|
|
- Map<String, Object> fields,
|
|
|
- InferenceConfig config) throws Exception {
|
|
|
+ private static <T extends InferenceConfig> SingleValueInferenceResults getSingleValue(Model<T> model,
|
|
|
+ Map<String, Object> fields,
|
|
|
+ InferenceConfigUpdate<T> config)
|
|
|
+ throws Exception {
|
|
|
return (SingleValueInferenceResults)getInferenceResult(model, fields, config);
|
|
|
}
|
|
|
|
|
|
- private static InferenceResults getInferenceResult(Model model, Map<String, Object> fields, InferenceConfig config) throws Exception {
|
|
|
+ private static <T extends InferenceConfig> InferenceResults getInferenceResult(Model<T> model,
|
|
|
+ Map<String, Object> fields,
|
|
|
+ InferenceConfigUpdate<T> config) throws Exception {
|
|
|
PlainActionFuture<InferenceResults> future = new PlainActionFuture<>();
|
|
|
model.infer(fields, config, future);
|
|
|
return future.get();
|