|
@@ -71,31 +71,29 @@ public abstract class SenderService implements InferenceService {
|
|
|
ActionListener<InferenceServiceResults> listener
|
|
|
) {
|
|
|
init();
|
|
|
- var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList();
|
|
|
- var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream);
|
|
|
+ var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
|
|
|
doInfer(model, inferenceInput, taskSettings, timeout, listener);
|
|
|
}
|
|
|
|
|
|
private static InferenceInputs createInput(
|
|
|
SenderService service,
|
|
|
Model model,
|
|
|
- List<ChunkInferenceInput> input,
|
|
|
+ List<String> input,
|
|
|
InputType inputType,
|
|
|
@Nullable String query,
|
|
|
@Nullable Boolean returnDocuments,
|
|
|
@Nullable Integer topN,
|
|
|
boolean stream
|
|
|
) {
|
|
|
- List<String> textInput = ChunkInferenceInput.inputs(input);
|
|
|
return switch (model.getTaskType()) {
|
|
|
- case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream);
|
|
|
+ case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
|
|
|
case RERANK -> {
|
|
|
ValidationException validationException = new ValidationException();
|
|
|
service.validateRerankParameters(returnDocuments, topN, validationException);
|
|
|
if (validationException.validationErrors().isEmpty() == false) {
|
|
|
throw validationException;
|
|
|
}
|
|
|
- yield new QueryAndDocsInputs(query, textInput, returnDocuments, topN, stream);
|
|
|
+ yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream);
|
|
|
}
|
|
|
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> {
|
|
|
ValidationException validationException = new ValidationException();
|
|
@@ -103,7 +101,7 @@ public abstract class SenderService implements InferenceService {
|
|
|
if (validationException.validationErrors().isEmpty() == false) {
|
|
|
throw validationException;
|
|
|
}
|
|
|
- yield new EmbeddingsInput(input, inputType, stream);
|
|
|
+ yield new EmbeddingsInput(input, null, inputType, stream);
|
|
|
}
|
|
|
default -> throw new ElasticsearchStatusException(
|
|
|
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
|