|
@@ -15,6 +15,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata;
|
|
|
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
+import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.features.NodeFeature;
|
|
|
import org.elasticsearch.index.mapper.MappedFieldType;
|
|
|
import org.elasticsearch.index.query.AbstractQueryBuilder;
|
|
@@ -40,13 +41,17 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.Collection;
|
|
|
+import java.util.Collections;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Objects;
|
|
|
import java.util.Set;
|
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
+import static org.elasticsearch.TransportVersions.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS;
|
|
|
+import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
|
|
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
|
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
|
@@ -58,7 +63,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
public static final NodeFeature SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = new NodeFeature("semantic_query.multiple_inference_ids");
|
|
|
public static final NodeFeature SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX = new NodeFeature("semantic_query.filter_field_caps_fix");
|
|
|
|
|
|
- private static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV = TransportVersion.fromName(
|
|
|
+ static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV = TransportVersion.fromName(
|
|
|
"semantic_query_multiple_inference_ids"
|
|
|
);
|
|
|
|
|
@@ -84,7 +89,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
|
|
|
private final String fieldName;
|
|
|
private final String query;
|
|
|
- private final Map<String, InferenceResults> inferenceResultsMap;
|
|
|
+ private final Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap;
|
|
|
private final Boolean lenient;
|
|
|
|
|
|
public SemanticQueryBuilder(String fieldName, String query) {
|
|
@@ -95,7 +100,12 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
this(fieldName, query, lenient, null);
|
|
|
}
|
|
|
|
|
|
- protected SemanticQueryBuilder(String fieldName, String query, Boolean lenient, Map<String, InferenceResults> inferenceResultsMap) {
|
|
|
+ protected SemanticQueryBuilder(
|
|
|
+ String fieldName,
|
|
|
+ String query,
|
|
|
+ Boolean lenient,
|
|
|
+ Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
|
|
|
+ ) {
|
|
|
if (fieldName == null) {
|
|
|
throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value");
|
|
|
}
|
|
@@ -112,11 +122,17 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
super(in);
|
|
|
this.fieldName = in.readString();
|
|
|
this.query = in.readString();
|
|
|
- if (in.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
|
|
|
- this.inferenceResultsMap = in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)));
|
|
|
+ if (in.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
|
|
|
+ this.inferenceResultsMap = in.readOptional(
|
|
|
+ i1 -> i1.readImmutableMap(FullyQualifiedInferenceId::new, i2 -> i2.readNamedWriteable(InferenceResults.class))
|
|
|
+ );
|
|
|
+ } else if (in.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
|
|
|
+ this.inferenceResultsMap = convertFromBwcInferenceResultsMap(
|
|
|
+ in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)))
|
|
|
+ );
|
|
|
} else {
|
|
|
InferenceResults inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class);
|
|
|
- this.inferenceResultsMap = inferenceResults != null ? buildBwcInferenceResultsMap(inferenceResults) : null;
|
|
|
+ this.inferenceResultsMap = inferenceResults != null ? buildSingleResultInferenceResultsMap(inferenceResults) : null;
|
|
|
in.readBoolean(); // Discard noInferenceResults, it is no longer necessary
|
|
|
}
|
|
|
if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) {
|
|
@@ -130,8 +146,18 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
protected void doWriteTo(StreamOutput out) throws IOException {
|
|
|
out.writeString(fieldName);
|
|
|
out.writeString(query);
|
|
|
- if (out.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
|
|
|
- out.writeOptional((o, v) -> o.writeMap(v, StreamOutput::writeNamedWriteable), inferenceResultsMap);
|
|
|
+ if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
|
|
|
+ out.writeOptional(
|
|
|
+ (o, v) -> o.writeMap(v, StreamOutput::writeWriteable, StreamOutput::writeNamedWriteable),
|
|
|
+ inferenceResultsMap
|
|
|
+ );
|
|
|
+ } else if (out.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
|
|
|
+ out.writeOptional((o1, v) -> o1.writeMap(v, (o2, id) -> {
|
|
|
+ if (id.clusterAlias().equals(LOCAL_CLUSTER_GROUP_KEY) == false) {
|
|
|
+ throw new IllegalArgumentException("Cannot serialize remote cluster inference results in a mixed-version cluster");
|
|
|
+ }
|
|
|
+ o2.writeString(id.inferenceId());
|
|
|
+ }, StreamOutput::writeNamedWriteable), inferenceResultsMap);
|
|
|
} else {
|
|
|
InferenceResults inferenceResults = null;
|
|
|
if (inferenceResultsMap != null) {
|
|
@@ -150,7 +176,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private SemanticQueryBuilder(SemanticQueryBuilder other, Map<String, InferenceResults> inferenceResultsMap) {
|
|
|
+ private SemanticQueryBuilder(SemanticQueryBuilder other, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
|
|
|
this.fieldName = other.fieldName;
|
|
|
this.query = other.query;
|
|
|
this.boost = other.boost;
|
|
@@ -182,9 +208,63 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
return PARSER.apply(parser, null);
|
|
|
}
|
|
|
|
|
|
- public static void registerInferenceAsyncAction(
|
|
|
+ /**
|
|
|
+ * <p>
|
|
|
+ * Get inference results for the provided query using the provided inference IDs. The inference IDs are fully qualified by the
|
|
|
+ * cluster alias in the provided {@link QueryRewriteContext}.
|
|
|
+ * </p>
|
|
|
+ * <p>
|
|
|
+ * This method will return an inference results map that will be asynchronously populated with inference results. If the provided
|
|
|
+ * inference results map already contains all required inference results, the same map instance will be returned. Otherwise, a new map
|
|
|
+ * instance will be returned. It is guaranteed that a non-null map instance will be returned.
|
|
|
+ * </p>
|
|
|
+ *
|
|
|
+ * @param queryRewriteContext The query rewrite context
|
|
|
+ * @param inferenceIds The inference IDs to use to generate inference results
|
|
|
+ * @param inferenceResultsMap The initial inference results map
|
|
|
+ * @param query The query to generate inference results for
|
|
|
+ * @return An inference results map
|
|
|
+ */
|
|
|
+ static Map<FullyQualifiedInferenceId, InferenceResults> getInferenceResults(
|
|
|
+ QueryRewriteContext queryRewriteContext,
|
|
|
+ Set<String> inferenceIds,
|
|
|
+ @Nullable Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
|
|
|
+ @Nullable String query
|
|
|
+ ) {
|
|
|
+ boolean modifiedInferenceResultsMap = false;
|
|
|
+ Map<FullyQualifiedInferenceId, InferenceResults> currentInferenceResultsMap = inferenceResultsMap != null
|
|
|
+ ? inferenceResultsMap
|
|
|
+ : Map.of();
|
|
|
+
|
|
|
+ if (query != null) {
|
|
|
+ for (String inferenceId : inferenceIds) {
|
|
|
+ FullyQualifiedInferenceId fullyQualifiedInferenceId = new FullyQualifiedInferenceId(
|
|
|
+ queryRewriteContext.getLocalClusterAlias(),
|
|
|
+ inferenceId
|
|
|
+ );
|
|
|
+ if (currentInferenceResultsMap.containsKey(fullyQualifiedInferenceId) == false) {
|
|
|
+ if (modifiedInferenceResultsMap == false) {
|
|
|
+ // Copy the inference results map to ensure it is mutable and thread safe
|
|
|
+ currentInferenceResultsMap = new ConcurrentHashMap<>(currentInferenceResultsMap);
|
|
|
+ modifiedInferenceResultsMap = true;
|
|
|
+ }
|
|
|
+
|
|
|
+ registerInferenceAsyncAction(
|
|
|
+ queryRewriteContext,
|
|
|
+ ((ConcurrentHashMap<FullyQualifiedInferenceId, InferenceResults>) currentInferenceResultsMap),
|
|
|
+ query,
|
|
|
+ inferenceId
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return currentInferenceResultsMap;
|
|
|
+ }
|
|
|
+
|
|
|
+ static void registerInferenceAsyncAction(
|
|
|
QueryRewriteContext queryRewriteContext,
|
|
|
- Map<String, InferenceResults> inferenceResultsMap,
|
|
|
+ ConcurrentHashMap<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
|
|
|
String query,
|
|
|
String inferenceId
|
|
|
) {
|
|
@@ -208,21 +288,38 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
InferenceAction.INSTANCE,
|
|
|
inferenceRequest,
|
|
|
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
|
|
|
- inferenceResultsMap.put(inferenceId, validateAndConvertInferenceResults(inferenceResponse.getResults(), inferenceId));
|
|
|
+ inferenceResultsMap.put(
|
|
|
+ new FullyQualifiedInferenceId(queryRewriteContext.getLocalClusterAlias(), inferenceId),
|
|
|
+ validateAndConvertInferenceResults(inferenceResponse.getResults(), inferenceId)
|
|
|
+ );
|
|
|
l.onResponse(null);
|
|
|
})
|
|
|
)
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ static Map<FullyQualifiedInferenceId, InferenceResults> convertFromBwcInferenceResultsMap(
|
|
|
+ Map<String, InferenceResults> inferenceResultsMap
|
|
|
+ ) {
|
|
|
+ Map<FullyQualifiedInferenceId, InferenceResults> converted = null;
|
|
|
+ if (inferenceResultsMap != null) {
|
|
|
+ converted = Collections.unmodifiableMap(
|
|
|
+ inferenceResultsMap.entrySet()
|
|
|
+ .stream()
|
|
|
+ .collect(Collectors.toMap(e -> new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, e.getKey()), Map.Entry::getValue))
|
|
|
+ );
|
|
|
+ }
|
|
|
+ return converted;
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* Build an inference results map to store a single inference result that is not associated with an inference ID.
|
|
|
*
|
|
|
* @param inferenceResults The inference result
|
|
|
* @return An inference results map
|
|
|
*/
|
|
|
- protected static Map<String, InferenceResults> buildBwcInferenceResultsMap(InferenceResults inferenceResults) {
|
|
|
- return Map.of(PLACEHOLDER_INFERENCE_ID, inferenceResults);
|
|
|
+ static Map<FullyQualifiedInferenceId, InferenceResults> buildSingleResultInferenceResultsMap(InferenceResults inferenceResults) {
|
|
|
+ return Map.of(new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, PLACEHOLDER_INFERENCE_ID), inferenceResults);
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -232,8 +329,8 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
* @param inferenceResultsMap The inference results map
|
|
|
* @return The inference result
|
|
|
*/
|
|
|
- private static InferenceResults getBwcInferenceResults(Map<String, InferenceResults> inferenceResultsMap) {
|
|
|
- return inferenceResultsMap.get(PLACEHOLDER_INFERENCE_ID);
|
|
|
+ private static InferenceResults getSingleInferenceResult(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
|
|
|
+ return inferenceResultsMap.get(new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, PLACEHOLDER_INFERENCE_ID));
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -255,7 +352,12 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
return doRewriteBuildSemanticQuery(searchExecutionContext);
|
|
|
}
|
|
|
|
|
|
- return doRewriteGetInferenceResults(queryRewriteContext);
|
|
|
+ ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
|
|
|
+ if (resolvedIndices != null) {
|
|
|
+ return doRewriteGetInferenceResults(queryRewriteContext);
|
|
|
+ }
|
|
|
+
|
|
|
+ return this;
|
|
|
}
|
|
|
|
|
|
private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchExecutionContext) {
|
|
@@ -271,9 +373,11 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
}
|
|
|
|
|
|
String inferenceId = semanticTextFieldType.getSearchInferenceId();
|
|
|
- InferenceResults inferenceResults = getBwcInferenceResults(inferenceResultsMap);
|
|
|
+ InferenceResults inferenceResults = getSingleInferenceResult(inferenceResultsMap);
|
|
|
if (inferenceResults == null) {
|
|
|
- inferenceResults = inferenceResultsMap.get(inferenceId);
|
|
|
+ inferenceResults = inferenceResultsMap.get(
|
|
|
+ new FullyQualifiedInferenceId(searchExecutionContext.getLocalClusterAlias(), inferenceId)
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
if (inferenceResults == null) {
|
|
@@ -299,27 +403,30 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
}
|
|
|
|
|
|
private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
|
|
|
- if (inferenceResultsMap != null) {
|
|
|
- inferenceResultsErrorCheck();
|
|
|
- return this;
|
|
|
- }
|
|
|
-
|
|
|
ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
|
|
|
- if (resolvedIndices == null) {
|
|
|
- throw new IllegalStateException(
|
|
|
- "Rewriting on the coordinator node requires a query rewrite context with non-null resolved indices"
|
|
|
- );
|
|
|
- } else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
|
|
|
+ if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
|
|
|
throw new IllegalArgumentException(NAME + " query does not support cross-cluster search");
|
|
|
}
|
|
|
|
|
|
- Map<String, InferenceResults> inferenceResultsMap = new ConcurrentHashMap<>();
|
|
|
- Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
|
|
|
- for (String inferenceId : inferenceIds) {
|
|
|
- registerInferenceAsyncAction(queryRewriteContext, inferenceResultsMap, query, inferenceId);
|
|
|
+ SemanticQueryBuilder rewritten = this;
|
|
|
+ if (queryRewriteContext.hasAsyncActions() == false) {
|
|
|
+ Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
|
|
|
+ Map<FullyQualifiedInferenceId, InferenceResults> modifiedInferenceResultsMap = getInferenceResults(
|
|
|
+ queryRewriteContext,
|
|
|
+ inferenceIds,
|
|
|
+ inferenceResultsMap,
|
|
|
+ query
|
|
|
+ );
|
|
|
+
|
|
|
+ if (modifiedInferenceResultsMap == inferenceResultsMap) {
|
|
|
+ // The inference results map is fully populated, so we can perform error checking
|
|
|
+ inferenceResultsErrorCheck(modifiedInferenceResultsMap);
|
|
|
+ } else {
|
|
|
+ rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- return new SemanticQueryBuilder(this, inferenceResultsMap);
|
|
|
+ return rewritten;
|
|
|
}
|
|
|
|
|
|
private static InferenceResults validateAndConvertInferenceResults(
|
|
@@ -364,9 +471,9 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|
|
return inferenceResults;
|
|
|
}
|
|
|
|
|
|
- private void inferenceResultsErrorCheck() {
|
|
|
+ private void inferenceResultsErrorCheck(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
|
|
|
for (var entry : inferenceResultsMap.entrySet()) {
|
|
|
- String inferenceId = entry.getKey();
|
|
|
+ String inferenceId = entry.getKey().inferenceId();
|
|
|
InferenceResults inferenceResults = entry.getValue();
|
|
|
|
|
|
if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults) {
|