|
@@ -452,7 +452,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- public void testReferenceCounting() throws ExecutionException, InterruptedException, IOException {
|
|
|
+ public void testReferenceCounting() throws Exception {
|
|
|
String modelId = "test-reference-counting";
|
|
|
withTrainedModel(modelId, 1L);
|
|
|
|
|
@@ -469,25 +469,24 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
|
|
|
PlainActionFuture<LocalModel> forPipeline = new PlainActionFuture<>();
|
|
|
modelLoadingService.getModelForPipeline(modelId, forPipeline);
|
|
|
- LocalModel model = forPipeline.get();
|
|
|
- assertEquals(2, model.getReferenceCount());
|
|
|
+ final LocalModel model = forPipeline.get();
|
|
|
+ assertBusy(() -> assertEquals(2, model.getReferenceCount()));
|
|
|
|
|
|
PlainActionFuture<LocalModel> forSearch = new PlainActionFuture<>();
|
|
|
modelLoadingService.getModelForPipeline(modelId, forSearch);
|
|
|
- model = forSearch.get();
|
|
|
- assertEquals(3, model.getReferenceCount());
|
|
|
+ forSearch.get();
|
|
|
+ assertBusy(() -> assertEquals(3, model.getReferenceCount()));
|
|
|
|
|
|
model.release();
|
|
|
- assertEquals(2, model.getReferenceCount());
|
|
|
+ assertBusy(() -> assertEquals(2, model.getReferenceCount()));
|
|
|
|
|
|
PlainActionFuture<LocalModel> forSearch2 = new PlainActionFuture<>();
|
|
|
modelLoadingService.getModelForPipeline(modelId, forSearch2);
|
|
|
- model = forSearch2.get();
|
|
|
- assertEquals(3, model.getReferenceCount());
|
|
|
+ forSearch2.get();
|
|
|
+ assertBusy(() -> assertEquals(3, model.getReferenceCount()));
|
|
|
}
|
|
|
|
|
|
- @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/59445")
|
|
|
- public void testReferenceCountingForPipeline() throws ExecutionException, InterruptedException, IOException {
|
|
|
+ public void testReferenceCountingForPipeline() throws Exception {
|
|
|
String modelId = "test-reference-counting-for-pipeline";
|
|
|
withTrainedModel(modelId, 1L);
|
|
|
|
|
@@ -504,18 +503,17 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|
|
|
|
|
PlainActionFuture<LocalModel> forPipeline = new PlainActionFuture<>();
|
|
|
modelLoadingService.getModelForPipeline(modelId, forPipeline);
|
|
|
- LocalModel model = forPipeline.get();
|
|
|
- assertEquals(2, model.getReferenceCount());
|
|
|
+ final LocalModel model = forPipeline.get();
|
|
|
+ assertBusy(() -> assertEquals(2, model.getReferenceCount()));
|
|
|
|
|
|
PlainActionFuture<LocalModel> forPipeline2 = new PlainActionFuture<>();
|
|
|
modelLoadingService.getModelForPipeline(modelId, forPipeline2);
|
|
|
- model = forPipeline2.get();
|
|
|
- assertEquals(3, model.getReferenceCount());
|
|
|
+ forPipeline2.get();
|
|
|
+ assertBusy(() -> assertEquals(3, model.getReferenceCount()));
|
|
|
|
|
|
// will cause the model to be evicted
|
|
|
modelLoadingService.clusterChanged(ingestChangedEvent());
|
|
|
-
|
|
|
- assertEquals(2, model.getReferenceCount());
|
|
|
+ assertBusy(() -> assertEquals(2, model.getReferenceCount()));
|
|
|
}
|
|
|
|
|
|
public void testReferenceCounting_ModelIsNotCached() throws ExecutionException, InterruptedException {
|