|
@@ -43,7 +43,8 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
|
|
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
|
|
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
|
|
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
|
|
|
-import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
|
|
|
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification;
|
|
|
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression;
|
|
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
|
|
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
|
|
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
|
|
@@ -154,14 +155,7 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|
|
RawInferenceResults inferenceResult = (RawInferenceResults) result;
|
|
|
inferenceResults[i++] = inferenceResult.getValue();
|
|
|
if (config.requestingImportance()) {
|
|
|
- double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
|
|
|
- assert modelFeatureImportance.length == featureInfluence.length;
|
|
|
- for (int j = 0; j < modelFeatureImportance.length; j++) {
|
|
|
- if (featureInfluence[j] == null) {
|
|
|
- featureInfluence[j] = new double[modelFeatureImportance[j].length];
|
|
|
- }
|
|
|
- featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
|
|
|
- }
|
|
|
+ addFeatureImportance(featureInfluence, inferenceResult);
|
|
|
}
|
|
|
}
|
|
|
double[] processed = outputAggregator.processValues(inferenceResults);
|
|
@@ -176,18 +170,22 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|
|
InferenceResults result = model.infer(features, subModelInferenceConfig);
|
|
|
assert result instanceof RawInferenceResults;
|
|
|
RawInferenceResults inferenceResult = (RawInferenceResults) result;
|
|
|
- double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
|
|
|
- assert modelFeatureImportance.length == featureInfluence.length;
|
|
|
- for (int j = 0; j < modelFeatureImportance.length; j++) {
|
|
|
- if (featureInfluence[j] == null) {
|
|
|
- featureInfluence[j] = new double[modelFeatureImportance[j].length];
|
|
|
- }
|
|
|
- featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
|
|
|
- }
|
|
|
+ addFeatureImportance(featureInfluence, inferenceResult);
|
|
|
}
|
|
|
return featureInfluence;
|
|
|
}
|
|
|
|
|
|
+ private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
|
|
|
+ double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
|
|
|
+ assert modelFeatureImportance.length == featureInfluence.length;
|
|
|
+ for (int j = 0; j < modelFeatureImportance.length; j++) {
|
|
|
+ if (featureInfluence[j] == null) {
|
|
|
+ featureInfluence[j] = new double[modelFeatureImportance[j].length];
|
|
|
+ }
|
|
|
+ featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private InferenceResults buildResults(double[] processedInferences,
|
|
|
double[][] featureImportance,
|
|
|
Map<String, String> featureDecoderMap,
|
|
@@ -208,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|
|
case REGRESSION:
|
|
|
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
|
|
|
config,
|
|
|
- transformFeatureImportance(decodedFeatureImportance, null));
|
|
|
+ transformFeatureImportanceRegression(decodedFeatureImportance));
|
|
|
case CLASSIFICATION:
|
|
|
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
|
|
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
|
|
@@ -220,10 +218,13 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|
|
classificationConfig.getNumTopClasses(),
|
|
|
classificationConfig.getPredictionFieldType());
|
|
|
final InferenceHelpers.TopClassificationValue value = topClasses.v1();
|
|
|
- return new ClassificationInferenceResults((double)value.getValue(),
|
|
|
+ return new ClassificationInferenceResults(value.getValue(),
|
|
|
classificationLabel(topClasses.v1().getValue(), classificationLabels),
|
|
|
topClasses.v2(),
|
|
|
- transformFeatureImportance(decodedFeatureImportance, classificationLabels),
|
|
|
+ transformFeatureImportanceClassification(decodedFeatureImportance,
|
|
|
+ value.getValue(),
|
|
|
+ classificationLabels,
|
|
|
+ classificationConfig.getPredictionFieldType()),
|
|
|
config,
|
|
|
value.getProbability(),
|
|
|
value.getScore());
|