|
@@ -6,8 +6,11 @@
|
|
|
|
|
|
|
|
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
|
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
|
|
|
|
|
|
|
|
|
+import org.apache.logging.log4j.LogManager;
|
|
|
|
|
+import org.apache.logging.log4j.Logger;
|
|
|
import org.apache.lucene.util.RamUsageEstimator;
|
|
import org.apache.lucene.util.RamUsageEstimator;
|
|
|
import org.elasticsearch.common.Nullable;
|
|
import org.elasticsearch.common.Nullable;
|
|
|
|
|
+import org.elasticsearch.common.Strings;
|
|
|
import org.elasticsearch.common.collect.Tuple;
|
|
import org.elasticsearch.common.collect.Tuple;
|
|
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|
|
import org.elasticsearch.common.xcontent.XContentParser;
|
|
import org.elasticsearch.common.xcontent.XContentParser;
|
|
@@ -25,6 +28,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Leniently
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
|
|
|
|
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
import java.util.HashMap;
|
|
|
import java.util.LinkedHashSet;
|
|
import java.util.LinkedHashSet;
|
|
@@ -49,6 +53,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En
|
|
|
public class EnsembleInferenceModel implements InferenceModel {
|
|
public class EnsembleInferenceModel implements InferenceModel {
|
|
|
|
|
|
|
|
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
|
|
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
|
|
|
|
|
+ private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
|
|
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
@SuppressWarnings("unchecked")
|
|
|
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
|
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
|
@@ -136,6 +141,8 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|
|
if (preparedForInference == false) {
|
|
if (preparedForInference == false) {
|
|
|
throw ExceptionsHelper.serverError("model is not prepared for inference");
|
|
throw ExceptionsHelper.serverError("model is not prepared for inference");
|
|
|
}
|
|
}
|
|
|
|
|
+ LOGGER.debug("Inference called with feature names [{}]",
|
|
|
|
|
+ featureNames == null ? "<null>" : Strings.arrayToCommaDelimitedString(featureNames));
|
|
|
assert featureNames != null && featureNames.length > 0;
|
|
assert featureNames != null && featureNames.length > 0;
|
|
|
double[][] inferenceResults = new double[this.models.size()][];
|
|
double[][] inferenceResults = new double[this.models.size()][];
|
|
|
double[][] featureInfluence = new double[features.length][];
|
|
double[][] featureInfluence = new double[features.length][];
|
|
@@ -237,12 +244,14 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
|
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
|
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
|
|
|
|
+ LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
|
|
|
if (preparedForInference) {
|
|
if (preparedForInference) {
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
preparedForInference = true;
|
|
preparedForInference = true;
|
|
|
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
|
|
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
|
|
|
Set<String> referencedFeatures = subModelFeatures();
|
|
Set<String> referencedFeatures = subModelFeatures();
|
|
|
|
|
+ LOGGER.debug("detected submodel feature names {}", referencedFeatures);
|
|
|
int newFeatureIndex = 0;
|
|
int newFeatureIndex = 0;
|
|
|
newFeatureIndexMapping = new HashMap<>();
|
|
newFeatureIndexMapping = new HashMap<>();
|
|
|
this.featureNames = new String[referencedFeatures.size()];
|
|
this.featureNames = new String[referencedFeatures.size()];
|
|
@@ -301,4 +310,16 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|
|
return classificationWeights;
|
|
return classificationWeights;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public String toString() {
|
|
|
|
|
+ return "EnsembleInferenceModel{" +
|
|
|
|
|
+ "featureNames=" + Arrays.toString(featureNames) +
|
|
|
|
|
+ ", models=" + models +
|
|
|
|
|
+ ", outputAggregator=" + outputAggregator +
|
|
|
|
|
+ ", targetType=" + targetType +
|
|
|
|
|
+ ", classificationLabels=" + classificationLabels +
|
|
|
|
|
+ ", classificationWeights=" + Arrays.toString(classificationWeights) +
|
|
|
|
|
+ ", preparedForInference=" + preparedForInference +
|
|
|
|
|
+ '}';
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|