|
@@ -8,6 +8,9 @@
|
|
|
|
|
|
package org.elasticsearch.search.vectors;
|
|
|
|
|
|
+import org.apache.lucene.util.SetOnce;
|
|
|
+import org.elasticsearch.TransportVersion;
|
|
|
+import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.common.io.stream.Writeable;
|
|
@@ -27,8 +30,11 @@ import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.List;
|
|
|
import java.util.Objects;
|
|
|
+import java.util.function.Supplier;
|
|
|
|
|
|
+import static org.elasticsearch.common.Strings.format;
|
|
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
|
|
+import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
|
|
|
|
|
/**
|
|
|
* Defines a kNN search to run in the search request.
|
|
@@ -39,6 +45,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
public static final ParseField K_FIELD = new ParseField("k");
|
|
|
public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
|
|
|
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
|
|
|
+ public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
|
|
|
public static final ParseField FILTER_FIELD = new ParseField("filter");
|
|
|
public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD;
|
|
|
|
|
@@ -46,18 +53,28 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
@SuppressWarnings("unchecked")
|
|
|
// TODO optimize parsing for when BYTE values are provided
|
|
|
List<Float> vector = (List<Float>) args[1];
|
|
|
- float[] vectorArray = new float[vector.size()];
|
|
|
- for (int i = 0; i < vector.size(); i++) {
|
|
|
- vectorArray[i] = vector.get(i);
|
|
|
+ final float[] vectorArray;
|
|
|
+ if (vector != null) {
|
|
|
+ vectorArray = new float[vector.size()];
|
|
|
+ for (int i = 0; i < vector.size(); i++) {
|
|
|
+ vectorArray[i] = vector.get(i);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ vectorArray = null;
|
|
|
}
|
|
|
- return new KnnSearchBuilder((String) args[0], vectorArray, (int) args[2], (int) args[3]);
|
|
|
+ return new KnnSearchBuilder((String) args[0], vectorArray, (QueryVectorBuilder) args[4], (int) args[2], (int) args[3]);
|
|
|
});
|
|
|
|
|
|
static {
|
|
|
PARSER.declareString(constructorArg(), FIELD_FIELD);
|
|
|
- PARSER.declareFloatArray(constructorArg(), QUERY_VECTOR_FIELD);
|
|
|
+ PARSER.declareFloatArray(optionalConstructorArg(), QUERY_VECTOR_FIELD);
|
|
|
PARSER.declareInt(constructorArg(), K_FIELD);
|
|
|
PARSER.declareInt(constructorArg(), NUM_CANDS_FIELD);
|
|
|
+ PARSER.declareNamedObject(
|
|
|
+ optionalConstructorArg(),
|
|
|
+ (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
|
|
|
+ QUERY_VECTOR_BUILDER_FIELD
|
|
|
+ );
|
|
|
PARSER.declareFieldArray(
|
|
|
KnnSearchBuilder::addFilterQueries,
|
|
|
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
|
|
@@ -73,6 +90,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
|
|
|
final String field;
|
|
|
final float[] queryVector;
|
|
|
+ final QueryVectorBuilder queryVectorBuilder;
|
|
|
+ private final Supplier<float[]> querySupplier;
|
|
|
final int k;
|
|
|
final int numCands;
|
|
|
final List<QueryBuilder> filterQueries;
|
|
@@ -87,6 +106,27 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
* @param numCands the number of nearest neighbor candidates to consider per shard
|
|
|
*/
|
|
|
public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands) {
|
|
|
+ this(field, Objects.requireNonNull(queryVector, format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Defines a kNN search where the query vector will be provided by the queryVectorBuilder
|
|
|
+ * @param field the name of the vector field to search against
|
|
|
+ * @param queryVectorBuilder the query vector builder
|
|
|
+ * @param k the final number of nearest neighbors to return as top hits
|
|
|
+ * @param numCands the number of nearest neighbor candidates to consider per shard
|
|
|
+ */
|
|
|
+ public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands) {
|
|
|
+ this(
|
|
|
+ field,
|
|
|
+ null,
|
|
|
+ Objects.requireNonNull(queryVectorBuilder, format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())),
|
|
|
+ k,
|
|
|
+ numCands
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ private KnnSearchBuilder(String field, float[] queryVector, QueryVectorBuilder queryVectorBuilder, int k, int numCands) {
|
|
|
if (k < 1) {
|
|
|
throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
|
|
|
}
|
|
@@ -98,11 +138,41 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
if (numCands > NUM_CANDS_LIMIT) {
|
|
|
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
|
|
|
}
|
|
|
+ if (queryVector == null && queryVectorBuilder == null) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ format(
|
|
|
+ "either [%s] or [%s] must be provided",
|
|
|
+ QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
|
|
|
+ QUERY_VECTOR_FIELD.getPreferredName()
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+ if (queryVector != null && queryVectorBuilder != null) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ format(
|
|
|
+ "cannot provide both [%s] and [%s]",
|
|
|
+ QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
|
|
|
+ QUERY_VECTOR_FIELD.getPreferredName()
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
this.field = field;
|
|
|
- this.queryVector = queryVector;
|
|
|
+ this.queryVector = queryVector == null ? new float[0] : queryVector;
|
|
|
+ this.queryVectorBuilder = queryVectorBuilder;
|
|
|
this.k = k;
|
|
|
this.numCands = numCands;
|
|
|
this.filterQueries = new ArrayList<>();
|
|
|
+ this.querySupplier = null;
|
|
|
+ }
|
|
|
+
|
|
|
+ private KnnSearchBuilder(String field, Supplier<float[]> querySupplier, int k, int numCands, List<QueryBuilder> filterQueries) {
|
|
|
+ this.field = field;
|
|
|
+ this.queryVector = new float[0];
|
|
|
+ this.queryVectorBuilder = null;
|
|
|
+ this.k = k;
|
|
|
+ this.numCands = numCands;
|
|
|
+ this.filterQueries = filterQueries;
|
|
|
+ this.querySupplier = querySupplier;
|
|
|
}
|
|
|
|
|
|
public KnnSearchBuilder(StreamInput in) throws IOException {
|
|
@@ -112,6 +182,12 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
this.queryVector = in.readFloatArray();
|
|
|
this.filterQueries = in.readNamedWriteableList(QueryBuilder.class);
|
|
|
this.boost = in.readFloat();
|
|
|
+ if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_7_0)) {
|
|
|
+ this.queryVectorBuilder = in.readOptionalNamedWriteable(QueryVectorBuilder.class);
|
|
|
+ } else {
|
|
|
+ this.queryVectorBuilder = null;
|
|
|
+ }
|
|
|
+ this.querySupplier = null;
|
|
|
}
|
|
|
|
|
|
public int k() {
|
|
@@ -140,6 +216,32 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
|
|
|
@Override
|
|
|
public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
|
|
|
+ if (querySupplier != null) {
|
|
|
+ if (querySupplier.get() == null) {
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+ return new KnnSearchBuilder(field, querySupplier.get(), k, numCands).boost(boost).addFilterQueries(filterQueries);
|
|
|
+ }
|
|
|
+ if (queryVectorBuilder != null) {
|
|
|
+ SetOnce<float[]> toSet = new SetOnce<>();
|
|
|
+ ctx.registerAsyncAction((c, l) -> queryVectorBuilder.buildVector(c, ActionListener.wrap(v -> {
|
|
|
+ toSet.set(v);
|
|
|
+ if (v == null) {
|
|
|
+ l.onFailure(
|
|
|
+ new IllegalArgumentException(
|
|
|
+ format(
|
|
|
+ "[%s] with name [%s] returned null query_vector",
|
|
|
+ QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
|
|
|
+ queryVectorBuilder.getWriteableName()
|
|
|
+ )
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ l.onResponse(null);
|
|
|
+ }, l::onFailure)));
|
|
|
+ return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries).boost(boost);
|
|
|
+ }
|
|
|
boolean changed = false;
|
|
|
List<QueryBuilder> rewrittenQueries = new ArrayList<>(filterQueries.size());
|
|
|
for (QueryBuilder query : filterQueries) {
|
|
@@ -156,6 +258,9 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
}
|
|
|
|
|
|
public KnnVectorQueryBuilder toQueryBuilder() {
|
|
|
+ if (queryVectorBuilder != null) {
|
|
|
+ throw new IllegalArgumentException("missing rewrite");
|
|
|
+ }
|
|
|
return new KnnVectorQueryBuilder(field, queryVector, numCands).boost(boost).addFilterQueries(filterQueries);
|
|
|
}
|
|
|
|
|
@@ -168,21 +273,38 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
&& numCands == that.numCands
|
|
|
&& Objects.equals(field, that.field)
|
|
|
&& Arrays.equals(queryVector, that.queryVector)
|
|
|
+ && Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
|
|
|
+ && Objects.equals(querySupplier, that.querySupplier)
|
|
|
&& Objects.equals(filterQueries, that.filterQueries)
|
|
|
&& boost == that.boost;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(field, k, numCands, Arrays.hashCode(queryVector), Objects.hashCode(filterQueries), boost);
|
|
|
+ return Objects.hash(
|
|
|
+ field,
|
|
|
+ k,
|
|
|
+ numCands,
|
|
|
+ querySupplier,
|
|
|
+ queryVectorBuilder,
|
|
|
+ Arrays.hashCode(queryVector),
|
|
|
+ Objects.hashCode(filterQueries),
|
|
|
+ boost
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
builder.field(FIELD_FIELD.getPreferredName(), field)
|
|
|
.field(K_FIELD.getPreferredName(), k)
|
|
|
- .field(NUM_CANDS_FIELD.getPreferredName(), numCands)
|
|
|
- .array(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
|
|
|
+ .field(NUM_CANDS_FIELD.getPreferredName(), numCands);
|
|
|
+ if (queryVectorBuilder != null) {
|
|
|
+ builder.startObject(QUERY_VECTOR_BUILDER_FIELD.getPreferredName());
|
|
|
+ builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder);
|
|
|
+ builder.endObject();
|
|
|
+ } else {
|
|
|
+ builder.array(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
|
|
|
+ }
|
|
|
|
|
|
if (filterQueries.isEmpty() == false) {
|
|
|
builder.startArray(FILTER_FIELD.getPreferredName());
|
|
@@ -201,11 +323,26 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
|
|
|
|
|
|
@Override
|
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
|
+ if (querySupplier != null) {
|
|
|
+ throw new IllegalStateException("missing a rewriteAndFetch?");
|
|
|
+ }
|
|
|
out.writeString(field);
|
|
|
out.writeVInt(k);
|
|
|
out.writeVInt(numCands);
|
|
|
out.writeFloatArray(queryVector);
|
|
|
out.writeNamedWriteableList(filterQueries);
|
|
|
out.writeFloat(boost);
|
|
|
+ if (out.getTransportVersion().before(TransportVersion.V_8_7_0) && queryVectorBuilder != null) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ format(
|
|
|
+ "cannot serialize [%s] to older node of version [%s]",
|
|
|
+ QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
|
|
|
+ out.getTransportVersion()
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+ if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_7_0)) {
|
|
|
+ out.writeOptionalNamedWriteable(queryVectorBuilder);
|
|
|
+ }
|
|
|
}
|
|
|
}
|