|
@@ -62,16 +62,25 @@ public class EnsembleTests extends AbstractXContentSerializingTestCase<Ensemble>
|
|
|
|
|
|
public static Ensemble createRandom(TargetType targetType) {
|
|
|
int numberOfFeatures = randomIntBetween(1, 10);
|
|
|
- List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
|
|
|
- return createRandom(targetType, featureNames);
|
|
|
+ int numberOfModels = randomIntBetween(1, 10);
|
|
|
+ return createRandom(targetType, numberOfFeatures, numberOfModels, 6);
|
|
|
}
|
|
|
|
|
|
- public static Ensemble createRandom(TargetType targetType, List<String> featureNames) {
|
|
|
+ public static Ensemble createRandom(TargetType targetType, int numberOfFeatures) {
|
|
|
int numberOfModels = randomIntBetween(1, 10);
|
|
|
+ return createRandom(targetType, numberOfFeatures, numberOfModels, 6);
|
|
|
+ }
|
|
|
+
|
|
|
+ public static Ensemble createRandom(TargetType targetType, int numberOfFeatures, int numberOfModels, int treeDepth) {
|
|
|
+ List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
|
|
|
+ return createRandom(targetType, featureNames, numberOfModels, treeDepth);
|
|
|
+ }
|
|
|
+
|
|
|
+ public static Ensemble createRandom(TargetType targetType, List<String> featureNames, int numberOfModels, int treeDepth) {
|
|
|
List<String> treeFeatureNames = featureNames.isEmpty()
|
|
|
? Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList())
|
|
|
: featureNames;
|
|
|
- List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(treeFeatureNames, 6))
|
|
|
+ List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(treeFeatureNames, treeDepth))
|
|
|
.limit(numberOfModels)
|
|
|
.collect(Collectors.toList());
|
|
|
double[] weights = randomBoolean()
|