|
@@ -27,12 +27,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.Hyperparameters;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
|
|
|
import org.hamcrest.Matchers;
|
|
|
import org.junit.After;
|
|
|
|
|
@@ -540,7 +536,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/93228")
|
|
|
public void testAliasFields() throws Exception {
|
|
|
// The goal of this test is to assert alias fields are included in the analytics job.
|
|
|
// We have a simple dataset with two integer fields: field_1 and field_2.
|
|
@@ -585,10 +580,32 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
// Very infrequently this test may fail as the algorithm underestimates the
|
|
|
// required number of trees for this simple problem. This failure is irrelevant
|
|
|
// for non-trivial real-world problem and improving estimation of the number of trees
|
|
|
- // would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed.
|
|
|
+ // would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed
|
|
|
+ // and use the hyperparameters that are known to work.
|
|
|
long seed = 1000L; // fix seed
|
|
|
|
|
|
- Regression regression = new Regression("field_2", BoostedTreeParams.builder().build(), null, 90.0, seed, null, null, null, null);
|
|
|
+ Regression regression = new Regression(
|
|
|
+ "field_2",
|
|
|
+ BoostedTreeParams.builder()
|
|
|
+ .setDownsampleFactor(0.7520841625652861)
|
|
|
+ .setAlpha(547.9095715556235)
|
|
|
+ .setLambda(3.3008189603590044)
|
|
|
+ .setGamma(1.6082763366825203)
|
|
|
+ .setSoftTreeDepthLimit(4.733224114945455)
|
|
|
+ .setSoftTreeDepthTolerance(0.15)
|
|
|
+ .setEta(0.12371209659057758)
|
|
|
+ .setEtaGrowthRatePerTree(1.0618560482952888)
|
|
|
+ .setMaxTrees(30)
|
|
|
+ .setFeatureBagFraction(0.8)
|
|
|
+ .build(),
|
|
|
+ null,
|
|
|
+ 90.0,
|
|
|
+ seed,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null
|
|
|
+ );
|
|
|
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder().setId(jobId)
|
|
|
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null, null, Collections.emptyMap()))
|
|
|
.setDest(new DataFrameAnalyticsDest(destIndex, null))
|
|
@@ -604,19 +621,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
waitUntilAnalyticsIsStopped(jobId);
|
|
|
|
|
|
- // obtain addition information for investigation of #90599
|
|
|
- String modelId = getModelId(jobId);
|
|
|
- TrainedModelMetadata modelMetadata = getModelMetadata(modelId);
|
|
|
- assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0));
|
|
|
- StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90599
|
|
|
- for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) {
|
|
|
- hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n");
|
|
|
- }
|
|
|
- TrainedModelDefinition modelDefinition = getModelDefinition(modelId);
|
|
|
- Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel();
|
|
|
- int numberTrees = ensemble.getModels().size();
|
|
|
-
|
|
|
- StringBuilder targetsPredictions = new StringBuilder(); // used to investigate #90599
|
|
|
assertResponse(prepareSearch(sourceIndex).setSize(totalDocCount), sourceData -> {
|
|
|
double predictionErrorSum = 0.0;
|
|
|
for (SearchHit hit : sourceData.getHits()) {
|
|
@@ -629,19 +633,12 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
int featureValue = (int) destDoc.get("field_1");
|
|
|
double predictionValue = (double) resultsObject.get(predictionField);
|
|
|
predictionErrorSum += Math.abs(predictionValue - 2 * featureValue);
|
|
|
-
|
|
|
- // collect the log of targets and predictions for debugging #90599
|
|
|
- targetsPredictions.append(2 * featureValue).append(", ").append(predictionValue).append("\n");
|
|
|
}
|
|
|
// We assert on the mean prediction error in order to reduce the probability
|
|
|
// the test fails compared to asserting on the prediction of each individual doc.
|
|
|
double meanPredictionError = predictionErrorSum / sourceData.getHits().getHits().length;
|
|
|
String str = "Failure: failed for seed %d inferenceEntityId %s numberTrees %d\n";
|
|
|
- assertThat(
|
|
|
- Strings.format(str, seed, modelId, numberTrees) + targetsPredictions + hyperparameters,
|
|
|
- meanPredictionError,
|
|
|
- lessThanOrEqualTo(3.0)
|
|
|
- );
|
|
|
+ assertThat(meanPredictionError, lessThanOrEqualTo(3.0));
|
|
|
});
|
|
|
|
|
|
assertProgressComplete(jobId);
|