|
@@ -18,13 +18,13 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
|
|
|
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
|
|
|
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
|
+import org.elasticsearch.common.breaker.CircuitBreaker;
|
|
|
+import org.elasticsearch.common.breaker.CircuitBreakingException;
|
|
|
import org.elasticsearch.common.bytes.BytesReference;
|
|
|
-import org.elasticsearch.common.collect.Tuple;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.transport.TransportAddress;
|
|
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
|
|
import org.elasticsearch.common.unit.TimeValue;
|
|
|
-import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
|
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.common.xcontent.XContentFactory;
|
|
|
import org.elasticsearch.common.xcontent.XContentType;
|
|
@@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
|
|
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
|
|
+import org.elasticsearch.xpack.ml.MachineLearning;
|
|
|
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
|
|
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
|
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
|
@@ -61,6 +62,7 @@ import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
+import static org.hamcrest.Matchers.instanceOf;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
|
import static org.hamcrest.Matchers.not;
|
|
|
import static org.hamcrest.Matchers.nullValue;
|
|
@@ -83,6 +85,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
private ClusterService clusterService;
|
|
|
private InferenceAuditor auditor;
|
|
|
private TrainedModelStatsService trainedModelStatsService;
|
|
|
+ private CircuitBreaker circuitBreaker;
|
|
|
|
|
|
@Before
|
|
|
public void setUpComponents() {
|
|
@@ -97,6 +100,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class));
|
|
|
doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class));
|
|
|
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build());
|
|
|
+ circuitBreaker = new CustomCircuitBreaker(1000);
|
|
|
}
|
|
|
|
|
|
@After
|
|
@@ -116,10 +120,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
auditor,
|
|
|
threadPool,
|
|
|
clusterService,
|
|
|
- NamedXContentRegistry.EMPTY,
|
|
|
trainedModelStatsService,
|
|
|
Settings.EMPTY,
|
|
|
- "test-node");
|
|
|
+ "test-node",
|
|
|
+ circuitBreaker);
|
|
|
|
|
|
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
|
|
|
|
@@ -163,10 +167,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
auditor,
|
|
|
threadPool,
|
|
|
clusterService,
|
|
|
- NamedXContentRegistry.EMPTY,
|
|
|
trainedModelStatsService,
|
|
|
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(),
|
|
|
- "test-node");
|
|
|
+ "test-node",
|
|
|
+ circuitBreaker);
|
|
|
|
|
|
// We want to be notified when the models are loaded which happens in a background thread
|
|
|
ModelLoadedTracker loadedTracker = new ModelLoadedTracker(Arrays.asList(modelIds));
|
|
@@ -279,10 +283,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
auditor,
|
|
|
threadPool,
|
|
|
clusterService,
|
|
|
- NamedXContentRegistry.EMPTY,
|
|
|
trainedModelStatsService,
|
|
|
Settings.EMPTY,
|
|
|
- "test-node");
|
|
|
+ "test-node",
|
|
|
+ circuitBreaker);
|
|
|
|
|
|
modelLoadingService.clusterChanged(ingestChangedEvent(false, model1));
|
|
|
|
|
@@ -304,10 +308,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
auditor,
|
|
|
threadPool,
|
|
|
clusterService,
|
|
|
- NamedXContentRegistry.EMPTY,
|
|
|
trainedModelStatsService,
|
|
|
Settings.EMPTY,
|
|
|
- "test-node");
|
|
|
+ "test-node",
|
|
|
+ circuitBreaker);
|
|
|
modelLoadingService.clusterChanged(ingestChangedEvent(model));
|
|
|
|
|
|
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
|
@@ -332,10 +336,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
auditor,
|
|
|
threadPool,
|
|
|
clusterService,
|
|
|
- NamedXContentRegistry.EMPTY,
|
|
|
trainedModelStatsService,
|
|
|
Settings.EMPTY,
|
|
|
- "test-node");
|
|
|
+ "test-node",
|
|
|
+ circuitBreaker);
|
|
|
|
|
|
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
|
|
modelLoadingService.getModel(model, future);
|
|
@@ -355,10 +359,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
auditor,
|
|
|
threadPool,
|
|
|
clusterService,
|
|
|
- NamedXContentRegistry.EMPTY,
|
|
|
trainedModelStatsService,
|
|
|
Settings.EMPTY,
|
|
|
- "test-node");
|
|
|
+ "test-node",
|
|
|
+ circuitBreaker);
|
|
|
|
|
|
for(int i = 0; i < 3; i++) {
|
|
|
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
|
@@ -370,6 +374,50 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
|
|
|
}
|
|
|
|
|
|
+ public void testCircuitBreakerBreak() throws Exception {
|
|
|
+ String model1 = "test-circuit-break-model-1";
|
|
|
+ String model2 = "test-circuit-break-model-2";
|
|
|
+ String model3 = "test-circuit-break-model-3";
|
|
|
+ withTrainedModel(model1, 5L);
|
|
|
+ withTrainedModel(model2, 5L);
|
|
|
+ withTrainedModel(model3, 12L);
|
|
|
+ CircuitBreaker circuitBreaker = new CustomCircuitBreaker(11);
|
|
|
+ ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
|
|
+ auditor,
|
|
|
+ threadPool,
|
|
|
+ clusterService,
|
|
|
+ trainedModelStatsService,
|
|
|
+ Settings.EMPTY,
|
|
|
+ "test-node",
|
|
|
+ circuitBreaker);
|
|
|
+
|
|
|
+ modelLoadingService.addModelLoadedListener(model3, ActionListener.wrap(
|
|
|
+ r -> fail("Should not have succeeded to load model as breaker should be reached"),
|
|
|
+ e -> assertThat(e, instanceOf(CircuitBreakingException.class))
|
|
|
+ ));
|
|
|
+
|
|
|
+ modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
|
|
+
|
|
|
+ // Should have been loaded from the cluster change event but it is unknown in what order
|
|
|
+ // the loading occurred or which models are currently in the cache due to evictions.
|
|
|
+ // Verify that we have at least loaded all three
|
|
|
+ assertBusy(() -> {
|
|
|
+ verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any());
|
|
|
+ verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any());
|
|
|
+ verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any());
|
|
|
+ });
|
|
|
+ assertBusy(() -> {
|
|
|
+ assertThat(circuitBreaker.getUsed(), equalTo(10L));
|
|
|
+ assertThat(circuitBreaker.getTrippedCount(), equalTo(1L));
|
|
|
+ });
|
|
|
+
|
|
|
+ modelLoadingService.clusterChanged(ingestChangedEvent(model1));
|
|
|
+
|
|
|
+ assertBusy(() -> {
|
|
|
+ assertThat(circuitBreaker.getUsed(), equalTo(5L));
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
@SuppressWarnings("unchecked")
|
|
|
private void withTrainedModel(String modelId, long size) {
|
|
|
InferenceDefinition definition = mock(InferenceDefinition.class);
|
|
@@ -378,15 +426,48 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
when(trainedModelConfig.getModelId()).thenReturn(modelId);
|
|
|
when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS);
|
|
|
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
|
|
|
+ when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(size);
|
|
|
doAnswer(invocationOnMock -> {
|
|
|
@SuppressWarnings("rawtypes")
|
|
|
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
|
|
- listener.onResponse(Tuple.tuple(trainedModelConfig, definition));
|
|
|
+ listener.onResponse(definition);
|
|
|
return null;
|
|
|
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
|
|
|
+ doAnswer(invocationOnMock -> {
|
|
|
+ @SuppressWarnings("rawtypes")
|
|
|
+ ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
|
|
+ listener.onResponse(trainedModelConfig);
|
|
|
+ return null;
|
|
|
+ }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
|
|
}
|
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
private void withMissingModel(String modelId) {
|
|
|
+ if (randomBoolean()) {
|
|
|
+ doAnswer(invocationOnMock -> {
|
|
|
+ @SuppressWarnings("rawtypes")
|
|
|
+ ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
|
|
+ listener.onFailure(new ResourceNotFoundException(
|
|
|
+ Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
|
|
+ return null;
|
|
|
+ }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
|
|
+ } else {
|
|
|
+ TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
|
|
|
+ when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
|
|
|
+ doAnswer(invocationOnMock -> {
|
|
|
+ @SuppressWarnings("rawtypes")
|
|
|
+ ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
|
|
+ listener.onResponse(trainedModelConfig);
|
|
|
+ return null;
|
|
|
+ }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
|
|
+ doAnswer(invocationOnMock -> {
|
|
|
+ @SuppressWarnings("rawtypes")
|
|
|
+ ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
|
|
+ listener.onFailure(new ResourceNotFoundException(
|
|
|
+ Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
|
|
+ return null;
|
|
|
+ }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
|
|
|
+ }
|
|
|
doAnswer(invocationOnMock -> {
|
|
|
@SuppressWarnings("rawtypes")
|
|
|
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
|
@@ -438,6 +519,79 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private static class CustomCircuitBreaker implements CircuitBreaker {
|
|
|
+
|
|
|
+ private final long maxBytes;
|
|
|
+ private long currentBytes = 0;
|
|
|
+ private long trippedCount = 0;
|
|
|
+
|
|
|
+ CustomCircuitBreaker(long maxBytes) {
|
|
|
+ this.maxBytes = maxBytes;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void circuitBreak(String fieldName, long bytesNeeded) {
|
|
|
+ throw new CircuitBreakingException(fieldName, Durability.TRANSIENT);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException {
|
|
|
+ synchronized (this) {
|
|
|
+ if (bytes + currentBytes >= maxBytes) {
|
|
|
+ trippedCount++;
|
|
|
+ circuitBreak(label, bytes);
|
|
|
+ }
|
|
|
+ currentBytes += bytes;
|
|
|
+ return currentBytes;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public long addWithoutBreaking(long bytes) {
|
|
|
+ synchronized (this) {
|
|
|
+ currentBytes += bytes;
|
|
|
+ return currentBytes;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public long getUsed() {
|
|
|
+ return currentBytes;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public long getLimit() {
|
|
|
+ return maxBytes;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public double getOverhead() {
|
|
|
+ return 1.0;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public long getTrippedCount() {
|
|
|
+ synchronized (this) {
|
|
|
+ return trippedCount;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String getName() {
|
|
|
+ return MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Durability getDurability() {
|
|
|
+ return Durability.TRANSIENT;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void setLimitAndOverhead(long limit, double overhead) {
|
|
|
+ throw new UnsupportedOperationException("boom");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private static class ModelLoadedTracker {
|
|
|
private final Set<String> expectedModelIds;
|
|
|
|