|
@@ -16,11 +16,12 @@ import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.Map;
|
|
|
|
|
|
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests.randomClassificationConfig;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
|
|
|
public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<ClassificationConfigUpdate> {
|
|
|
|
|
|
- public static ClassificationConfigUpdate randomClassificationConfig() {
|
|
|
+ public static ClassificationConfigUpdate randomClassificationConfigUpdate() {
|
|
|
return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10),
|
|
|
randomBoolean() ? null : randomAlphaOfLength(10),
|
|
|
randomBoolean() ? null : randomAlphaOfLength(10),
|
|
@@ -49,9 +50,33 @@ public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTes
|
|
|
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
|
|
|
}
|
|
|
|
|
|
+ public void testApply() {
|
|
|
+ ClassificationConfig originalConfig = randomClassificationConfig();
|
|
|
+
|
|
|
+ assertThat(originalConfig, equalTo(ClassificationConfigUpdate.EMPTY_PARAMS.apply(originalConfig)));
|
|
|
+
|
|
|
+ assertThat(new ClassificationConfig.Builder(originalConfig).setNumTopClasses(5).build(),
|
|
|
+ equalTo(new ClassificationConfigUpdate.Builder().setNumTopClasses(5).build().apply(originalConfig)));
|
|
|
+ assertThat(new ClassificationConfig.Builder()
|
|
|
+ .setNumTopClasses(5)
|
|
|
+ .setNumTopFeatureImportanceValues(1)
|
|
|
+ .setPredictionFieldType(PredictionFieldType.BOOLEAN)
|
|
|
+ .setResultsField("foo")
|
|
|
+ .setTopClassesResultsField("bar").build(),
|
|
|
+ equalTo(new ClassificationConfigUpdate.Builder()
|
|
|
+ .setNumTopClasses(5)
|
|
|
+ .setNumTopFeatureImportanceValues(1)
|
|
|
+ .setPredictionFieldType(PredictionFieldType.BOOLEAN)
|
|
|
+ .setResultsField("foo")
|
|
|
+ .setTopClassesResultsField("bar")
|
|
|
+ .build()
|
|
|
+ .apply(originalConfig)
|
|
|
+ ));
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
protected ClassificationConfigUpdate createTestInstance() {
|
|
|
- return randomClassificationConfig();
|
|
|
+ return randomClassificationConfigUpdate();
|
|
|
}
|
|
|
|
|
|
@Override
|