Browse Source

[ML] Move PyTorch request ID and cache hit indicator to top level (#88901)

This change will facilitate a performance improvement on the C++
side. The request ID and cache hit indicator are the parts that
need to be changed when the C++ process responds to an inference
request. Having them at the top level means we do not need to
parse and manipulate the original response - we can simply cache
the inner object of the response and add the outer fields around
it when serializing it.

Companion to elastic/ml-cpp#2376
David Roberts 3 years ago
parent
commit
8c21d03f7a
21 changed files with 236 additions and 138 deletions
  1. 0 1
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  2. 46 19
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java
  3. 37 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java
  4. 2 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java
  5. 4 34
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java
  6. 33 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java
  7. 2 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java
  8. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java
  9. 5 5
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  10. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java
  11. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
  12. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java
  13. 48 41
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java
  14. 35 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResultTests.java
  15. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java
  16. 1 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java
  17. 14 4
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java
  18. 1 5
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java
  19. 0 2
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml
  20. 0 1
      x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java
  21. 0 1
      x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java

+ 0 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -75,7 +75,6 @@ import static org.hamcrest.Matchers.nullValue;
  * torch.jit.save(traced_model, "simplemodel.pt")
  * ## End Python
  */
-@ESRestTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376")
 public class PyTorchModelIT extends ESRestTestCase {
 
     private static final String BASIC_AUTH_VALUE_SUPER_USER = UsernamePasswordToken.basicAuthHeaderValue(

+ 46 - 19
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.xpack.core.ml.utils.Intervals;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.AckResult;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
@@ -105,10 +106,12 @@ public class PyTorchResultProcessor {
                     threadSettingsConsumer.accept(threadSettings);
                     processThreadSettings(result);
                 }
+                if (result.ackResult() != null) {
+                    processAcknowledgement(result);
+                }
                 if (result.errorResult() != null) {
                     processErrorResult(result);
                 }
-
             }
         } catch (Exception e) {
             // No need to report error as we're stopping
@@ -118,10 +121,13 @@ public class PyTorchResultProcessor {
             pendingResults.forEach(
                 (id, pendingResult) -> pendingResult.listener.onResponse(
                     new PyTorchResult(
+                        id,
+                        null,
+                        null,
+                        null,
                         null,
                         null,
                         new ErrorResult(
-                            id,
                             isStopping
                                 ? "inference canceled as process is stopping"
                                 : "inference native process died unexpectedly with failure [" + e.getMessage() + "]"
@@ -133,7 +139,7 @@ public class PyTorchResultProcessor {
         } finally {
             pendingResults.forEach(
                 (id, pendingResult) -> pendingResult.listener.onResponse(
-                    new PyTorchResult(null, null, new ErrorResult(id, "inference canceled as process is stopping"))
+                    new PyTorchResult(id, false, null, null, null, null, new ErrorResult("inference canceled as process is stopping"))
                 )
             );
             pendingResults.clear();
@@ -144,12 +150,17 @@ public class PyTorchResultProcessor {
     void processInferenceResult(PyTorchResult result) {
         PyTorchInferenceResult inferenceResult = result.inferenceResult();
         assert inferenceResult != null;
+        Long timeMs = result.timeMs();
+        if (timeMs == null) {
+            assert false : "time_ms should be set for an inference result";
+            timeMs = 0L;
+        }
 
-        logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, inferenceResult.getRequestId()));
-        processResult(inferenceResult);
-        PendingResult pendingResult = pendingResults.remove(inferenceResult.getRequestId());
+        logger.trace(() -> format("[%s] Parsed inference result with id [%s]", deploymentId, result.requestId()));
+        processResult(inferenceResult, timeMs, Boolean.TRUE.equals(result.isCacheHit()));
+        PendingResult pendingResult = pendingResults.remove(result.requestId());
         if (pendingResult == null) {
-            logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, inferenceResult.getRequestId()));
+            logger.debug(() -> format("[%s] no pending result for inference [%s]", deploymentId, result.requestId()));
         } else {
             pendingResult.listener.onResponse(result);
         }
@@ -159,10 +170,23 @@ public class PyTorchResultProcessor {
         ThreadSettings threadSettings = result.threadSettings();
         assert threadSettings != null;
 
-        logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, threadSettings.requestId()));
-        PendingResult pendingResult = pendingResults.remove(threadSettings.requestId());
+        logger.trace(() -> format("[%s] Parsed thread settings result with id [%s]", deploymentId, result.requestId()));
+        PendingResult pendingResult = pendingResults.remove(result.requestId());
+        if (pendingResult == null) {
+            logger.debug(() -> format("[%s] no pending result for thread settings [%s]", deploymentId, result.requestId()));
+        } else {
+            pendingResult.listener.onResponse(result);
+        }
+    }
+
+    void processAcknowledgement(PyTorchResult result) {
+        AckResult ack = result.ackResult();
+        assert ack != null;
+
+        logger.trace(() -> format("[%s] Parsed ack result with id [%s]", deploymentId, result.requestId()));
+        PendingResult pendingResult = pendingResults.remove(result.requestId());
         if (pendingResult == null) {
-            logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, threadSettings.requestId()));
+            logger.debug(() -> format("[%s] no pending result for ack [%s]", deploymentId, result.requestId()));
         } else {
             pendingResult.listener.onResponse(result);
         }
@@ -172,12 +196,15 @@ public class PyTorchResultProcessor {
         ErrorResult errorResult = result.errorResult();
         assert errorResult != null;
 
-        errorCount++;
+        // Only one result is processed at any time, but we need to stop this happening part way through another thread getting stats
+        synchronized (this) {
+            errorCount++;
+        }
 
-        logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, errorResult.requestId()));
-        PendingResult pendingResult = pendingResults.remove(errorResult.requestId());
+        logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, result.requestId()));
+        PendingResult pendingResult = pendingResults.remove(result.requestId());
         if (pendingResult == null) {
-            logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, errorResult.requestId()));
+            logger.debug(() -> format("[%s] no pending result for error [%s]", deploymentId, result.requestId()));
         } else {
             pendingResult.listener.onResponse(result);
         }
@@ -218,8 +245,8 @@ public class PyTorchResultProcessor {
         );
     }
 
-    private synchronized void processResult(PyTorchInferenceResult result) {
-        timingStats.accept(result.getTimeMs());
+    private synchronized void processResult(PyTorchInferenceResult result, long timeMs, boolean isCacheHit) {
+        timingStats.accept(timeMs);
 
         lastResultTimeMs = currentTimeMsSupplier.getAsLong();
         if (lastResultTimeMs > currentPeriodEndTimeMs) {
@@ -240,15 +267,15 @@ public class PyTorchResultProcessor {
 
             lastPeriodCacheHitCount = 0;
             lastPeriodSummaryStats = new LongSummaryStatistics();
-            lastPeriodSummaryStats.accept(result.getTimeMs());
+            lastPeriodSummaryStats.accept(timeMs);
 
             // set to the end of the current bucket
             currentPeriodEndTimeMs = startTime + Intervals.alignToCeil(lastResultTimeMs - startTime, REPORTING_PERIOD_MS);
         } else {
-            lastPeriodSummaryStats.accept(result.getTimeMs());
+            lastPeriodSummaryStats.accept(timeMs);
         }
 
-        if (result.isCacheHit()) {
+        if (isCacheHit) {
             cacheHitCount++;
             lastPeriodCacheHitCount++;
         }

+ 37 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java

@@ -0,0 +1,37 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.pytorch.results;
+
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+
+public record AckResult(boolean acknowledged) implements ToXContentObject {
+
+    public static final ParseField ACKNOWLEDGED = new ParseField("acknowledged");
+
+    public static ConstructingObjectParser<AckResult, Void> PARSER = new ConstructingObjectParser<>(
+        "ack",
+        a -> new AckResult((Boolean) a[0])
+    );
+
+    static {
+        PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ACKNOWLEDGED);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(ACKNOWLEDGED.getPreferredName(), acknowledged);
+        builder.endObject();
+        return builder;
+    }
+}

+ 2 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java

@@ -14,26 +14,22 @@ import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
 
-public record ErrorResult(String requestId, String error) implements ToXContentObject {
+public record ErrorResult(String error) implements ToXContentObject {
 
     public static final ParseField ERROR = new ParseField("error");
 
     public static ConstructingObjectParser<ErrorResult, Void> PARSER = new ConstructingObjectParser<>(
         "error",
-        a -> new ErrorResult((String) a[0], (String) a[1])
+        a -> new ErrorResult((String) a[0])
     );
 
     static {
-        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PyTorchResult.REQUEST_ID);
         PARSER.declareString(ConstructingObjectParser.constructorArg(), ERROR);
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
-        if (requestId != null) {
-            builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId);
-        }
         builder.field(ERROR.getPreferredName(), error);
         builder.endObject();
         return builder;

+ 4 - 34
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java

@@ -18,7 +18,6 @@ import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;
 
 import java.io.IOException;
 import java.util.Arrays;
-import java.util.Objects;
 
 /**
  * All results must have a request_id.
@@ -28,62 +27,38 @@ import java.util.Objects;
 public class PyTorchInferenceResult implements ToXContentObject {
 
     private static final ParseField INFERENCE = new ParseField("inference");
-    private static final ParseField TIME_MS = new ParseField("time_ms");
-    private static final ParseField CACHE_HIT = new ParseField("cache_hit");
 
     public static final ConstructingObjectParser<PyTorchInferenceResult, Void> PARSER = new ConstructingObjectParser<>(
         "pytorch_inference_result",
-        a -> new PyTorchInferenceResult((String) a[0], (double[][][]) a[1], (Long) a[2], (Boolean) a[3])
+        a -> new PyTorchInferenceResult((double[][][]) a[0])
     );
 
     static {
-        PARSER.declareString(ConstructingObjectParser.constructorArg(), PyTorchResult.REQUEST_ID);
         PARSER.declareField(
             ConstructingObjectParser.optionalConstructorArg(),
             (p, c) -> MlParserUtils.parse3DArrayOfDoubles(INFERENCE.getPreferredName(), p),
             INFERENCE,
             ObjectParser.ValueType.VALUE_ARRAY
         );
-        PARSER.declareLong(ConstructingObjectParser.constructorArg(), TIME_MS);
-        PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), CACHE_HIT);
     }
 
     public static PyTorchInferenceResult fromXContent(XContentParser parser) throws IOException {
         return PARSER.parse(parser, null);
     }
 
-    private final String requestId;
     private final double[][][] inference;
-    private final long timeMs;
-    private final boolean cacheHit;
 
-    public PyTorchInferenceResult(String requestId, @Nullable double[][][] inference, long timeMs, boolean cacheHit) {
-        this.requestId = Objects.requireNonNull(requestId);
+    public PyTorchInferenceResult(@Nullable double[][][] inference) {
         this.inference = inference;
-        this.timeMs = timeMs;
-        this.cacheHit = cacheHit;
-    }
-
-    public String getRequestId() {
-        return requestId;
     }
 
     public double[][][] getInferenceResult() {
         return inference;
     }
 
-    public long getTimeMs() {
-        return timeMs;
-    }
-
-    public boolean isCacheHit() {
-        return cacheHit;
-    }
-
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
-        builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId);
         if (inference != null) {
             builder.startArray(INFERENCE.getPreferredName());
             for (double[][] doubles : inference) {
@@ -95,15 +70,13 @@ public class PyTorchInferenceResult implements ToXContentObject {
             }
             builder.endArray();
         }
-        builder.field(TIME_MS.getPreferredName(), timeMs);
-        builder.field(CACHE_HIT.getPreferredName(), cacheHit);
         builder.endObject();
         return builder;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(requestId, timeMs, Arrays.deepHashCode(inference), cacheHit);
+        return Arrays.deepHashCode(inference);
     }
 
     @Override
@@ -112,9 +85,6 @@ public class PyTorchInferenceResult implements ToXContentObject {
         if (other == null || getClass() != other.getClass()) return false;
 
         PyTorchInferenceResult that = (PyTorchInferenceResult) other;
-        return Objects.equals(requestId, that.requestId)
-            && Arrays.deepEquals(inference, that.inference)
-            && timeMs == that.timeMs
-            && cacheHit == that.cacheHit;
+        return Arrays.deepEquals(inference, that.inference);
     }
 }

+ 33 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java

@@ -19,24 +19,43 @@ import java.io.IOException;
  * The top level object capturing output from the pytorch process.
  */
 public record PyTorchResult(
+    String requestId,
+    Boolean isCacheHit,
+    Long timeMs,
     @Nullable PyTorchInferenceResult inferenceResult,
     @Nullable ThreadSettings threadSettings,
+    @Nullable AckResult ackResult,
     @Nullable ErrorResult errorResult
 ) implements ToXContentObject {
 
-    static final ParseField REQUEST_ID = new ParseField("request_id");
+    private static final ParseField REQUEST_ID = new ParseField("request_id");
+    private static final ParseField CACHE_HIT = new ParseField("cache_hit");
+    private static final ParseField TIME_MS = new ParseField("time_ms");
 
     private static final ParseField RESULT = new ParseField("result");
     private static final ParseField THREAD_SETTINGS = new ParseField("thread_settings");
+    private static final ParseField ACK = new ParseField("ack");
 
     public static ConstructingObjectParser<PyTorchResult, Void> PARSER = new ConstructingObjectParser<>(
         "pytorch_result",
-        a -> new PyTorchResult((PyTorchInferenceResult) a[0], (ThreadSettings) a[1], (ErrorResult) a[2])
+        a -> new PyTorchResult(
+            (String) a[0],
+            (Boolean) a[1],
+            (Long) a[2],
+            (PyTorchInferenceResult) a[3],
+            (ThreadSettings) a[4],
+            (AckResult) a[5],
+            (ErrorResult) a[6]
+        )
     );
 
     static {
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID);
+        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CACHE_HIT);
+        PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), TIME_MS);
         PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), PyTorchInferenceResult.PARSER, RESULT);
         PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ThreadSettings.PARSER, THREAD_SETTINGS);
+        PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), AckResult.PARSER, ACK);
         PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ErrorResult.PARSER, ErrorResult.ERROR);
     }
 
@@ -47,12 +66,24 @@ public record PyTorchResult(
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
+        if (requestId != null) {
+            builder.field(REQUEST_ID.getPreferredName(), requestId);
+        }
+        if (isCacheHit != null) {
+            builder.field(CACHE_HIT.getPreferredName(), isCacheHit);
+        }
+        if (timeMs != null) {
+            builder.field(TIME_MS.getPreferredName(), timeMs);
+        }
         if (inferenceResult != null) {
             builder.field(RESULT.getPreferredName(), inferenceResult);
         }
         if (threadSettings != null) {
             builder.field(THREAD_SETTINGS.getPreferredName(), threadSettings);
         }
+        if (ackResult != null) {
+            builder.field(ACK.getPreferredName(), ackResult);
+        }
         if (errorResult != null) {
             builder.field(ErrorResult.ERROR.getPreferredName(), errorResult);
         }

+ 2 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java

@@ -14,20 +14,19 @@ import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
 
-public record ThreadSettings(int numThreadsPerAllocation, int numAllocations, String requestId) implements ToXContentObject {
+public record ThreadSettings(int numThreadsPerAllocation, int numAllocations) implements ToXContentObject {
 
     private static final ParseField NUM_ALLOCATIONS = new ParseField("num_allocations");
     private static final ParseField NUM_THREADS_PER_ALLOCATION = new ParseField("num_threads_per_allocation");
 
     public static ConstructingObjectParser<ThreadSettings, Void> PARSER = new ConstructingObjectParser<>(
         "thread_settings",
-        a -> new ThreadSettings((int) a[0], (int) a[1], (String) a[2])
+        a -> new ThreadSettings((int) a[0], (int) a[1])
     );
 
     static {
         PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_THREADS_PER_ALLOCATION);
         PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_ALLOCATIONS);
-        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PyTorchResult.REQUEST_ID);
     }
 
     @Override
@@ -35,9 +34,6 @@ public record ThreadSettings(int numThreadsPerAllocation, int numAllocations, St
         builder.startObject();
         builder.field(NUM_THREADS_PER_ALLOCATION.getPreferredName(), numThreadsPerAllocation);
         builder.field(NUM_ALLOCATIONS.getPreferredName(), numAllocations);
-        if (requestId != null) {
-            builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId);
-        }
         builder.endObject();
         return builder;
     }

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -64,7 +64,7 @@ public class FillMaskProcessorTests extends ESTestCase {
         String resultsField = randomAlphaOfLength(10);
         FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult(
             tokenization,
-            new PyTorchInferenceResult("1", scores, 0L, false),
+            new PyTorchInferenceResult(scores),
             tokenizer,
             4,
             resultsField
@@ -91,7 +91,7 @@ public class FillMaskProcessorTests extends ESTestCase {
             0
         );
 
-        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, false);
+        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(new double[][][] { { {} } });
         expectThrows(
             ElasticsearchStatusException.class,
             () -> FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10))

+ 5 - 5
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -72,7 +72,7 @@ public class NerProcessorTests extends ESTestCase {
 
         var e = expectThrows(
             ElasticsearchStatusException.class,
-            () -> processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, false))
+            () -> processor.processResult(tokenization, new PyTorchInferenceResult(null))
         );
         assertThat(e, instanceOf(ElasticsearchStatusException.class));
     }
@@ -113,7 +113,7 @@ public class NerProcessorTests extends ESTestCase {
                     { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london
                     { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep
                 } };
-            NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false));
+            NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores));
 
             assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
             assertThat(result.getEntityGroups().size(), equalTo(2));
@@ -141,7 +141,7 @@ public class NerProcessorTests extends ESTestCase {
                 { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
                 { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london
             } };
-        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores));
 
         assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));
@@ -178,7 +178,7 @@ public class NerProcessorTests extends ESTestCase {
                 { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in
                 { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london
             } };
-        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores));
 
         assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));
@@ -211,7 +211,7 @@ public class NerProcessorTests extends ESTestCase {
                 { 0, 0, 0, 0, 5 }, // in
                 { 6, 0, 0, 0, 0 } // london
             } };
-        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores));
 
         assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](SOFTWARE&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java

@@ -87,7 +87,7 @@ public class QuestionAnsweringProcessorTests extends ESTestCase {
         assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7));
         double[][][] scores = { { START_TOKEN_SCORES }, { END_TOKEN_SCORES } };
         NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config);
-        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false);
+        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores);
         QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult(
             tokenizationResult,
             pyTorchResult

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -32,7 +32,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
 
     public void testInvalidResult() {
         {
-            PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, false);
+            PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] {});
             var e = expectThrows(
                 ElasticsearchStatusException.class,
                 () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
@@ -41,7 +41,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
             assertThat(e.getMessage(), containsString("Text classification result has no data"));
         }
         {
-            PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, false);
+            PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0 } } });
             var e = expectThrows(
                 ElasticsearchStatusException.class,
                 () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java

@@ -51,7 +51,7 @@ public class TextSimilarityProcessorTests extends ESTestCase {
         assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7));
         double[][][] scores = { { { 42 } } };
         NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(textSimilarityConfig);
-        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false);
+        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores);
         TextSimilarityInferenceResults result = (TextSimilarityInferenceResults) resultProcessor.processResult(
             tokenizationResult,
             pyTorchResult
@@ -74,7 +74,7 @@ public class TextSimilarityProcessorTests extends ESTestCase {
         TextSimilarityProcessor processor = new TextSimilarityProcessor(tokenizer);
         NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(textSimilarityConfig);
         double[][][] scores = { { { 42 }, { 12 }, { 100 } } };
-        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false);
+        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores);
         TextSimilarityInferenceResults result = (TextSimilarityInferenceResults) resultProcessor.processResult(
             new BertTokenizationResult(List.of(), List.of(), 1),
             pyTorchResult

+ 48 - 41
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.inference.pytorch.process;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.AckResult;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
@@ -24,7 +25,6 @@ import java.util.function.LongSupplier;
 
 import static org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor.REPORTING_PERIOD_MS;
 import static org.hamcrest.Matchers.closeTo;
-import static org.hamcrest.Matchers.comparesEqualTo;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.nullValue;
@@ -37,40 +37,47 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         var settingsHolder = new AtomicReference<ThreadSettings>();
         var processor = new PyTorchResultProcessor("deployment-foo", settingsHolder::set);
 
-        var settings = new ThreadSettings(1, 1, "thread-setting");
+        var settings = new ThreadSettings(1, 1);
         processor.registerRequest("thread-setting", new AssertingResultListener(r -> assertEquals(settings, r.threadSettings())));
 
-        processor.process(mockNativeProcess(List.of(new PyTorchResult(null, settings, null)).iterator()));
+        processor.process(
+            mockNativeProcess(List.of(new PyTorchResult("thread-setting", null, null, null, settings, null, null)).iterator())
+        );
 
         assertEquals(settings, settingsHolder.get());
     }
 
     public void testResultsProcessing() {
-        var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false);
-        var threadSettings = new ThreadSettings(1, 1, "b");
-        var errorResult = new ErrorResult("c", "a bad thing has happened");
+        var inferenceResult = new PyTorchInferenceResult(null);
+        var threadSettings = new ThreadSettings(1, 1);
+        var ack = new AckResult(true);
+        var errorResult = new ErrorResult("a bad thing has happened");
 
         var inferenceListener = new AssertingResultListener(r -> assertEquals(inferenceResult, r.inferenceResult()));
         var threadSettingsListener = new AssertingResultListener(r -> assertEquals(threadSettings, r.threadSettings()));
+        var ackListener = new AssertingResultListener(r -> assertEquals(ack, r.ackResult()));
         var errorListener = new AssertingResultListener(r -> assertEquals(errorResult, r.errorResult()));
 
         var processor = new PyTorchResultProcessor("foo", s -> {});
         processor.registerRequest("a", inferenceListener);
         processor.registerRequest("b", threadSettingsListener);
-        processor.registerRequest("c", errorListener);
+        processor.registerRequest("c", ackListener);
+        processor.registerRequest("d", errorListener);
 
         processor.process(
             mockNativeProcess(
                 List.of(
-                    new PyTorchResult(inferenceResult, null, null),
-                    new PyTorchResult(null, threadSettings, null),
-                    new PyTorchResult(null, null, errorResult)
+                    new PyTorchResult("a", true, 1000L, inferenceResult, null, null, null),
+                    new PyTorchResult("b", null, null, null, threadSettings, null, null),
+                    new PyTorchResult("c", null, null, null, null, ack, null),
+                    new PyTorchResult("d", null, null, null, null, null, errorResult)
                 ).iterator()
             )
         );
 
         assertTrue(inferenceListener.hasResponse);
         assertTrue(threadSettingsListener.hasResponse);
+        assertTrue(ackListener.hasResponse);
         assertTrue(errorListener.hasResponse);
     }
 
@@ -86,9 +93,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         );
         processor.registerRequest("b", calledOnShutdown);
 
-        var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false);
+        var inferenceResult = new PyTorchInferenceResult(null);
 
-        processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator()));
+        processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, 1000L, inferenceResult, null, null, null)).iterator()));
         assertSame(inferenceResult, resultHolder.get());
         assertTrue(calledOnShutdown.hasResponse);
     }
@@ -100,8 +107,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
 
         processor.ignoreResponseWithoutNotifying("a");
 
-        var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false);
-        processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator()));
+        var inferenceResult = new PyTorchInferenceResult(null);
+        processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, 1000L, inferenceResult, null, null, null)).iterator()));
     }
 
     public void testPendingRequestAreCalledAtShutdown() {
@@ -146,8 +153,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         }
     }
 
-    private PyTorchResult wrapInferenceResult(PyTorchInferenceResult result) {
-        return new PyTorchResult(result, null, null);
+    private PyTorchResult wrapInferenceResult(String requestId, boolean isCacheHit, long timeMs, PyTorchInferenceResult result) {
+        return new PyTorchResult(requestId, isCacheHit, timeMs, result, null, null, null);
     }
 
     public void testsStats() {
@@ -161,33 +168,33 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         processor.registerRequest("b", pendingB);
         processor.registerRequest("c", pendingC);
 
-        var a = wrapInferenceResult(new PyTorchInferenceResult("a", null, 1000L, false));
-        var b = wrapInferenceResult(new PyTorchInferenceResult("b", null, 900L, false));
-        var c = wrapInferenceResult(new PyTorchInferenceResult("c", null, 200L, true));
+        var a = wrapInferenceResult("a", false, 1000L, new PyTorchInferenceResult(null));
+        var b = wrapInferenceResult("b", false, 900L, new PyTorchInferenceResult(null));
+        var c = wrapInferenceResult("c", true, 200L, new PyTorchInferenceResult(null));
 
         processor.processInferenceResult(a);
         var stats = processor.getResultStats();
-        assertThat(stats.errorCount(), comparesEqualTo(0));
+        assertThat(stats.errorCount(), equalTo(0));
         assertThat(stats.cacheHitCount(), equalTo(0L));
-        assertThat(stats.numberOfPendingResults(), comparesEqualTo(2));
-        assertThat(stats.timingStats().getCount(), comparesEqualTo(1L));
-        assertThat(stats.timingStats().getSum(), comparesEqualTo(1000L));
+        assertThat(stats.numberOfPendingResults(), equalTo(2));
+        assertThat(stats.timingStats().getCount(), equalTo(1L));
+        assertThat(stats.timingStats().getSum(), equalTo(1000L));
 
         processor.processInferenceResult(b);
         stats = processor.getResultStats();
-        assertThat(stats.errorCount(), comparesEqualTo(0));
+        assertThat(stats.errorCount(), equalTo(0));
         assertThat(stats.cacheHitCount(), equalTo(0L));
-        assertThat(stats.numberOfPendingResults(), comparesEqualTo(1));
-        assertThat(stats.timingStats().getCount(), comparesEqualTo(2L));
-        assertThat(stats.timingStats().getSum(), comparesEqualTo(1900L));
+        assertThat(stats.numberOfPendingResults(), equalTo(1));
+        assertThat(stats.timingStats().getCount(), equalTo(2L));
+        assertThat(stats.timingStats().getSum(), equalTo(1900L));
 
         processor.processInferenceResult(c);
         stats = processor.getResultStats();
-        assertThat(stats.errorCount(), comparesEqualTo(0));
+        assertThat(stats.errorCount(), equalTo(0));
         assertThat(stats.cacheHitCount(), equalTo(1L));
-        assertThat(stats.numberOfPendingResults(), comparesEqualTo(0));
-        assertThat(stats.timingStats().getCount(), comparesEqualTo(3L));
-        assertThat(stats.timingStats().getSum(), comparesEqualTo(2100L));
+        assertThat(stats.numberOfPendingResults(), equalTo(0));
+        assertThat(stats.timingStats().getCount(), equalTo(3L));
+        assertThat(stats.timingStats().getSum(), equalTo(2100L));
     }
 
     public void testsTimeDependentStats() {
@@ -227,9 +234,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         var processor = new PyTorchResultProcessor("foo", s -> {}, timeSupplier);
 
         // 1st period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null)));
         // first call has no results as is in the same period
         var stats = processor.getResultStats();
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
@@ -243,7 +250,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.peakThroughput(), equalTo(3L));
 
         // 2nd period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 100L, false)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 100L, new PyTorchInferenceResult(null)));
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
@@ -255,7 +262,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
 
         // 4th period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 300L, false)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 300L, new PyTorchInferenceResult(null)));
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
@@ -263,8 +270,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[9])));
 
         // 7th period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 410L, false)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 390L, false)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 410L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 390L, new PyTorchInferenceResult(null)));
         stats = processor.getResultStats();
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
         assertThat(stats.recentStats().avgInferenceTime(), nullValue());
@@ -275,9 +282,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[12])));
 
         // 8th period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 510L, false)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 500L, false)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 490L, false)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 510L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 500L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 490L, new PyTorchInferenceResult(null)));
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(3L));

+ 35 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResultTests.java

@@ -0,0 +1,35 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.pytorch.results;
+
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+
+public class AckResultTests extends AbstractXContentTestCase<AckResult> {
+
+    public static AckResult createRandom() {
+        return new AckResult(randomBoolean());
+    }
+
+    @Override
+    protected AckResult createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected AckResult doParseInstance(XContentParser parser) throws IOException {
+        return AckResult.PARSER.parse(parser, null);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return false;
+    }
+}

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java

@@ -15,7 +15,7 @@ import java.io.IOException;
 public class ErrorResultTests extends AbstractXContentTestCase<ErrorResult> {
 
     public static ErrorResult createRandom() {
-        return new ErrorResult(randomBoolean() ? null : randomAlphaOfLength(5), randomAlphaOfLength(5));
+        return new ErrorResult(randomAlphaOfLength(50));
     }
 
     @Override

+ 1 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java

@@ -30,8 +30,6 @@ public class PyTorchInferenceResultTests extends AbstractXContentTestCase<PyTorc
     }
 
     public static PyTorchInferenceResult createRandom() {
-        String id = randomAlphaOfLength(6);
-
         int rows = randomIntBetween(1, 10);
         int columns = randomIntBetween(1, 10);
         int depth = randomIntBetween(1, 10);
@@ -43,6 +41,6 @@ public class PyTorchInferenceResultTests extends AbstractXContentTestCase<PyTorc
                 }
             }
         }
-        return new PyTorchInferenceResult(id, arr, randomLong(), randomBoolean());
+        return new PyTorchInferenceResult(arr);
     }
 }

+ 14 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java

@@ -16,11 +16,21 @@ public class PyTorchResultTests extends AbstractXContentTestCase<PyTorchResult>
 
     @Override
     protected PyTorchResult createTestInstance() {
-        int type = randomIntBetween(0, 2);
+        String requestId = randomAlphaOfLength(5);
+        int type = randomIntBetween(0, 3);
         return switch (type) {
-            case 0 -> new PyTorchResult(PyTorchInferenceResultTests.createRandom(), null, null);
-            case 1 -> new PyTorchResult(null, ThreadSettingsTests.createRandom(), null);
-            default -> new PyTorchResult(null, null, ErrorResultTests.createRandom());
+            case 0 -> new PyTorchResult(
+                requestId,
+                randomBoolean(),
+                randomNonNegativeLong(),
+                PyTorchInferenceResultTests.createRandom(),
+                null,
+                null,
+                null
+            );
+            case 1 -> new PyTorchResult(requestId, null, null, null, ThreadSettingsTests.createRandom(), null, null);
+            case 2 -> new PyTorchResult(requestId, null, null, null, null, AckResultTests.createRandom(), null);
+            default -> new PyTorchResult(requestId, null, null, null, null, null, ErrorResultTests.createRandom());
         };
     }
 

+ 1 - 5
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java

@@ -15,11 +15,7 @@ import java.io.IOException;
 public class ThreadSettingsTests extends AbstractXContentTestCase<ThreadSettings> {
 
     public static ThreadSettings createRandom() {
-        return new ThreadSettings(
-            randomIntBetween(1, Integer.MAX_VALUE),
-            randomIntBetween(1, Integer.MAX_VALUE),
-            randomBoolean() ? null : randomAlphaOfLength(5)
-        );
+        return new ThreadSettings(randomIntBetween(1, Integer.MAX_VALUE), randomIntBetween(1, Integer.MAX_VALUE));
     }
 
     @Override

+ 0 - 2
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

@@ -76,8 +76,6 @@ setup:
 ---
 "Test start and stop deployment with cache":
   - skip:
-      version: all
-      reason: "@AwaitsFix https://github.com/elastic/ml-cpp/pull/2376"
       features: allowed_warnings
 
   - do:

+ 0 - 1
x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java

@@ -31,7 +31,6 @@ import static org.elasticsearch.client.WarningsHandler.PERMISSIVE;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 
-@AbstractFullClusterRestartTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376")
 public class MLModelDeploymentFullClusterRestartIT extends AbstractFullClusterRestartTestCase {
 
     // See PyTorchModelIT for how this model was created

+ 0 - 1
x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java

@@ -29,7 +29,6 @@ import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 
-@AbstractUpgradeTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376")
 public class MLModelDeploymentsUpgradeIT extends AbstractUpgradeTestCase {
 
     // See PyTorchModelIT for how this model was created