|
|
@@ -57,6 +57,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
public static class Request extends ActionRequest implements IndicesRequest.Replaceable {
|
|
|
|
|
|
public static final ParseField QUERY_STRING = new ParseField("query_string"); // TODO a better name and update docs when changed
|
|
|
+ public static final ParseField TEXT_EMBEDDING_CONFIG = new ParseField("text_embedding_config");
|
|
|
|
|
|
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME);
|
|
|
|
|
|
@@ -67,15 +68,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
PARSER.declareObject(
|
|
|
Request.Builder::setUpdate,
|
|
|
(p, c) -> TextEmbeddingConfigUpdate.fromXContentStrict(p),
|
|
|
- InferTrainedModelDeploymentAction.Request.INFERENCE_CONFIG
|
|
|
+ TEXT_EMBEDDING_CONFIG
|
|
|
);
|
|
|
- PARSER.declareObject(Request.Builder::setKnnSearch, (p, c) -> KnnQueryOptions.fromXContent(p), SearchSourceBuilder.KNN_FIELD);
|
|
|
- PARSER.declareFieldArray(
|
|
|
- Request.Builder::setFilters,
|
|
|
+ PARSER.declareObject(
|
|
|
+ Request.Builder::setQueryBuilder,
|
|
|
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
|
|
|
- KnnSearchBuilder.FILTER_FIELD,
|
|
|
- ObjectParser.ValueType.OBJECT_ARRAY
|
|
|
+ SearchSourceBuilder.QUERY_FIELD
|
|
|
);
|
|
|
+ PARSER.declareObject(Request.Builder::setKnnSearch, (p, c) -> KnnQueryOptions.fromXContent(p), SearchSourceBuilder.KNN_FIELD);
|
|
|
PARSER.declareField(
|
|
|
(p, request, c) -> request.setFetchSource(FetchSourceContext.fromXContent(p)),
|
|
|
SearchSourceBuilder._SOURCE_FIELD,
|
|
|
@@ -99,16 +99,21 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
SearchSourceBuilder.STORED_FIELDS_FIELD,
|
|
|
ObjectParser.ValueType.STRING_ARRAY
|
|
|
);
|
|
|
+ PARSER.declareInt(Request.Builder::setSize, SearchSourceBuilder.SIZE_FIELD);
|
|
|
}
|
|
|
|
|
|
public static Request parseRestRequest(RestRequest restRequest) throws IOException {
|
|
|
Builder builder = new Builder(Strings.splitStringByCommaToArray(restRequest.param("index")));
|
|
|
- builder.setRouting(restRequest.param("routing"));
|
|
|
if (restRequest.hasContentOrSourceParam()) {
|
|
|
try (XContentParser contentParser = restRequest.contentOrSourceParamParser()) {
|
|
|
PARSER.parse(contentParser, builder, null);
|
|
|
}
|
|
|
}
|
|
|
+ // Query parameters are preferred to body parameters.
|
|
|
+ if (restRequest.hasParam("size")) {
|
|
|
+ builder.setSize(restRequest.paramAsInt("size", -1));
|
|
|
+ }
|
|
|
+ builder.setRouting(restRequest.param("routing"));
|
|
|
return builder.build();
|
|
|
}
|
|
|
|
|
|
@@ -117,13 +122,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
private final String queryString;
|
|
|
private final String modelId;
|
|
|
private final TimeValue inferenceTimeout;
|
|
|
+ private final QueryBuilder query;
|
|
|
private final KnnQueryOptions knnQueryOptions;
|
|
|
private final TextEmbeddingConfigUpdate embeddingConfig;
|
|
|
- private final List<QueryBuilder> filters;
|
|
|
private final FetchSourceContext fetchSource;
|
|
|
private final List<FieldAndFormat> fields;
|
|
|
private final List<FieldAndFormat> docValueFields;
|
|
|
private final StoredFieldsContext storedFields;
|
|
|
+ private final int size;
|
|
|
|
|
|
public Request(StreamInput in) throws IOException {
|
|
|
super(in);
|
|
|
@@ -132,17 +138,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
queryString = in.readString();
|
|
|
modelId = in.readString();
|
|
|
inferenceTimeout = in.readOptionalTimeValue();
|
|
|
+ query = in.readOptionalNamedWriteable(QueryBuilder.class);
|
|
|
knnQueryOptions = new KnnQueryOptions(in);
|
|
|
embeddingConfig = in.readOptionalWriteable(TextEmbeddingConfigUpdate::new);
|
|
|
- if (in.readBoolean()) {
|
|
|
- filters = in.readNamedWriteableList(QueryBuilder.class);
|
|
|
- } else {
|
|
|
- filters = null;
|
|
|
- }
|
|
|
fetchSource = in.readOptionalWriteable(FetchSourceContext::readFrom);
|
|
|
fields = in.readOptionalList(FieldAndFormat::new);
|
|
|
docValueFields = in.readOptionalList(FieldAndFormat::new);
|
|
|
storedFields = in.readOptionalWriteable(StoredFieldsContext::new);
|
|
|
+ size = in.readInt();
|
|
|
}
|
|
|
|
|
|
Request(
|
|
|
@@ -150,27 +153,29 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
String routing,
|
|
|
String queryString,
|
|
|
String modelId,
|
|
|
+ QueryBuilder query,
|
|
|
KnnQueryOptions knnQueryOptions,
|
|
|
TextEmbeddingConfigUpdate embeddingConfig,
|
|
|
TimeValue inferenceTimeout,
|
|
|
- List<QueryBuilder> filters,
|
|
|
FetchSourceContext fetchSource,
|
|
|
List<FieldAndFormat> fields,
|
|
|
List<FieldAndFormat> docValueFields,
|
|
|
- StoredFieldsContext storedFields
|
|
|
+ StoredFieldsContext storedFields,
|
|
|
+ int size
|
|
|
) {
|
|
|
this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
|
|
|
this.routing = routing;
|
|
|
this.queryString = queryString;
|
|
|
this.modelId = modelId;
|
|
|
+ this.query = query;
|
|
|
this.knnQueryOptions = knnQueryOptions;
|
|
|
this.embeddingConfig = embeddingConfig;
|
|
|
this.inferenceTimeout = inferenceTimeout;
|
|
|
- this.filters = filters;
|
|
|
this.fetchSource = fetchSource;
|
|
|
this.fields = fields;
|
|
|
this.docValueFields = docValueFields;
|
|
|
this.storedFields = storedFields;
|
|
|
+ this.size = size;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
@@ -181,18 +186,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
out.writeString(queryString);
|
|
|
out.writeString(modelId);
|
|
|
out.writeOptionalTimeValue(inferenceTimeout);
|
|
|
+ out.writeOptionalNamedWriteable(query);
|
|
|
knnQueryOptions.writeTo(out);
|
|
|
out.writeOptionalWriteable(embeddingConfig);
|
|
|
- if (filters != null) {
|
|
|
- out.writeBoolean(true);
|
|
|
- out.writeNamedWriteableList(filters);
|
|
|
- } else {
|
|
|
- out.writeBoolean(false);
|
|
|
- }
|
|
|
out.writeOptionalWriteable(fetchSource);
|
|
|
out.writeOptionalCollection(fields);
|
|
|
out.writeOptionalCollection(docValueFields);
|
|
|
out.writeOptionalWriteable(storedFields);
|
|
|
+ out.writeInt(size);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
@@ -231,6 +232,10 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
return inferenceTimeout;
|
|
|
}
|
|
|
|
|
|
+ public QueryBuilder getQuery() {
|
|
|
+ return query;
|
|
|
+ }
|
|
|
+
|
|
|
public KnnQueryOptions getKnnQueryOptions() {
|
|
|
return knnQueryOptions;
|
|
|
}
|
|
|
@@ -239,10 +244,6 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
return embeddingConfig;
|
|
|
}
|
|
|
|
|
|
- public List<QueryBuilder> getFilters() {
|
|
|
- return filters;
|
|
|
- }
|
|
|
-
|
|
|
public FetchSourceContext getFetchSource() {
|
|
|
return fetchSource;
|
|
|
}
|
|
|
@@ -259,6 +260,10 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
return storedFields;
|
|
|
}
|
|
|
|
|
|
+ public int getSize() {
|
|
|
+ return size;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public boolean equals(Object o) {
|
|
|
if (this == o) return true;
|
|
|
@@ -269,13 +274,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
&& Objects.equals(queryString, request.queryString)
|
|
|
&& Objects.equals(modelId, request.modelId)
|
|
|
&& Objects.equals(inferenceTimeout, request.inferenceTimeout)
|
|
|
+ && Objects.equals(query, request.query)
|
|
|
&& Objects.equals(knnQueryOptions, request.knnQueryOptions)
|
|
|
&& Objects.equals(embeddingConfig, request.embeddingConfig)
|
|
|
- && Objects.equals(filters, request.filters)
|
|
|
&& Objects.equals(fetchSource, request.fetchSource)
|
|
|
&& Objects.equals(fields, request.fields)
|
|
|
&& Objects.equals(docValueFields, request.docValueFields)
|
|
|
- && Objects.equals(storedFields, request.storedFields);
|
|
|
+ && Objects.equals(storedFields, request.storedFields)
|
|
|
+ && size == request.size;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
@@ -285,13 +291,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
queryString,
|
|
|
modelId,
|
|
|
inferenceTimeout,
|
|
|
+ query,
|
|
|
knnQueryOptions,
|
|
|
embeddingConfig,
|
|
|
- filters,
|
|
|
fetchSource,
|
|
|
fields,
|
|
|
docValueFields,
|
|
|
- storedFields
|
|
|
+ storedFields,
|
|
|
+ size
|
|
|
);
|
|
|
result = 31 * result + Arrays.hashCode(indices);
|
|
|
return result;
|
|
|
@@ -321,12 +328,13 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
private String queryString;
|
|
|
private TimeValue timeout;
|
|
|
private TextEmbeddingConfigUpdate update;
|
|
|
+ private QueryBuilder queryBuilder;
|
|
|
private KnnQueryOptions knnSearchBuilder;
|
|
|
- private List<QueryBuilder> filters;
|
|
|
private FetchSourceContext fetchSource;
|
|
|
private List<FieldAndFormat> fields;
|
|
|
private List<FieldAndFormat> docValueFields;
|
|
|
private StoredFieldsContext storedFields;
|
|
|
+ private int size = -1;
|
|
|
|
|
|
Builder(String[] indices) {
|
|
|
this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
|
|
|
@@ -360,8 +368,8 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
this.knnSearchBuilder = knnSearchBuilder;
|
|
|
}
|
|
|
|
|
|
- private void setFilters(List<QueryBuilder> filters) {
|
|
|
- this.filters = filters;
|
|
|
+ void setQueryBuilder(QueryBuilder queryBuilder) {
|
|
|
+ this.queryBuilder = queryBuilder;
|
|
|
}
|
|
|
|
|
|
private void setFetchSource(FetchSourceContext fetchSource) {
|
|
|
@@ -380,20 +388,25 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
this.storedFields = storedFields;
|
|
|
}
|
|
|
|
|
|
+ private void setSize(int size) {
|
|
|
+ this.size = size;
|
|
|
+ }
|
|
|
+
|
|
|
Request build() {
|
|
|
return new Request(
|
|
|
indices,
|
|
|
routing,
|
|
|
queryString,
|
|
|
modelId,
|
|
|
+ queryBuilder,
|
|
|
knnSearchBuilder,
|
|
|
update,
|
|
|
timeout,
|
|
|
- filters,
|
|
|
fetchSource,
|
|
|
fields,
|
|
|
docValueFields,
|
|
|
- storedFields
|
|
|
+ storedFields,
|
|
|
+ size
|
|
|
);
|
|
|
}
|
|
|
}
|
|
|
@@ -528,7 +541,12 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
|
|
|
if (queryVector == null) {
|
|
|
throw new IllegalStateException("[query_vector] not set on the Knn query");
|
|
|
}
|
|
|
- return new KnnSearchBuilder(field, queryVector, k, numCands);
|
|
|
+ var builder = new KnnSearchBuilder(field, queryVector, k, numCands);
|
|
|
+ builder.boost(boost);
|
|
|
+ if (filterQueries.isEmpty() == false) {
|
|
|
+ builder.addFilterQueries(filterQueries);
|
|
|
+ }
|
|
|
+ return builder;
|
|
|
}
|
|
|
|
|
|
@Override
|