|
@@ -44,6 +44,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
randomFrom(TaskType.values()),
|
|
|
randomAlphaOfLength(6),
|
|
|
randomAlphaOfLengthOrNull(10),
|
|
|
+ randomBoolean(),
|
|
|
+ randomIntBetween(0, 10),
|
|
|
randomList(1, 5, () -> randomAlphaOfLength(8)),
|
|
|
randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
|
|
|
randomFrom(InputType.values()),
|
|
@@ -85,6 +87,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.TEXT_EMBEDDING,
|
|
|
"model",
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of("input"),
|
|
|
null,
|
|
|
null,
|
|
@@ -100,6 +104,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.RERANK,
|
|
|
"model",
|
|
|
"query",
|
|
|
+ Boolean.TRUE,
|
|
|
+ 34,
|
|
|
List.of("input"),
|
|
|
null,
|
|
|
null,
|
|
@@ -119,6 +125,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
false
|
|
|
);
|
|
|
ActionRequestValidationException inputNullError = inputNullRequest.validate();
|
|
@@ -131,6 +139,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.TEXT_EMBEDDING,
|
|
|
"model",
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of(),
|
|
|
null,
|
|
|
null,
|
|
@@ -142,11 +152,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
assertThat(inputEmptyError.getMessage(), is("Validation Failed: 1: Field [input] cannot be an empty array;"));
|
|
|
}
|
|
|
|
|
|
+ public void testValidation_TextEmbedding_WithReturnDocument() {
|
|
|
+ InferenceAction.Request inputRequest = new InferenceAction.Request(
|
|
|
+ TaskType.TEXT_EMBEDDING,
|
|
|
+ "model",
|
|
|
+ null,
|
|
|
+ Boolean.TRUE,
|
|
|
+ null,
|
|
|
+ List.of("input"),
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ false
|
|
|
+ );
|
|
|
+ ActionRequestValidationException inputError = inputRequest.validate();
|
|
|
+ assertNotNull(inputError);
|
|
|
+ assertThat(
|
|
|
+ inputError.getMessage(),
|
|
|
+ is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [text_embedding];")
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testValidation_TextEmbedding_WithTopN() {
|
|
|
+ InferenceAction.Request inputRequest = new InferenceAction.Request(
|
|
|
+ TaskType.TEXT_EMBEDDING,
|
|
|
+ "model",
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ 12,
|
|
|
+ List.of("input"),
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ false
|
|
|
+ );
|
|
|
+ ActionRequestValidationException inputError = inputRequest.validate();
|
|
|
+ assertNotNull(inputError);
|
|
|
+ assertThat(inputError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [text_embedding];"));
|
|
|
+ }
|
|
|
+
|
|
|
public void testValidation_Rerank_Null() {
|
|
|
InferenceAction.Request queryNullRequest = new InferenceAction.Request(
|
|
|
TaskType.RERANK,
|
|
|
"model",
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of("input"),
|
|
|
null,
|
|
|
null,
|
|
@@ -163,6 +214,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.RERANK,
|
|
|
"model",
|
|
|
"",
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of("input"),
|
|
|
null,
|
|
|
null,
|
|
@@ -179,6 +232,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.RERANK,
|
|
|
"model",
|
|
|
"query",
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of("input"),
|
|
|
null,
|
|
|
InputType.SEARCH,
|
|
@@ -195,6 +250,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.SPARSE_EMBEDDING,
|
|
|
"model",
|
|
|
"",
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of("input"),
|
|
|
null,
|
|
|
InputType.SEARCH,
|
|
@@ -209,11 +266,56 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ public void testValidation_SparseEmbedding_WithReturnDocument() {
|
|
|
+ InferenceAction.Request queryRequest = new InferenceAction.Request(
|
|
|
+ TaskType.SPARSE_EMBEDDING,
|
|
|
+ "model",
|
|
|
+ "",
|
|
|
+ Boolean.FALSE,
|
|
|
+ null,
|
|
|
+ List.of("input"),
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ false
|
|
|
+ );
|
|
|
+ ActionRequestValidationException queryError = queryRequest.validate();
|
|
|
+ assertNotNull(queryError);
|
|
|
+ assertThat(
|
|
|
+ queryError.getMessage(),
|
|
|
+ is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [sparse_embedding];")
|
|
|
+ );
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testValidation_SparseEmbedding_WithTopN() {
|
|
|
+ InferenceAction.Request queryRequest = new InferenceAction.Request(
|
|
|
+ TaskType.SPARSE_EMBEDDING,
|
|
|
+ "model",
|
|
|
+ "",
|
|
|
+ null,
|
|
|
+ 22,
|
|
|
+ List.of("input"),
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ false
|
|
|
+ );
|
|
|
+ ActionRequestValidationException queryError = queryRequest.validate();
|
|
|
+ assertNotNull(queryError);
|
|
|
+ assertThat(
|
|
|
+ queryError.getMessage(),
|
|
|
+ is("Validation Failed: 1: Field [top_n] cannot be specified for task type [sparse_embedding];")
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
public void testValidation_Completion_WithInputType() {
|
|
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
|
|
TaskType.COMPLETION,
|
|
|
"model",
|
|
|
"",
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of("input"),
|
|
|
null,
|
|
|
InputType.SEARCH,
|
|
@@ -225,11 +327,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];"));
|
|
|
}
|
|
|
|
|
|
+ public void testValidation_Completion_WithReturnDocuments() {
|
|
|
+ InferenceAction.Request queryRequest = new InferenceAction.Request(
|
|
|
+ TaskType.COMPLETION,
|
|
|
+ "model",
|
|
|
+ "",
|
|
|
+ Boolean.TRUE,
|
|
|
+ null,
|
|
|
+ List.of("input"),
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ false
|
|
|
+ );
|
|
|
+ ActionRequestValidationException queryError = queryRequest.validate();
|
|
|
+ assertNotNull(queryError);
|
|
|
+ assertThat(
|
|
|
+ queryError.getMessage(),
|
|
|
+ is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [completion];")
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testValidation_Completion_WithTopN() {
|
|
|
+ InferenceAction.Request queryRequest = new InferenceAction.Request(
|
|
|
+ TaskType.COMPLETION,
|
|
|
+ "model",
|
|
|
+ "",
|
|
|
+ null,
|
|
|
+ 77,
|
|
|
+ List.of("input"),
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ false
|
|
|
+ );
|
|
|
+ ActionRequestValidationException queryError = queryRequest.validate();
|
|
|
+ assertNotNull(queryError);
|
|
|
+ assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [completion];"));
|
|
|
+ }
|
|
|
+
|
|
|
public void testValidation_ChatCompletion_WithInputType() {
|
|
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
|
|
TaskType.CHAT_COMPLETION,
|
|
|
"model",
|
|
|
"",
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of("input"),
|
|
|
null,
|
|
|
InputType.SEARCH,
|
|
@@ -244,6 +387,45 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ public void testValidation_ChatCompletion_WithReturnDocuments() {
|
|
|
+ InferenceAction.Request queryRequest = new InferenceAction.Request(
|
|
|
+ TaskType.CHAT_COMPLETION,
|
|
|
+ "model",
|
|
|
+ "",
|
|
|
+ Boolean.TRUE,
|
|
|
+ null,
|
|
|
+ List.of("input"),
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ false
|
|
|
+ );
|
|
|
+ ActionRequestValidationException queryError = queryRequest.validate();
|
|
|
+ assertNotNull(queryError);
|
|
|
+ assertThat(
|
|
|
+ queryError.getMessage(),
|
|
|
+ is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [chat_completion];")
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testValidation_ChatCompletion_WithTopN() {
|
|
|
+ InferenceAction.Request queryRequest = new InferenceAction.Request(
|
|
|
+ TaskType.CHAT_COMPLETION,
|
|
|
+ "model",
|
|
|
+ "",
|
|
|
+ null,
|
|
|
+ 11,
|
|
|
+ List.of("input"),
|
|
|
+ null,
|
|
|
+ InputType.SEARCH,
|
|
|
+ null,
|
|
|
+ false
|
|
|
+ );
|
|
|
+ ActionRequestValidationException queryError = queryRequest.validate();
|
|
|
+ assertNotNull(queryError);
|
|
|
+ assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [chat_completion];"));
|
|
|
+ }
|
|
|
+
|
|
|
public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
|
|
|
String singleInputRequest = """
|
|
|
{
|
|
@@ -271,6 +453,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
nextTask,
|
|
|
instance.getInferenceEntityId(),
|
|
|
instance.getQuery(),
|
|
|
+ instance.getReturnDocuments(),
|
|
|
+ instance.getTopN(),
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
instance.getInputType(),
|
|
@@ -283,6 +467,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId() + "foo",
|
|
|
instance.getQuery(),
|
|
|
+ instance.getReturnDocuments(),
|
|
|
+ instance.getTopN(),
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
instance.getInputType(),
|
|
@@ -297,6 +483,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
instance.getQuery(),
|
|
|
+ instance.getReturnDocuments(),
|
|
|
+ instance.getTopN(),
|
|
|
changedInputs,
|
|
|
instance.getTaskSettings(),
|
|
|
instance.getInputType(),
|
|
@@ -317,6 +505,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
instance.getQuery(),
|
|
|
+ instance.getReturnDocuments(),
|
|
|
+ instance.getTopN(),
|
|
|
instance.getInput(),
|
|
|
taskSettings,
|
|
|
instance.getInputType(),
|
|
@@ -331,6 +521,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
instance.getQuery(),
|
|
|
+ instance.getReturnDocuments(),
|
|
|
+ instance.getTopN(),
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
nextInputType,
|
|
@@ -343,6 +535,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
instance.getQuery() == null ? randomAlphaOfLength(10) : instance.getQuery() + randomAlphaOfLength(1),
|
|
|
+ instance.getReturnDocuments(),
|
|
|
+ instance.getTopN(),
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
instance.getInputType(),
|
|
@@ -360,6 +554,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
instance.getQuery(),
|
|
|
+ instance.getReturnDocuments(),
|
|
|
+ instance.getTopN(),
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
instance.getInputType(),
|
|
@@ -374,6 +570,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
instance.getQuery(),
|
|
|
+ instance.getReturnDocuments(),
|
|
|
+ instance.getTopN(),
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
instance.getInputType(),
|
|
@@ -395,6 +593,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
instance.getInput().subList(0, 1),
|
|
|
instance.getTaskSettings(),
|
|
|
InputType.UNSPECIFIED,
|
|
@@ -406,6 +606,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
InputType.UNSPECIFIED,
|
|
@@ -420,6 +622,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
InputType.INGEST,
|
|
@@ -432,6 +636,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
InputType.UNSPECIFIED,
|
|
@@ -443,6 +649,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
instance.getInputType(),
|
|
@@ -455,6 +663,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
instance.getTaskType(),
|
|
|
instance.getInferenceEntityId(),
|
|
|
instance.getQuery(),
|
|
|
+ null,
|
|
|
+ null,
|
|
|
instance.getInput(),
|
|
|
instance.getTaskSettings(),
|
|
|
instance.getInputType(),
|
|
@@ -462,9 +672,24 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
false,
|
|
|
InferenceContext.EMPTY_INSTANCE
|
|
|
);
|
|
|
- } else {
|
|
|
- mutated = instance;
|
|
|
- }
|
|
|
+ } else if (version.before(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
|
|
|
+ && version.isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19) == false) {
|
|
|
+ mutated = new InferenceAction.Request(
|
|
|
+ instance.getTaskType(),
|
|
|
+ instance.getInferenceEntityId(),
|
|
|
+ instance.getQuery(),
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ instance.getInput(),
|
|
|
+ instance.getTaskSettings(),
|
|
|
+ instance.getInputType(),
|
|
|
+ instance.getInferenceTimeout(),
|
|
|
+ false,
|
|
|
+ instance.getContext()
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ mutated = instance;
|
|
|
+ }
|
|
|
|
|
|
// We always assume that a request has been rerouted, if it came from a node without adaptive rate limiting
|
|
|
if (version.before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
|
|
@@ -481,6 +706,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.TEXT_EMBEDDING,
|
|
|
"model",
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of(),
|
|
|
Map.of(),
|
|
|
InputType.UNSPECIFIED,
|
|
@@ -503,6 +730,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.TEXT_EMBEDDING,
|
|
|
"model",
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of(),
|
|
|
Map.of(),
|
|
|
InputType.INGEST,
|
|
@@ -525,6 +754,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.TEXT_EMBEDDING,
|
|
|
"model",
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of("input"),
|
|
|
Map.of(),
|
|
|
InputType.UNSPECIFIED,
|
|
@@ -548,6 +779,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
TaskType.TEXT_EMBEDDING,
|
|
|
"model",
|
|
|
null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
List.of("input"),
|
|
|
Map.of(),
|
|
|
InputType.UNSPECIFIED,
|