Browse Source

[ML] fixing inference reference counting tests (#59453)

Benjamin Trent 5 years ago
parent
commit
ac8715acc7

+ 14 - 16
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java

@@ -452,7 +452,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
         });
     }
 
-    public void testReferenceCounting() throws ExecutionException, InterruptedException, IOException {
+    public void testReferenceCounting() throws Exception {
         String modelId = "test-reference-counting";
         withTrainedModel(modelId, 1L);
 
@@ -469,25 +469,24 @@ public class ModelLoadingServiceTests extends ESTestCase {
 
         PlainActionFuture<LocalModel> forPipeline = new PlainActionFuture<>();
         modelLoadingService.getModelForPipeline(modelId, forPipeline);
-        LocalModel model = forPipeline.get();
-        assertEquals(2, model.getReferenceCount());
+        final LocalModel model = forPipeline.get();
+        assertBusy(() -> assertEquals(2, model.getReferenceCount()));
 
         PlainActionFuture<LocalModel> forSearch = new PlainActionFuture<>();
         modelLoadingService.getModelForPipeline(modelId, forSearch);
-        model = forSearch.get();
-        assertEquals(3, model.getReferenceCount());
+        forSearch.get();
+        assertBusy(() -> assertEquals(3, model.getReferenceCount()));
 
         model.release();
-        assertEquals(2, model.getReferenceCount());
+        assertBusy(() -> assertEquals(2, model.getReferenceCount()));
 
         PlainActionFuture<LocalModel> forSearch2 = new PlainActionFuture<>();
         modelLoadingService.getModelForPipeline(modelId, forSearch2);
-        model = forSearch2.get();
-        assertEquals(3, model.getReferenceCount());
+        forSearch2.get();
+        assertBusy(() -> assertEquals(3, model.getReferenceCount()));
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/59445")
-    public void testReferenceCountingForPipeline() throws ExecutionException, InterruptedException, IOException {
+    public void testReferenceCountingForPipeline() throws Exception {
         String modelId = "test-reference-counting-for-pipeline";
         withTrainedModel(modelId, 1L);
 
@@ -504,18 +503,17 @@ public class ModelLoadingServiceTests extends ESTestCase {
 
         PlainActionFuture<LocalModel> forPipeline = new PlainActionFuture<>();
         modelLoadingService.getModelForPipeline(modelId, forPipeline);
-        LocalModel model = forPipeline.get();
-        assertEquals(2, model.getReferenceCount());
+        final LocalModel model = forPipeline.get();
+        assertBusy(() -> assertEquals(2, model.getReferenceCount()));
 
         PlainActionFuture<LocalModel> forPipeline2 = new PlainActionFuture<>();
         modelLoadingService.getModelForPipeline(modelId, forPipeline2);
-        model = forPipeline2.get();
-        assertEquals(3, model.getReferenceCount());
+        forPipeline2.get();
+        assertBusy(() -> assertEquals(3, model.getReferenceCount()));
 
         // will cause the model to be evicted
         modelLoadingService.clusterChanged(ingestChangedEvent());
-
-        assertEquals(2, model.getReferenceCount());
+        assertBusy(() -> assertEquals(2, model.getReferenceCount()));
     }
 
     public void testReferenceCounting_ModelIsNotCached() throws ExecutionException, InterruptedException {