|
@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.inference.pytorch.process;
|
|
|
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
+import org.elasticsearch.xpack.ml.inference.pytorch.results.AckResult;
|
|
|
import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult;
|
|
|
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
|
|
|
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
|
|
@@ -24,7 +25,6 @@ import java.util.function.LongSupplier;
|
|
|
|
|
|
import static org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor.REPORTING_PERIOD_MS;
|
|
|
import static org.hamcrest.Matchers.closeTo;
|
|
|
-import static org.hamcrest.Matchers.comparesEqualTo;
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.hamcrest.Matchers.nullValue;
|
|
@@ -37,40 +37,47 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
var settingsHolder = new AtomicReference<ThreadSettings>();
|
|
|
var processor = new PyTorchResultProcessor("deployment-foo", settingsHolder::set);
|
|
|
|
|
|
- var settings = new ThreadSettings(1, 1, "thread-setting");
|
|
|
+ var settings = new ThreadSettings(1, 1);
|
|
|
processor.registerRequest("thread-setting", new AssertingResultListener(r -> assertEquals(settings, r.threadSettings())));
|
|
|
|
|
|
- processor.process(mockNativeProcess(List.of(new PyTorchResult(null, settings, null)).iterator()));
|
|
|
+ processor.process(
|
|
|
+ mockNativeProcess(List.of(new PyTorchResult("thread-setting", null, null, null, settings, null, null)).iterator())
|
|
|
+ );
|
|
|
|
|
|
assertEquals(settings, settingsHolder.get());
|
|
|
}
|
|
|
|
|
|
public void testResultsProcessing() {
|
|
|
- var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false);
|
|
|
- var threadSettings = new ThreadSettings(1, 1, "b");
|
|
|
- var errorResult = new ErrorResult("c", "a bad thing has happened");
|
|
|
+ var inferenceResult = new PyTorchInferenceResult(null);
|
|
|
+ var threadSettings = new ThreadSettings(1, 1);
|
|
|
+ var ack = new AckResult(true);
|
|
|
+ var errorResult = new ErrorResult("a bad thing has happened");
|
|
|
|
|
|
var inferenceListener = new AssertingResultListener(r -> assertEquals(inferenceResult, r.inferenceResult()));
|
|
|
var threadSettingsListener = new AssertingResultListener(r -> assertEquals(threadSettings, r.threadSettings()));
|
|
|
+ var ackListener = new AssertingResultListener(r -> assertEquals(ack, r.ackResult()));
|
|
|
var errorListener = new AssertingResultListener(r -> assertEquals(errorResult, r.errorResult()));
|
|
|
|
|
|
var processor = new PyTorchResultProcessor("foo", s -> {});
|
|
|
processor.registerRequest("a", inferenceListener);
|
|
|
processor.registerRequest("b", threadSettingsListener);
|
|
|
- processor.registerRequest("c", errorListener);
|
|
|
+ processor.registerRequest("c", ackListener);
|
|
|
+ processor.registerRequest("d", errorListener);
|
|
|
|
|
|
processor.process(
|
|
|
mockNativeProcess(
|
|
|
List.of(
|
|
|
- new PyTorchResult(inferenceResult, null, null),
|
|
|
- new PyTorchResult(null, threadSettings, null),
|
|
|
- new PyTorchResult(null, null, errorResult)
|
|
|
+ new PyTorchResult("a", true, 1000L, inferenceResult, null, null, null),
|
|
|
+ new PyTorchResult("b", null, null, null, threadSettings, null, null),
|
|
|
+ new PyTorchResult("c", null, null, null, null, ack, null),
|
|
|
+ new PyTorchResult("d", null, null, null, null, null, errorResult)
|
|
|
).iterator()
|
|
|
)
|
|
|
);
|
|
|
|
|
|
assertTrue(inferenceListener.hasResponse);
|
|
|
assertTrue(threadSettingsListener.hasResponse);
|
|
|
+ assertTrue(ackListener.hasResponse);
|
|
|
assertTrue(errorListener.hasResponse);
|
|
|
}
|
|
|
|
|
@@ -86,9 +93,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
);
|
|
|
processor.registerRequest("b", calledOnShutdown);
|
|
|
|
|
|
- var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false);
|
|
|
+ var inferenceResult = new PyTorchInferenceResult(null);
|
|
|
|
|
|
- processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator()));
|
|
|
+ processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, 1000L, inferenceResult, null, null, null)).iterator()));
|
|
|
assertSame(inferenceResult, resultHolder.get());
|
|
|
assertTrue(calledOnShutdown.hasResponse);
|
|
|
}
|
|
@@ -100,8 +107,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
|
|
|
processor.ignoreResponseWithoutNotifying("a");
|
|
|
|
|
|
- var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false);
|
|
|
- processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator()));
|
|
|
+ var inferenceResult = new PyTorchInferenceResult(null);
|
|
|
+ processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, 1000L, inferenceResult, null, null, null)).iterator()));
|
|
|
}
|
|
|
|
|
|
public void testPendingRequestAreCalledAtShutdown() {
|
|
@@ -146,8 +153,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private PyTorchResult wrapInferenceResult(PyTorchInferenceResult result) {
|
|
|
- return new PyTorchResult(result, null, null);
|
|
|
+ private PyTorchResult wrapInferenceResult(String requestId, boolean isCacheHit, long timeMs, PyTorchInferenceResult result) {
|
|
|
+ return new PyTorchResult(requestId, isCacheHit, timeMs, result, null, null, null);
|
|
|
}
|
|
|
|
|
|
public void testsStats() {
|
|
@@ -161,33 +168,33 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
processor.registerRequest("b", pendingB);
|
|
|
processor.registerRequest("c", pendingC);
|
|
|
|
|
|
- var a = wrapInferenceResult(new PyTorchInferenceResult("a", null, 1000L, false));
|
|
|
- var b = wrapInferenceResult(new PyTorchInferenceResult("b", null, 900L, false));
|
|
|
- var c = wrapInferenceResult(new PyTorchInferenceResult("c", null, 200L, true));
|
|
|
+ var a = wrapInferenceResult("a", false, 1000L, new PyTorchInferenceResult(null));
|
|
|
+ var b = wrapInferenceResult("b", false, 900L, new PyTorchInferenceResult(null));
|
|
|
+ var c = wrapInferenceResult("c", true, 200L, new PyTorchInferenceResult(null));
|
|
|
|
|
|
processor.processInferenceResult(a);
|
|
|
var stats = processor.getResultStats();
|
|
|
- assertThat(stats.errorCount(), comparesEqualTo(0));
|
|
|
+ assertThat(stats.errorCount(), equalTo(0));
|
|
|
assertThat(stats.cacheHitCount(), equalTo(0L));
|
|
|
- assertThat(stats.numberOfPendingResults(), comparesEqualTo(2));
|
|
|
- assertThat(stats.timingStats().getCount(), comparesEqualTo(1L));
|
|
|
- assertThat(stats.timingStats().getSum(), comparesEqualTo(1000L));
|
|
|
+ assertThat(stats.numberOfPendingResults(), equalTo(2));
|
|
|
+ assertThat(stats.timingStats().getCount(), equalTo(1L));
|
|
|
+ assertThat(stats.timingStats().getSum(), equalTo(1000L));
|
|
|
|
|
|
processor.processInferenceResult(b);
|
|
|
stats = processor.getResultStats();
|
|
|
- assertThat(stats.errorCount(), comparesEqualTo(0));
|
|
|
+ assertThat(stats.errorCount(), equalTo(0));
|
|
|
assertThat(stats.cacheHitCount(), equalTo(0L));
|
|
|
- assertThat(stats.numberOfPendingResults(), comparesEqualTo(1));
|
|
|
- assertThat(stats.timingStats().getCount(), comparesEqualTo(2L));
|
|
|
- assertThat(stats.timingStats().getSum(), comparesEqualTo(1900L));
|
|
|
+ assertThat(stats.numberOfPendingResults(), equalTo(1));
|
|
|
+ assertThat(stats.timingStats().getCount(), equalTo(2L));
|
|
|
+ assertThat(stats.timingStats().getSum(), equalTo(1900L));
|
|
|
|
|
|
processor.processInferenceResult(c);
|
|
|
stats = processor.getResultStats();
|
|
|
- assertThat(stats.errorCount(), comparesEqualTo(0));
|
|
|
+ assertThat(stats.errorCount(), equalTo(0));
|
|
|
assertThat(stats.cacheHitCount(), equalTo(1L));
|
|
|
- assertThat(stats.numberOfPendingResults(), comparesEqualTo(0));
|
|
|
- assertThat(stats.timingStats().getCount(), comparesEqualTo(3L));
|
|
|
- assertThat(stats.timingStats().getSum(), comparesEqualTo(2100L));
|
|
|
+ assertThat(stats.numberOfPendingResults(), equalTo(0));
|
|
|
+ assertThat(stats.timingStats().getCount(), equalTo(3L));
|
|
|
+ assertThat(stats.timingStats().getSum(), equalTo(2100L));
|
|
|
}
|
|
|
|
|
|
public void testsTimeDependentStats() {
|
|
@@ -227,9 +234,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
var processor = new PyTorchResultProcessor("foo", s -> {}, timeSupplier);
|
|
|
|
|
|
// 1st period
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false)));
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false)));
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null)));
|
|
|
// first call has no results as is in the same period
|
|
|
var stats = processor.getResultStats();
|
|
|
assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
|
|
@@ -243,7 +250,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
assertThat(stats.peakThroughput(), equalTo(3L));
|
|
|
|
|
|
// 2nd period
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 100L, false)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 100L, new PyTorchInferenceResult(null)));
|
|
|
stats = processor.getResultStats();
|
|
|
assertNotNull(stats.recentStats());
|
|
|
assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
|
|
@@ -255,7 +262,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
|
|
|
|
|
|
// 4th period
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 300L, false)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 300L, new PyTorchInferenceResult(null)));
|
|
|
stats = processor.getResultStats();
|
|
|
assertNotNull(stats.recentStats());
|
|
|
assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
|
|
@@ -263,8 +270,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[9])));
|
|
|
|
|
|
// 7th period
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 410L, false)));
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 390L, false)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 410L, new PyTorchInferenceResult(null)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 390L, new PyTorchInferenceResult(null)));
|
|
|
stats = processor.getResultStats();
|
|
|
assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
|
|
|
assertThat(stats.recentStats().avgInferenceTime(), nullValue());
|
|
@@ -275,9 +282,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
|
|
|
assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[12])));
|
|
|
|
|
|
// 8th period
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 510L, false)));
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 500L, false)));
|
|
|
- processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 490L, false)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 510L, new PyTorchInferenceResult(null)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 500L, new PyTorchInferenceResult(null)));
|
|
|
+ processor.processInferenceResult(wrapInferenceResult("foo", false, 490L, new PyTorchInferenceResult(null)));
|
|
|
stats = processor.getResultStats();
|
|
|
assertNotNull(stats.recentStats());
|
|
|
assertThat(stats.recentStats().requestsProcessed(), equalTo(3L));
|