|
@@ -20,6 +20,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.common.lucene.search.Queries;
|
|
|
import org.elasticsearch.core.Nullable;
|
|
|
+import org.elasticsearch.features.NodeFeature;
|
|
|
import org.elasticsearch.index.mapper.MappedFieldType;
|
|
|
import org.elasticsearch.index.mapper.NestedObjectMapper;
|
|
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
|
@@ -52,11 +53,14 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstr
|
|
|
* {@link org.apache.lucene.search.KnnByteVectorQuery}.
|
|
|
*/
|
|
|
public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBuilder> {
|
|
|
+ public static final NodeFeature K_PARAM_SUPPORTED = new NodeFeature("search.vectors.k_param_supported");
|
|
|
+
|
|
|
public static final String NAME = "knn";
|
|
|
private static final int NUM_CANDS_LIMIT = 10_000;
|
|
|
private static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f;
|
|
|
|
|
|
public static final ParseField FIELD_FIELD = new ParseField("field");
|
|
|
+ public static final ParseField K_FIELD = new ParseField("k");
|
|
|
public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
|
|
|
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
|
|
|
public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity");
|
|
@@ -69,10 +73,11 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
args -> new KnnVectorQueryBuilder(
|
|
|
(String) args[0],
|
|
|
(VectorData) args[1],
|
|
|
- (QueryVectorBuilder) args[4],
|
|
|
+ (QueryVectorBuilder) args[5],
|
|
|
null,
|
|
|
(Integer) args[2],
|
|
|
- (Float) args[3]
|
|
|
+ (Integer) args[3],
|
|
|
+ (Float) args[4]
|
|
|
)
|
|
|
);
|
|
|
|
|
@@ -84,6 +89,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
QUERY_VECTOR_FIELD,
|
|
|
ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER
|
|
|
);
|
|
|
+ PARSER.declareInt(optionalConstructorArg(), K_FIELD);
|
|
|
PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD);
|
|
|
PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY_FIELD);
|
|
|
PARSER.declareNamedObject(
|
|
@@ -106,26 +112,33 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
|
|
|
private final String fieldName;
|
|
|
private final VectorData queryVector;
|
|
|
+ private final Integer k;
|
|
|
private Integer numCands;
|
|
|
private final List<QueryBuilder> filterQueries = new ArrayList<>();
|
|
|
private final Float vectorSimilarity;
|
|
|
private final QueryVectorBuilder queryVectorBuilder;
|
|
|
private final Supplier<float[]> queryVectorSupplier;
|
|
|
|
|
|
- public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer numCands, Float vectorSimilarity) {
|
|
|
- this(fieldName, VectorData.fromFloats(queryVector), null, null, numCands, vectorSimilarity);
|
|
|
+ public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
|
|
|
+ this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity);
|
|
|
}
|
|
|
|
|
|
- protected KnnVectorQueryBuilder(String fieldName, QueryVectorBuilder queryVectorBuilder, Integer numCands, Float vectorSimilarity) {
|
|
|
- this(fieldName, null, queryVectorBuilder, null, numCands, vectorSimilarity);
|
|
|
+ protected KnnVectorQueryBuilder(
|
|
|
+ String fieldName,
|
|
|
+ QueryVectorBuilder queryVectorBuilder,
|
|
|
+ Integer k,
|
|
|
+ Integer numCands,
|
|
|
+ Float vectorSimilarity
|
|
|
+ ) {
|
|
|
+ this(fieldName, null, queryVectorBuilder, null, k, numCands, vectorSimilarity);
|
|
|
}
|
|
|
|
|
|
- public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer numCands, Float vectorSimilarity) {
|
|
|
- this(fieldName, VectorData.fromBytes(queryVector), null, null, numCands, vectorSimilarity);
|
|
|
+ public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
|
|
|
+ this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, vectorSimilarity);
|
|
|
}
|
|
|
|
|
|
- public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer numCands, Float vectorSimilarity) {
|
|
|
- this(fieldName, queryVector, null, null, numCands, vectorSimilarity);
|
|
|
+ public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
|
|
|
+ this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity);
|
|
|
}
|
|
|
|
|
|
private KnnVectorQueryBuilder(
|
|
@@ -133,12 +146,21 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
VectorData queryVector,
|
|
|
QueryVectorBuilder queryVectorBuilder,
|
|
|
Supplier<float[]> queryVectorSupplier,
|
|
|
+ Integer k,
|
|
|
Integer numCands,
|
|
|
Float vectorSimilarity
|
|
|
) {
|
|
|
+ if (k != null && k < 1) {
|
|
|
+ throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
|
|
|
+ }
|
|
|
if (numCands != null && numCands > NUM_CANDS_LIMIT) {
|
|
|
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
|
|
|
}
|
|
|
+ if (k != null && numCands != null && numCands < k) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]"
|
|
|
+ );
|
|
|
+ }
|
|
|
if (queryVector == null && queryVectorBuilder == null) {
|
|
|
throw new IllegalArgumentException(
|
|
|
format(
|
|
@@ -158,6 +180,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
}
|
|
|
this.fieldName = fieldName;
|
|
|
this.queryVector = queryVector;
|
|
|
+ this.k = k;
|
|
|
this.numCands = numCands;
|
|
|
this.vectorSimilarity = vectorSimilarity;
|
|
|
this.queryVectorBuilder = queryVectorBuilder;
|
|
@@ -167,6 +190,11 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
public KnnVectorQueryBuilder(StreamInput in) throws IOException {
|
|
|
super(in);
|
|
|
this.fieldName = in.readString();
|
|
|
+ if (in.getTransportVersion().onOrAfter(TransportVersions.K_FOR_KNN_QUERY_ADDED)) {
|
|
|
+ this.k = in.readOptionalVInt();
|
|
|
+ } else {
|
|
|
+ this.k = null;
|
|
|
+ }
|
|
|
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
|
|
|
this.numCands = in.readOptionalVInt();
|
|
|
} else {
|
|
@@ -214,6 +242,10 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
return vectorSimilarity;
|
|
|
}
|
|
|
|
|
|
+ public Integer k() {
|
|
|
+ return k;
|
|
|
+ }
|
|
|
+
|
|
|
public Integer numCands() {
|
|
|
return numCands;
|
|
|
}
|
|
@@ -245,6 +277,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
throw new IllegalStateException("missing a rewriteAndFetch?");
|
|
|
}
|
|
|
out.writeString(fieldName);
|
|
|
+ if (out.getTransportVersion().onOrAfter(TransportVersions.K_FOR_KNN_QUERY_ADDED)) {
|
|
|
+ out.writeOptionalVInt(k);
|
|
|
+ }
|
|
|
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
|
|
|
out.writeOptionalVInt(numCands);
|
|
|
} else {
|
|
@@ -302,6 +337,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
if (queryVector != null) {
|
|
|
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
|
|
|
}
|
|
|
+ if (k != null) {
|
|
|
+ builder.field(K_FIELD.getPreferredName(), k);
|
|
|
+ }
|
|
|
if (numCands != null) {
|
|
|
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
|
|
|
}
|
|
@@ -335,7 +373,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
if (queryVectorSupplier.get() == null) {
|
|
|
return this;
|
|
|
}
|
|
|
- return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), numCands, vectorSimilarity).boost(boost)
|
|
|
+ return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, vectorSimilarity).boost(boost)
|
|
|
.queryName(queryName)
|
|
|
.addFilterQueries(filterQueries);
|
|
|
}
|
|
@@ -357,7 +395,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
}
|
|
|
ll.onResponse(null);
|
|
|
})));
|
|
|
- return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, numCands, vectorSimilarity).boost(
|
|
|
+ return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, k, numCands, vectorSimilarity).boost(
|
|
|
boost
|
|
|
).queryName(queryName).addFilterQueries(filterQueries);
|
|
|
}
|
|
@@ -377,7 +415,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
rewrittenQueries.add(rewrittenQuery);
|
|
|
}
|
|
|
if (changed) {
|
|
|
- return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, numCands, vectorSimilarity)
|
|
|
+ return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, k, numCands, vectorSimilarity)
|
|
|
.boost(boost)
|
|
|
.queryName(queryName)
|
|
|
.addFilterQueries(rewrittenQueries);
|
|
@@ -388,7 +426,12 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
@Override
|
|
|
protected Query doToQuery(SearchExecutionContext context) throws IOException {
|
|
|
MappedFieldType fieldType = context.getFieldType(fieldName);
|
|
|
- int requestSize = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize();
|
|
|
+ int requestSize;
|
|
|
+ if (k != null) {
|
|
|
+ requestSize = k;
|
|
|
+ } else {
|
|
|
+ requestSize = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize();
|
|
|
+ }
|
|
|
int adjustedNumCands = numCands == null
|
|
|
? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * requestSize, NUM_CANDS_LIMIT))
|
|
|
: numCands;
|
|
@@ -446,20 +489,21 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
// Now join the filterQuery & parentFilter to provide the matching blocks of children
|
|
|
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
|
|
|
}
|
|
|
- return vectorFieldType.createKnnQuery(queryVector, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet);
|
|
|
+ return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet);
|
|
|
}
|
|
|
- return vectorFieldType.createKnnQuery(queryVector, adjustedNumCands, filterQuery, vectorSimilarity, null);
|
|
|
+ return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, null);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
protected int doHashCode() {
|
|
|
- return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity, queryVectorBuilder);
|
|
|
+ return Objects.hash(fieldName, Objects.hashCode(queryVector), k, numCands, filterQueries, vectorSimilarity, queryVectorBuilder);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
protected boolean doEquals(KnnVectorQueryBuilder other) {
|
|
|
return Objects.equals(fieldName, other.fieldName)
|
|
|
&& Objects.equals(queryVector, other.queryVector)
|
|
|
+ && Objects.equals(k, other.k)
|
|
|
&& Objects.equals(numCands, other.numCands)
|
|
|
&& Objects.equals(filterQueries, other.filterQueries)
|
|
|
&& Objects.equals(vectorSimilarity, other.vectorSimilarity)
|