|
@@ -49,6 +49,7 @@ import static org.hamcrest.Matchers.hasSize;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
|
import static org.hamcrest.Matchers.nullValue;
|
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.ArgumentMatchers.anyBoolean;
|
|
|
import static org.mockito.Mockito.doAnswer;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.when;
|
|
@@ -855,12 +856,11 @@ public class ServiceUtilsTests extends ESTestCase {
|
|
|
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
|
|
|
|
|
|
doAnswer(invocation -> {
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
- ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
|
|
|
+ ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
|
|
|
listener.onResponse(new InferenceTextEmbeddingFloatResults(List.of()));
|
|
|
|
|
|
return Void.TYPE;
|
|
|
- }).when(service).infer(any(), any(), any(), any(), any(), any(), any());
|
|
|
+ }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
|
|
|
|
|
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
|
|
getEmbeddingSize(model, service, listener);
|
|
@@ -878,12 +878,11 @@ public class ServiceUtilsTests extends ESTestCase {
|
|
|
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
|
|
|
|
|
|
doAnswer(invocation -> {
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
- ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
|
|
|
+ ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
|
|
|
listener.onResponse(new InferenceTextEmbeddingByteResults(List.of()));
|
|
|
|
|
|
return Void.TYPE;
|
|
|
- }).when(service).infer(any(), any(), any(), any(), any(), any(), any());
|
|
|
+ }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
|
|
|
|
|
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
|
|
getEmbeddingSize(model, service, listener);
|
|
@@ -903,12 +902,11 @@ public class ServiceUtilsTests extends ESTestCase {
|
|
|
var textEmbedding = TextEmbeddingResultsTests.createRandomResults();
|
|
|
|
|
|
doAnswer(invocation -> {
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
- ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
|
|
|
+ ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
|
|
|
listener.onResponse(textEmbedding);
|
|
|
|
|
|
return Void.TYPE;
|
|
|
- }).when(service).infer(any(), any(), any(), any(), any(), any(), any());
|
|
|
+ }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
|
|
|
|
|
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
|
|
getEmbeddingSize(model, service, listener);
|
|
@@ -927,12 +925,11 @@ public class ServiceUtilsTests extends ESTestCase {
|
|
|
var textEmbedding = InferenceTextEmbeddingByteResultsTests.createRandomResults();
|
|
|
|
|
|
doAnswer(invocation -> {
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
- ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
|
|
|
+ ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
|
|
|
listener.onResponse(textEmbedding);
|
|
|
|
|
|
return Void.TYPE;
|
|
|
- }).when(service).infer(any(), any(), any(), any(), any(), any(), any());
|
|
|
+ }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
|
|
|
|
|
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
|
|
getEmbeddingSize(model, service, listener);
|