瀏覽代碼

[ML] RegressionIT: Fix hyperparameters for regression tests and unmute the test (#135541) (#135769)

This PR fixes the flaky test muted in #93228 by fixing hyperparameters to the values that always work. Since the test is for alias fields and not for the training algorithm, fixing the hyperparameters is not dangerous.

Closes #93228
Valeriy Khakhutskyy 3 周之前
父節點
當前提交
3d07b613bc

+ 25 - 28
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

@@ -27,12 +27,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.Hyperparameters;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
 import org.hamcrest.Matchers;
 import org.junit.After;
 
@@ -540,7 +536,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         );
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/93228")
     public void testAliasFields() throws Exception {
         // The goal of this test is to assert alias fields are included in the analytics job.
         // We have a simple dataset with two integer fields: field_1 and field_2.
@@ -585,10 +580,32 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         // Very infrequently this test may fail as the algorithm underestimates the
         // required number of trees for this simple problem. This failure is irrelevant
         // for non-trivial real-world problem and improving estimation of the number of trees
-        // would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed.
+        // would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed
+        // and use the hyperparameters that are known to work.
         long seed = 1000L; // fix seed
 
-        Regression regression = new Regression("field_2", BoostedTreeParams.builder().build(), null, 90.0, seed, null, null, null, null);
+        Regression regression = new Regression(
+            "field_2",
+            BoostedTreeParams.builder()
+                .setDownsampleFactor(0.7520841625652861)
+                .setAlpha(547.9095715556235)
+                .setLambda(3.3008189603590044)
+                .setGamma(1.6082763366825203)
+                .setSoftTreeDepthLimit(4.733224114945455)
+                .setSoftTreeDepthTolerance(0.15)
+                .setEta(0.12371209659057758)
+                .setEtaGrowthRatePerTree(1.0618560482952888)
+                .setMaxTrees(30)
+                .setFeatureBagFraction(0.8)
+                .build(),
+            null,
+            90.0,
+            seed,
+            null,
+            null,
+            null,
+            null
+        );
         DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder().setId(jobId)
             .setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null, null, Collections.emptyMap()))
             .setDest(new DataFrameAnalyticsDest(destIndex, null))
@@ -604,19 +621,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
         waitUntilAnalyticsIsStopped(jobId);
 
-        // obtain addition information for investigation of #90599
-        String modelId = getModelId(jobId);
-        TrainedModelMetadata modelMetadata = getModelMetadata(modelId);
-        assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0));
-        StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90599
-        for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) {
-            hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n");
-        }
-        TrainedModelDefinition modelDefinition = getModelDefinition(modelId);
-        Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel();
-        int numberTrees = ensemble.getModels().size();
-
-        StringBuilder targetsPredictions = new StringBuilder(); // used to investigate #90599
         assertResponse(prepareSearch(sourceIndex).setSize(totalDocCount), sourceData -> {
             double predictionErrorSum = 0.0;
             for (SearchHit hit : sourceData.getHits()) {
@@ -629,19 +633,12 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 int featureValue = (int) destDoc.get("field_1");
                 double predictionValue = (double) resultsObject.get(predictionField);
                 predictionErrorSum += Math.abs(predictionValue - 2 * featureValue);
-
-                // collect the log of targets and predictions for debugging #90599
-                targetsPredictions.append(2 * featureValue).append(", ").append(predictionValue).append("\n");
             }
             // We assert on the mean prediction error in order to reduce the probability
             // the test fails compared to asserting on the prediction of each individual doc.
             double meanPredictionError = predictionErrorSum / sourceData.getHits().getHits().length;
             String str = "Failure: failed for seed %d inferenceEntityId %s numberTrees %d\n";
-            assertThat(
-                Strings.format(str, seed, modelId, numberTrees) + targetsPredictions + hyperparameters,
-                meanPredictionError,
-                lessThanOrEqualTo(3.0)
-            );
+            assertThat(meanPredictionError, lessThanOrEqualTo(3.0));
         });
 
         assertProgressComplete(jobId);