Browse Source

[ML] calculate cache misses for inference and return in stats (#58252)

When a local model is constructed, the cache hit miss count is incremented.

When a user calls _stats, we will include the sum cache hit miss count across ALL nodes. This statistic is important to in comparing against the inference_count. If the cache hit miss count is near the inference_count it indicates that the cache is overburdened, or inappropriately configured.
Benjamin Trent 5 years ago
parent
commit
a43ff95f2d

+ 20 - 4
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelStats.java

@@ -18,6 +18,7 @@
  */
 package org.elasticsearch.client.ml.inference;
 
+import org.elasticsearch.client.ml.inference.trainedmodel.InferenceStats;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@@ -38,32 +39,36 @@ public class TrainedModelStats implements ToXContentObject {
     public static final ParseField MODEL_ID = new ParseField("model_id");
     public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
     public static final ParseField INGEST_STATS = new ParseField("ingest");
+    public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");
 
     private final String modelId;
     private final Map<String, Object> ingestStats;
     private final int pipelineCount;
+    private final InferenceStats inferenceStats;
 
     @SuppressWarnings("unchecked")
     static final ConstructingObjectParser<TrainedModelStats, Void> PARSER =
         new ConstructingObjectParser<>(
             "trained_model_stats",
             true,
-            args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2]));
+            args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2], (InferenceStats) args[3]));
 
     static {
         PARSER.declareString(constructorArg(), MODEL_ID);
         PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), INGEST_STATS);
         PARSER.declareInt(constructorArg(), PIPELINE_COUNT);
+        PARSER.declareObject(optionalConstructorArg(), InferenceStats.PARSER, INFERENCE_STATS);
     }
 
     public static TrainedModelStats fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
 
-    public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount) {
+    public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount, InferenceStats inferenceStats) {
         this.modelId = modelId;
         this.ingestStats = ingestStats;
         this.pipelineCount = pipelineCount;
+        this.inferenceStats = inferenceStats;
     }
 
     /**
@@ -89,6 +94,13 @@ public class TrainedModelStats implements ToXContentObject {
         return pipelineCount;
     }
 
+    /**
+     * Inference statistics
+     */
+    public InferenceStats getInferenceStats() {
+        return inferenceStats;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -97,13 +109,16 @@ public class TrainedModelStats implements ToXContentObject {
         if (ingestStats != null) {
             builder.field(INGEST_STATS.getPreferredName(), ingestStats);
         }
+        if (inferenceStats != null) {
+            builder.field(INFERENCE_STATS.getPreferredName(), inferenceStats);
+        }
         builder.endObject();
         return builder;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(modelId, ingestStats, pipelineCount);
+        return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats);
     }
 
     @Override
@@ -117,7 +132,8 @@ public class TrainedModelStats implements ToXContentObject {
         TrainedModelStats other = (TrainedModelStats) obj;
         return Objects.equals(this.modelId, other.modelId)
             && Objects.equals(this.ingestStats, other.ingestStats)
-            && Objects.equals(this.pipelineCount, other.pipelineCount);
+            && Objects.equals(this.pipelineCount, other.pipelineCount)
+            && Objects.equals(this.inferenceStats, other.inferenceStats);
     }
 
 }

+ 170 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/InferenceStats.java

@@ -0,0 +1,170 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference.trainedmodel;
+
+import org.elasticsearch.client.common.TimeUtil;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.time.Instant;
+import java.util.Objects;
+
+public class InferenceStats implements ToXContentObject {
+
+    public static final String NAME = "inference_stats";
+    public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
+    public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
+    public static final ParseField CACHE_MISS_COUNT = new ParseField("cache_miss_count");
+    public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
+    public static final ParseField TIMESTAMP = new ParseField("timestamp");
+
+    public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
+        NAME,
+        true,
+        a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (Long)a[3], (Instant)a[4])
+    );
+    static {
+        PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
+        PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
+        PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
+        PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), CACHE_MISS_COUNT);
+        PARSER.declareField(ConstructingObjectParser.constructorArg(),
+            p -> TimeUtil.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
+            TIMESTAMP,
+            ObjectParser.ValueType.VALUE);
+    }
+
+    private final long missingAllFieldsCount;
+    private final long inferenceCount;
+    private final long failureCount;
+    private final long cacheMissCount;
+    private final Instant timeStamp;
+
+    private InferenceStats(Long missingAllFieldsCount,
+                           Long inferenceCount,
+                           Long failureCount,
+                           Long cacheMissCount,
+                           Instant instant) {
+        this(unboxOrZero(missingAllFieldsCount),
+            unboxOrZero(inferenceCount),
+            unboxOrZero(failureCount),
+            unboxOrZero(cacheMissCount),
+            instant);
+    }
+
+    public InferenceStats(long missingAllFieldsCount,
+                          long inferenceCount,
+                          long failureCount,
+                          long cacheMissCount,
+                          Instant timeStamp) {
+        this.missingAllFieldsCount = missingAllFieldsCount;
+        this.inferenceCount = inferenceCount;
+        this.failureCount = failureCount;
+        this.cacheMissCount = cacheMissCount;
+        this.timeStamp = timeStamp == null ?
+            Instant.ofEpochMilli(Instant.now().toEpochMilli()) :
+            Instant.ofEpochMilli(timeStamp.toEpochMilli());
+    }
+
+    /**
+     * How many times this model attempted to infer with all its fields missing
+     */
+    public long getMissingAllFieldsCount() {
+        return missingAllFieldsCount;
+    }
+
+    /**
+     * How many inference calls were made against this model
+     */
+    public long getInferenceCount() {
+        return inferenceCount;
+    }
+
+    /**
+     * How many inference failures occurred.
+     */
+    public long getFailureCount() {
+        return failureCount;
+    }
+
+    /**
+     * How many cache misses occurred when inferring this model
+     */
+    public long getCacheMissCount() {
+        return cacheMissCount;
+    }
+
+    /**
+     * The timestamp of these statistics.
+     */
+    public Instant getTimeStamp() {
+        return timeStamp;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
+        builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
+        builder.field(CACHE_MISS_COUNT.getPreferredName(), cacheMissCount);
+        builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
+        builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        InferenceStats that = (InferenceStats) o;
+        return missingAllFieldsCount == that.missingAllFieldsCount
+            && inferenceCount == that.inferenceCount
+            && failureCount == that.failureCount
+            && cacheMissCount == that.cacheMissCount
+            && Objects.equals(timeStamp, that.timeStamp);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, cacheMissCount, timeStamp);
+    }
+
+    @Override
+    public String toString() {
+        return "InferenceStats{" +
+            "missingAllFieldsCount=" + missingAllFieldsCount +
+            ", inferenceCount=" + inferenceCount +
+            ", failureCount=" + failureCount +
+            ", cacheMissCount=" + cacheMissCount +
+            ", timeStamp=" + timeStamp +
+            '}';
+    }
+
+    private static long unboxOrZero(@Nullable Long value) {
+        return value == null ? 0L : value;
+    }
+
+}

+ 3 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelStatsTests.java

@@ -18,6 +18,7 @@
  */
 package org.elasticsearch.client.ml.inference;
 
+import org.elasticsearch.client.ml.inference.trainedmodel.InferenceStatsTests;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -58,7 +59,8 @@ public class TrainedModelStatsTests extends AbstractXContentTestCase<TrainedMode
         return new TrainedModelStats(
             randomAlphaOfLength(10),
             randomBoolean() ? null : randomIngestStats(),
-            randomInt());
+            randomInt(),
+            randomBoolean() ? null : InferenceStatsTests.randomInstance());
     }
 
     private Map<String, Object> randomIngestStats() {

+ 54 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/InferenceStatsTests.java

@@ -0,0 +1,54 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.time.Instant;
+
+public class InferenceStatsTests extends AbstractXContentTestCase<InferenceStats> {
+
+    public static InferenceStats randomInstance() {
+        return new InferenceStats(randomNonNegativeLong(),
+            randomNonNegativeLong(),
+            randomNonNegativeLong(),
+            randomNonNegativeLong(),
+            Instant.now()
+            );
+    }
+
+    @Override
+    protected InferenceStats doParseInstance(XContentParser parser) throws IOException {
+        return InferenceStats.PARSER.apply(parser, null);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected InferenceStats createTestInstance() {
+        return randomInstance();
+    }
+
+}

+ 84 - 3
docs/reference/ml/df-analytics/apis/get-inference-trained-model-stats.asciidoc

@@ -67,6 +67,74 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from]
 (Optional, integer) 
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size]
 
+[role="child_attributes"]
+[[ml-get-inference-stats-results]]
+==== {api-response-body-title}
+
+`count`::
+(integer)
+The total number of trained model statistics that matched the requested ID patterns.
+Could be higher than the number of items in the `trained_model_stats` array as the
+size of the array is restricted by the supplied `size` parameter.
+
+`trained_model_stats`::
+(array)
+An array of trained model statistics, which are sorted by the `model_id` value in
+ascending order.
++
+.Properties of trained model stats
+[%collapsible%open]
+====
+`model_id`:::
+(string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
+
+`pipeline_count`:::
+(integer)
+The number of ingest pipelines that currently refer to the model.
+
+`inference_stats`:::
+(object)
+A collection of inference stats fields.
++
+.Properties of inference stats
+[%collapsible%open]
+=====
+
+`missing_all_fields_count`:::
+(integer)
+The number of inference calls where all the training features for the model
+were missing.
+
+`inference_count`:::
+(integer)
+The total number of times the model has been called for inference.
+This is across all inference contexts, including all pipelines.
+
+`cache_miss_count`:::
+(integer)
+The number of times the model was loaded for inference and was not retrieved from the
+cache. If this number is close to the `inference_count`, then the cache
+is not being appropriately used. This can be remedied by increasing the cache's size
+or its time-to-live (TTL). See <<general-ml-settings>> for the
+appropriate settings.
+
+`failure_count`:::
+(integer)
+The number of failures when using the model for inference.
+
+`timestamp`:::
+(<<time-units,time units>>)
+The time when the statistics were last updated.
+=====
+
+`ingest`:::
+(object)
+A collection of ingest stats for the model across all nodes. The values are
+summations of the individual node statistics. The format matches the `ingest`
+section in <<cluster-nodes-stats>>.
+
+====
 
 [[ml-get-inference-stats-response-codes]]
 ==== {api-response-codes-title}
@@ -74,7 +142,6 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size]
 `404` (Missing resources)::
   If `allow_no_match` is `false`, this code indicates that there are no
   resources that match the request or only partial matches for the request.
-  
 
 [[ml-get-inference-stats-example]]
 ==== {api-examples-title}
@@ -97,11 +164,25 @@ The API returns the following results:
   "trained_model_stats": [
     {
       "model_id": "flight-delay-prediction-1574775339910",
-      "pipeline_count": 0
+      "pipeline_count": 0,
+      "inference_stats": {
+        "failure_count": 0,
+        "inference_count": 4,
+        "cache_miss_count": 3,
+        "missing_all_fields_count": 0,
+        "timestamp": 1592399986979
+      }
     },
     {
       "model_id": "regression-job-one-1574775307356",
       "pipeline_count": 1,
+      "inference_stats": {
+        "failure_count": 0,
+        "inference_count": 178,
+        "cache_miss_count": 3,
+        "missing_all_fields_count": 0,
+        "timestamp": 1592399986979
+      },
       "ingest": {
         "total": {
           "count": 178,
@@ -135,4 +216,4 @@ The API returns the following results:
   ]
 }
 ----
-// NOTCONSOLE
+// NOTCONSOLE

+ 72 - 69
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -20,15 +21,13 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
 import java.io.IOException;
 import java.time.Instant;
 import java.util.Objects;
-import java.util.concurrent.atomic.LongAdder;
-import java.util.concurrent.locks.ReadWriteLock;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 public class InferenceStats implements ToXContentObject, Writeable {
 
     public static final String NAME = "inference_stats";
     public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
     public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
+    public static final ParseField CACHE_MISS_COUNT = new ParseField("cache_miss_count");
     public static final ParseField MODEL_ID = new ParseField("model_id");
     public static final ParseField NODE_ID = new ParseField("node_id");
     public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
@@ -38,12 +37,13 @@ public class InferenceStats implements ToXContentObject, Writeable {
     public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
         NAME,
         true,
-        a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (String)a[3], (String)a[4], (Instant)a[5])
+        a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (Long)a[3], (String)a[4], (String)a[5], (Instant)a[6])
     );
     static {
         PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
         PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
         PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
+        PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), CACHE_MISS_COUNT);
         PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
         PARSER.declareString(ConstructingObjectParser.constructorArg(), NODE_ID);
         PARSER.declareField(ConstructingObjectParser.constructorArg(),
@@ -51,9 +51,6 @@ public class InferenceStats implements ToXContentObject, Writeable {
             TIMESTAMP,
             ObjectParser.ValueType.VALUE);
     }
-    public static InferenceStats emptyStats(String modelId, String nodeId) {
-        return new InferenceStats(0L, 0L, 0L, modelId, nodeId, Instant.now());
-    }
 
     public static String docId(String modelId, String nodeId) {
         return NAME + "-" + modelId + "-" + nodeId;
@@ -62,6 +59,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
     private final long missingAllFieldsCount;
     private final long inferenceCount;
     private final long failureCount;
+    private final long cacheMissCount;
     private final String modelId;
     private final String nodeId;
     private final Instant timeStamp;
@@ -69,12 +67,14 @@ public class InferenceStats implements ToXContentObject, Writeable {
     private InferenceStats(Long missingAllFieldsCount,
                            Long inferenceCount,
                            Long failureCount,
+                           Long cacheMissCount,
                            String modelId,
                            String nodeId,
                            Instant instant) {
-        this(unbox(missingAllFieldsCount),
-            unbox(inferenceCount),
-            unbox(failureCount),
+        this(unboxOrZero(missingAllFieldsCount),
+            unboxOrZero(inferenceCount),
+            unboxOrZero(failureCount),
+            unboxOrZero(cacheMissCount),
             modelId,
             nodeId,
             instant);
@@ -83,12 +83,14 @@ public class InferenceStats implements ToXContentObject, Writeable {
     public InferenceStats(long missingAllFieldsCount,
                           long inferenceCount,
                           long failureCount,
+                          long cacheMissCount,
                           String modelId,
                           String nodeId,
                           Instant timeStamp) {
         this.missingAllFieldsCount = missingAllFieldsCount;
         this.inferenceCount = inferenceCount;
         this.failureCount = failureCount;
+        this.cacheMissCount = cacheMissCount;
         this.modelId = modelId;
         this.nodeId = nodeId;
         this.timeStamp = timeStamp == null ?
@@ -100,6 +102,11 @@ public class InferenceStats implements ToXContentObject, Writeable {
         this.missingAllFieldsCount = in.readVLong();
         this.inferenceCount = in.readVLong();
         this.failureCount = in.readVLong();
+        if (in.getVersion().onOrAfter(Version.V_7_9_0)) {
+            this.cacheMissCount = in.readVLong();
+        } else {
+            this.cacheMissCount = 0L;
+        }
         this.modelId = in.readOptionalString();
         this.nodeId = in.readOptionalString();
         this.timeStamp = in.readInstant();
@@ -117,6 +124,10 @@ public class InferenceStats implements ToXContentObject, Writeable {
         return failureCount;
     }
 
+    public long getCacheMissCount() {
+        return cacheMissCount;
+    }
+
     public String getModelId() {
         return modelId;
     }
@@ -130,7 +141,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
     }
 
     public boolean hasStats() {
-        return missingAllFieldsCount > 0 || inferenceCount > 0 || failureCount > 0;
+        return missingAllFieldsCount > 0 || inferenceCount > 0 || failureCount > 0 || cacheMissCount > 0;
     }
 
     @Override
@@ -145,6 +156,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
         }
         builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
         builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
+        builder.field(CACHE_MISS_COUNT.getPreferredName(), cacheMissCount);
         builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
         builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
         builder.endObject();
@@ -159,6 +171,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
         return missingAllFieldsCount == that.missingAllFieldsCount
             && inferenceCount == that.inferenceCount
             && failureCount == that.failureCount
+            && cacheMissCount == that.cacheMissCount
             && Objects.equals(modelId, that.modelId)
             && Objects.equals(nodeId, that.nodeId)
             && Objects.equals(timeStamp, that.timeStamp);
@@ -166,7 +179,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
 
     @Override
     public int hashCode() {
-        return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, modelId, nodeId, timeStamp);
+        return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, cacheMissCount, modelId, nodeId, timeStamp);
     }
 
     @Override
@@ -175,13 +188,14 @@ public class InferenceStats implements ToXContentObject, Writeable {
             "missingAllFieldsCount=" + missingAllFieldsCount +
             ", inferenceCount=" + inferenceCount +
             ", failureCount=" + failureCount +
+            ", cacheMissCount=" + cacheMissCount +
             ", modelId='" + modelId + '\'' +
             ", nodeId='" + nodeId + '\'' +
             ", timeStamp=" + timeStamp +
             '}';
     }
 
-    private static long unbox(@Nullable Long value) {
+    private static long unboxOrZero(@Nullable Long value) {
         return value == null ? 0L : value;
     }
 
@@ -194,6 +208,9 @@ public class InferenceStats implements ToXContentObject, Writeable {
         out.writeVLong(this.missingAllFieldsCount);
         out.writeVLong(this.inferenceCount);
         out.writeVLong(this.failureCount);
+        if (out.getVersion().onOrAfter(Version.V_7_9_0)) {
+            out.writeVLong(this.cacheMissCount);
+        }
         out.writeOptionalString(this.modelId);
         out.writeOptionalString(this.nodeId);
         out.writeInstant(timeStamp);
@@ -201,66 +218,55 @@ public class InferenceStats implements ToXContentObject, Writeable {
 
     public static class Accumulator {
 
-        private final LongAdder missingFieldsAccumulator = new LongAdder();
-        private final LongAdder inferenceAccumulator = new LongAdder();
-        private final LongAdder failureCountAccumulator = new LongAdder();
+        private long missingFieldsAccumulator = 0L;
+        private long inferenceAccumulator = 0L;
+        private long failureCountAccumulator = 0L;
+        private long cacheMissAccumulator = 0L;
         private final String modelId;
         private final String nodeId;
-        // curious reader
-        // you may be wondering why the lock set to the fair.
-        // When `currentStatsAndReset` is called, we want it guaranteed that it will eventually execute.
-        // If a ReadWriteLock is unfair, there are no such guarantees.
-        // A call for the `writelock::lock` could pause indefinitely.
-        private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true);
-
-        public Accumulator(String modelId, String nodeId) {
+
+        public Accumulator(String modelId, String nodeId, long cacheMisses) {
             this.modelId = modelId;
             this.nodeId = nodeId;
+            this.cacheMissAccumulator = cacheMisses;
         }
 
-        public Accumulator(InferenceStats previousStats) {
+        Accumulator(InferenceStats previousStats) {
             this.modelId = previousStats.modelId;
             this.nodeId = previousStats.nodeId;
-            this.missingFieldsAccumulator.add(previousStats.missingAllFieldsCount);
-            this.inferenceAccumulator.add(previousStats.inferenceCount);
-            this.failureCountAccumulator.add(previousStats.failureCount);
+            this.missingFieldsAccumulator += previousStats.missingAllFieldsCount;
+            this.inferenceAccumulator += previousStats.inferenceCount;
+            this.failureCountAccumulator += previousStats.failureCount;
+            this.cacheMissAccumulator += previousStats.cacheMissCount;
         }
 
+        /**
+         * NOT Thread Safe
+         *
+         * @param otherStats the other stats with which to increment the current stats
+         * @return Updated accumulator
+         */
         public Accumulator merge(InferenceStats otherStats) {
-            this.missingFieldsAccumulator.add(otherStats.missingAllFieldsCount);
-            this.inferenceAccumulator.add(otherStats.inferenceCount);
-            this.failureCountAccumulator.add(otherStats.failureCount);
+            this.missingFieldsAccumulator += otherStats.missingAllFieldsCount;
+            this.inferenceAccumulator += otherStats.inferenceCount;
+            this.failureCountAccumulator += otherStats.failureCount;
+            this.cacheMissAccumulator += otherStats.cacheMissCount;
             return this;
         }
 
-        public Accumulator incMissingFields() {
-            readWriteLock.readLock().lock();
-            try {
-                this.missingFieldsAccumulator.increment();
-                return this;
-            } finally {
-                readWriteLock.readLock().unlock();
-            }
+        public synchronized Accumulator incMissingFields() {
+            this.missingFieldsAccumulator++;
+            return this;
         }
 
-        public Accumulator incInference() {
-            readWriteLock.readLock().lock();
-            try {
-                this.inferenceAccumulator.increment();
-                return this;
-            } finally {
-                readWriteLock.readLock().unlock();
-            }
+        public synchronized Accumulator incInference() {
+            this.inferenceAccumulator++;
+            return this;
         }
 
-        public Accumulator incFailure() {
-            readWriteLock.readLock().lock();
-            try {
-                this.failureCountAccumulator.increment();
-                return this;
-            } finally {
-                readWriteLock.readLock().unlock();
-            }
+        public synchronized Accumulator incFailure() {
+            this.failureCountAccumulator++;
+            return this;
         }
 
         /**
@@ -269,23 +275,20 @@ public class InferenceStats implements ToXContentObject, Writeable {
          * Returns the current stats and resets the values of all the counters.
          * @return The current stats
          */
-        public InferenceStats currentStatsAndReset() {
-            readWriteLock.writeLock().lock();
-            try {
-                InferenceStats stats = currentStats(Instant.now());
-                this.missingFieldsAccumulator.reset();
-                this.inferenceAccumulator.reset();
-                this.failureCountAccumulator.reset();
-                return stats;
-            } finally {
-                readWriteLock.writeLock().unlock();
-            }
+        public synchronized InferenceStats currentStatsAndReset() {
+            InferenceStats stats = currentStats(Instant.now());
+            this.missingFieldsAccumulator = 0L;
+            this.inferenceAccumulator = 0L;
+            this.failureCountAccumulator = 0L;
+            this.cacheMissAccumulator = 0L;
+            return stats;
         }
 
         public InferenceStats currentStats(Instant timeStamp) {
-            return new InferenceStats(missingFieldsAccumulator.longValue(),
-                inferenceAccumulator.longValue(),
-                failureCountAccumulator.longValue(),
+            return new InferenceStats(missingFieldsAccumulator,
+                inferenceAccumulator,
+                failureCountAccumulator,
+                cacheMissAccumulator,
                 modelId,
                 nodeId,
                 timeStamp);

+ 3 - 0
x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json

@@ -97,6 +97,9 @@
       "failure_count": {
         "type": "long"
       },
+      "cache_miss_count": {
+        "type": "long"
+      },
       "missing_all_fields_count": {
         "type": "long"
       },

+ 1 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStatsTests.java

@@ -21,6 +21,7 @@ public class InferenceStatsTests extends AbstractSerializingTestCase<InferenceSt
 
     public static InferenceStats createTestInstance(String modelId, @Nullable String nodeId) {
         return new InferenceStats(randomNonNegativeLong(),
+            randomNonNegativeLong(),
             randomNonNegativeLong(),
             randomNonNegativeLong(),
             modelId,

+ 12 - 4
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

@@ -117,9 +117,13 @@ public class InferenceIngestIT extends ESRestTestCase {
             try {
                 Response statsResponse = client().performRequest(new Request("GET",
                     "_ml/inference/" + classificationModelId + "/_stats"));
-                assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
+                String response = EntityUtils.toString(statsResponse.getEntity());
+                assertThat(response, containsString("\"inference_count\":10"));
+                assertThat(response, containsString("\"cache_miss_count\":30"));
                 statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
-                assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
+                response = EntityUtils.toString(statsResponse.getEntity());
+                assertThat(response, containsString("\"inference_count\":10"));
+                assertThat(response, containsString("\"cache_miss_count\":30"));
             } catch (ResponseException ex) {
                 //this could just mean shard failures.
                 fail(ex.getMessage());
@@ -169,9 +173,13 @@ public class InferenceIngestIT extends ESRestTestCase {
             try {
                 Response statsResponse = client().performRequest(new Request("GET",
                     "_ml/inference/" + classificationModelId + "/_stats"));
-                assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
+                String response = EntityUtils.toString(statsResponse.getEntity());
+                assertThat(response, containsString("\"inference_count\":10"));
+                assertThat(response, containsString("\"cache_miss_count\":3"));
                 statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
-                assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
+                response = EntityUtils.toString(statsResponse.getEntity());
+                assertThat(response, containsString("\"inference_count\":15"));
+                assertThat(response, containsString("\"cache_miss_count\":3"));
                 // can get both
                 statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats"));
                 String entityString = EntityUtils.toString(statsResponse.getEntity());

+ 4 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java

@@ -58,12 +58,14 @@ public class TrainedModelStatsService {
         "    ctx._source.{0} += params.{0};\n" +
         "    ctx._source.{1} += params.{1};\n" +
         "    ctx._source.{2} += params.{2};\n" +
-        "    ctx._source.{3} = params.{3};";
+        "    ctx._source.{3} += params.{3};\n" +
+        "    ctx._source.{4} = params.{4};";
     // Script to only update if stats have increased since last persistence
     private static final String STATS_UPDATE_SCRIPT = Messages.getMessage(STATS_UPDATE_SCRIPT_TEMPLATE,
         InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(),
         InferenceStats.INFERENCE_COUNT.getPreferredName(),
         InferenceStats.FAILURE_COUNT.getPreferredName(),
+        InferenceStats.CACHE_MISS_COUNT.getPreferredName(),
         InferenceStats.TIMESTAMP.getPreferredName());
     private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS =
         new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
@@ -224,6 +226,7 @@ public class TrainedModelStatsService {
             params.put(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(), stats.getMissingAllFieldsCount());
             params.put(InferenceStats.TIMESTAMP.getPreferredName(), stats.getTimeStamp().toEpochMilli());
             params.put(InferenceStats.INFERENCE_COUNT.getPreferredName(), stats.getInferenceCount());
+            params.put(InferenceStats.CACHE_MISS_COUNT.getPreferredName(), stats.getCacheMissCount());
             stats.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
             UpdateRequest updateRequest = new UpdateRequest();
             updateRequest.upsert(builder)

+ 4 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java

@@ -46,11 +46,13 @@ public class LocalModel implements Model {
                       TrainedModelInput input,
                       Map<String, String> defaultFieldMap,
                       InferenceConfig modelInferenceConfig,
-                      TrainedModelStatsService trainedModelStatsService ) {
+                      TrainedModelStatsService trainedModelStatsService) {
         this.trainedModelDefinition = trainedModelDefinition;
         this.modelId = modelId;
         this.fieldNames = new HashSet<>(input.getFieldNames());
-        this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId);
+        // the ctor being called means a new instance was created.
+        // Consequently, it was not loaded from cache and on stats persist we should increment accordingly.
+        this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId, 1L);
         this.trainedModelStatsService = trainedModelStatsService;
         this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
         this.currentInferenceCount = new LongAdder();

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

@@ -633,6 +633,8 @@ public class TrainedModelProvider {
                     .field(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName()))
                 .aggregation(AggregationBuilders.sum(InferenceStats.INFERENCE_COUNT.getPreferredName())
                     .field(InferenceStats.INFERENCE_COUNT.getPreferredName()))
+                .aggregation(AggregationBuilders.sum(InferenceStats.CACHE_MISS_COUNT.getPreferredName())
+                    .field(InferenceStats.CACHE_MISS_COUNT.getPreferredName()))
                 .aggregation(AggregationBuilders.max(InferenceStats.TIMESTAMP.getPreferredName())
                     .field(InferenceStats.TIMESTAMP.getPreferredName()))
                 .query(queryBuilder));
@@ -645,12 +647,14 @@ public class TrainedModelProvider {
         }
         Sum failures = response.getAggregations().get(InferenceStats.FAILURE_COUNT.getPreferredName());
         Sum missing = response.getAggregations().get(InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName());
+        Sum cacheMiss = response.getAggregations().get(InferenceStats.CACHE_MISS_COUNT.getPreferredName());
         Sum count = response.getAggregations().get(InferenceStats.INFERENCE_COUNT.getPreferredName());
         Max timeStamp = response.getAggregations().get(InferenceStats.TIMESTAMP.getPreferredName());
         return new InferenceStats(
             missing == null ? 0L : Double.valueOf(missing.getValue()).longValue(),
             count == null ? 0L : Double.valueOf(count.getValue()).longValue(),
             failures == null ? 0L : Double.valueOf(failures.getValue()).longValue(),
+            cacheMiss == null ? 0L : Double.valueOf(cacheMiss.getValue()).longValue(),
             modelId,
             null,
             timeStamp == null || (Numbers.isValidDouble(timeStamp.getValue()) == false) ?