Browse Source

[ML] Fix streaming test regex (#115589) (#116102)

Replace regex with string parsing.

Fix #114788

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Pat Whelan 11 months ago
parent
commit
f910b9215e

+ 12 - 11
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

@@ -64,7 +64,6 @@ import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Predicate;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
 import java.util.function.Supplier;
-import java.util.regex.Pattern;
 
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.is;
@@ -414,17 +413,19 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
         assertThat(collector.stringsVerified.getLast(), equalTo(expectedExceptionAsServerSentEvent));
         assertThat(collector.stringsVerified.getLast(), equalTo(expectedExceptionAsServerSentEvent));
     }
     }
 
 
-    public void testNoStream() throws IOException {
-        var pattern = Pattern.compile("^\uFEFFevent: message\ndata: \\{\"result\":\".*\"}\n\n\uFEFFevent: message\ndata: \\[DONE]\n\n$");
+    public void testNoStream() {
+        var collector = new RandomStringCollector();
+        var expectedTestCount = randomIntBetween(2, 30);
         var request = new Request(RestRequest.Method.POST.name(), NO_STREAM_ROUTE);
         var request = new Request(RestRequest.Method.POST.name(), NO_STREAM_ROUTE);
-        var response = getRestClient().performRequest(request);
-        assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_OK));
-        var responseString = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
-
-        assertThat(
-            "Expected " + responseString + " to match pattern " + pattern.pattern(),
-            pattern.matcher(responseString).matches(),
-            is(true)
+        request.setOptions(
+            RequestOptions.DEFAULT.toBuilder()
+                .setHttpAsyncResponseConsumerFactory(() -> new AsyncResponseConsumer(collector))
+                .addParameter(REQUEST_COUNT, String.valueOf(expectedTestCount))
+                .build()
         );
         );
+        var response = callAsync(request);
+        assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_OK));
+        assertThat(collector.stringsVerified.size(), equalTo(2)); // single payload count + done byte
+        assertThat(collector.stringsVerified.peekLast(), equalTo("[DONE]"));
     }
     }
 }
 }