Răsfoiți Sursa

[ML] Speed up PutTrainedModelActionResponseTests (#90876)

David Kyle 3 ani în urmă
părinte
comite
0671e08bc4

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java

@@ -22,7 +22,7 @@ public class PutTrainedModelActionResponseTests extends AbstractWireSerializingT
         String modelId = randomAlphaOfLength(10);
         return new Response(
             TrainedModelConfigTests.createTestInstance(modelId, randomBoolean())
-                .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
+                .setParsedDefinition(TrainedModelDefinitionTests.createSmallRandomBuilder())
                 .build()
         );
     }

+ 30 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java

@@ -73,10 +73,40 @@ public class TrainedModelDefinitionTests extends AbstractXContentSerializingTest
         ).setTrainedModel(randomFrom(TreeTests.createRandom(targetType), EnsembleTests.createRandom(targetType)));
     }
 
+    public static TrainedModelDefinition.Builder createRandomBuilder(
+        TargetType targetType,
+        int numberOfProcessors,
+        int numberOfFeatures,
+        int numberOfModels,
+        int treeDepth
+    ) {
+        return new TrainedModelDefinition.Builder().setPreProcessors(
+            randomBoolean()
+                ? null
+                : Stream.generate(
+                    () -> randomFrom(
+                        FrequencyEncodingTests.createRandom(),
+                        OneHotEncodingTests.createRandom(),
+                        TargetMeanEncodingTests.createRandom()
+                    )
+                ).limit(numberOfProcessors).collect(Collectors.toList())
+        )
+            .setTrainedModel(
+                randomFrom(
+                    TreeTests.createRandom(targetType, numberOfFeatures, treeDepth),
+                    EnsembleTests.createRandom(targetType, numberOfFeatures, numberOfModels, treeDepth)
+                )
+            );
+    }
+
     public static TrainedModelDefinition.Builder createRandomBuilder() {
         return createRandomBuilder(randomFrom(TargetType.values()));
     }
 
+    public static TrainedModelDefinition.Builder createSmallRandomBuilder() {
+        return createRandomBuilder(randomFrom(TargetType.values()), 2, 3, 2, 3);
+    }
+
     public static final String ENSEMBLE_MODEL = """
         {
           "preprocessors": [

+ 13 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -62,16 +62,25 @@ public class EnsembleTests extends AbstractXContentSerializingTestCase<Ensemble>
 
     public static Ensemble createRandom(TargetType targetType) {
         int numberOfFeatures = randomIntBetween(1, 10);
-        List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
-        return createRandom(targetType, featureNames);
+        int numberOfModels = randomIntBetween(1, 10);
+        return createRandom(targetType, numberOfFeatures, numberOfModels, 6);
     }
 
-    public static Ensemble createRandom(TargetType targetType, List<String> featureNames) {
+    public static Ensemble createRandom(TargetType targetType, int numberOfFeatures) {
         int numberOfModels = randomIntBetween(1, 10);
+        return createRandom(targetType, numberOfFeatures, numberOfModels, 6);
+    }
+
+    public static Ensemble createRandom(TargetType targetType, int numberOfFeatures, int numberOfModels, int treeDepth) {
+        List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
+        return createRandom(targetType, featureNames, numberOfModels, treeDepth);
+    }
+
+    public static Ensemble createRandom(TargetType targetType, List<String> featureNames, int numberOfModels, int treeDepth) {
         List<String> treeFeatureNames = featureNames.isEmpty()
             ? Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList())
             : featureNames;
-        List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(treeFeatureNames, 6))
+        List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(treeFeatureNames, treeDepth))
             .limit(numberOfModels)
             .collect(Collectors.toList());
         double[] weights = randomBoolean()

+ 2 - 11
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java

@@ -36,7 +36,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
-import java.util.stream.Stream;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
 import static org.hamcrest.Matchers.closeTo;
@@ -66,21 +65,13 @@ public class EnsembleInferenceModelTests extends ESTestCase {
     public void testSerializationFromEnsemble() throws Exception {
         for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
             int numberOfFeatures = randomIntBetween(1, 10);
-            Ensemble ensemble = EnsembleTests.createRandom(
-                randomFrom(TargetType.values()),
-                randomBoolean()
-                    ? Collections.emptyList()
-                    : Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList())
-            );
+            Ensemble ensemble = EnsembleTests.createRandom(randomFrom(TargetType.values()), randomBoolean() ? 0 : numberOfFeatures);
             assertThat(serializeFromTrainedModel(ensemble), is(not(nullValue())));
         }
     }
 
     public void testInferenceWithoutPreparing() throws IOException {
-        Ensemble ensemble = EnsembleTests.createRandom(
-            TargetType.REGRESSION,
-            Stream.generate(() -> randomAlphaOfLength(10)).limit(4).collect(Collectors.toList())
-        );
+        Ensemble ensemble = EnsembleTests.createRandom(TargetType.REGRESSION, 4);
 
         EnsembleInferenceModel model = deserializeFromTrainedModel(ensemble, xContentRegistry(), EnsembleInferenceModel::fromXContent);
         expectThrows(ElasticsearchException.class, () -> model.infer(Collections.emptyMap(), RegressionConfig.EMPTY_PARAMS, null));

+ 9 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

@@ -53,17 +53,21 @@ public class TreeTests extends AbstractXContentSerializingTestCase<Tree> {
         return createRandom();
     }
 
+    public static Tree createRandom() {
+        return createRandom(randomFrom(TargetType.values()));
+    }
+
     public static Tree createRandom(TargetType targetType) {
         int numberOfFeatures = randomIntBetween(1, 10);
+        return createRandom(targetType, numberOfFeatures, 6);
+    }
+
+    public static Tree createRandom(TargetType targetType, int numberOfFeatures, int depth) {
         List<String> featureNames = new ArrayList<>();
         for (int i = 0; i < numberOfFeatures; i++) {
             featureNames.add(randomAlphaOfLength(10));
         }
-        return buildRandomTree(targetType, featureNames, 6);
-    }
-
-    public static Tree createRandom() {
-        return createRandom(randomFrom(TargetType.values()));
+        return buildRandomTree(targetType, featureNames, depth);
     }
 
     public static Tree buildRandomTree(TargetType targetType, List<String> featureNames, int depth) {