浏览代码

Allow semantic queries to gather inference results on remote clusters (#134956)

Mike Pellegrini 3 周之前
父节点
当前提交
200b08a91f
共有 12 个文件被更改,包括 418 次插入127 次删除
  1. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  2. 32 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java
  3. 11 11
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
  4. 3 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java
  5. 48 21
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java
  6. 11 11
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java
  7. 143 36
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
  8. 53 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java
  9. 15 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
  10. 20 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java
  11. 15 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java
  12. 66 40
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -333,6 +333,7 @@ public class TransportVersions {
     public static final TransportVersion ESQL_LOOKUP_JOIN_ON_EXPRESSION = def(9_163_0_00);
     public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED = def(9_164_0_00);
     public static final TransportVersion SEARCH_SOURCE_EXCLUDE_INFERENCE_FIELDS_PARAM = def(9_165_0_00);
+    public static final TransportVersion INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS = def(9_166_0_00);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 32 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java

@@ -0,0 +1,32 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.queries;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public record FullyQualifiedInferenceId(String clusterAlias, String inferenceId) implements Writeable {
+    public FullyQualifiedInferenceId(String clusterAlias, String inferenceId) {
+        this.clusterAlias = Objects.requireNonNull(clusterAlias);
+        this.inferenceId = Objects.requireNonNull(inferenceId);
+    }
+
+    public FullyQualifiedInferenceId(StreamInput in) throws IOException {
+        this(in.readString(), in.readString());
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(clusterAlias);
+        out.writeString(inferenceId);
+    }
+}

+ 11 - 11
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java

@@ -48,9 +48,9 @@ public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInfere
         super(in);
     }
 
-    public InterceptedInferenceKnnVectorQueryBuilder(
+    InterceptedInferenceKnnVectorQueryBuilder(
         InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> other,
-        Map<String, InferenceResults> inferenceResultsMap
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
     ) {
         super(other, inferenceResultsMap);
     }
@@ -114,7 +114,7 @@ public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInfere
     }
 
     @Override
-    protected QueryBuilder copy(Map<String, InferenceResults> inferenceResultsMap) {
+    protected QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
         return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap);
     }
 
@@ -129,9 +129,9 @@ public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInfere
         if (fieldType == null) {
             rewritten = new MatchNoneQueryBuilder();
         } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
-            rewritten = querySemanticTextField(semanticTextFieldType);
+            rewritten = querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType);
         } else {
-            rewritten = queryNonSemanticTextField();
+            rewritten = queryNonSemanticTextField(indexMetadataContext.getLocalClusterAlias());
         }
 
         return rewritten;
@@ -166,7 +166,7 @@ public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInfere
         return modelId;
     }
 
-    private QueryBuilder querySemanticTextField(SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
+    private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
         MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings();
         if (modelSettings == null) {
             // No inference results have been indexed yet
@@ -182,7 +182,7 @@ public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInfere
                 inferenceId = semanticTextFieldType.getSearchInferenceId();
             }
 
-            MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(inferenceId);
+            MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(clusterAlias, inferenceId);
             queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
         }
 
@@ -202,7 +202,7 @@ public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInfere
             .queryName(originalQuery.queryName());
     }
 
-    private QueryBuilder queryNonSemanticTextField() {
+    private QueryBuilder queryNonSemanticTextField(String clusterAlias) {
         VectorData queryVector = originalQuery.queryVector();
         if (queryVector == null) {
             String modelId = getQueryVectorBuilderModelId();
@@ -213,7 +213,7 @@ public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInfere
                 throw new IllegalStateException("No query vector or query vector builder model ID specified");
             }
 
-            MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(modelId);
+            MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(clusterAlias, modelId);
             queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
         }
 
@@ -231,8 +231,8 @@ public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInfere
         return knnQuery;
     }
 
-    private MlTextEmbeddingResults getTextEmbeddingResults(String inferenceId) {
-        InferenceResults inferenceResults = inferenceResultsMap.get(inferenceId);
+    private MlTextEmbeddingResults getTextEmbeddingResults(String clusterAlias, String inferenceId) {
+        InferenceResults inferenceResults = inferenceResultsMap.get(new FullyQualifiedInferenceId(clusterAlias, inferenceId));
         if (inferenceResults == null) {
             throw new IllegalStateException("Could not find inference results from inference endpoint [" + inferenceId + "]");
         } else if (inferenceResults instanceof MlTextEmbeddingResults == false) {

+ 3 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java

@@ -35,9 +35,9 @@ public class InterceptedInferenceMatchQueryBuilder extends InterceptedInferenceQ
         super(in);
     }
 
-    private InterceptedInferenceMatchQueryBuilder(
+    InterceptedInferenceMatchQueryBuilder(
         InterceptedInferenceQueryBuilder<MatchQueryBuilder> other,
-        Map<String, InferenceResults> inferenceResultsMap
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
     ) {
         super(other, inferenceResultsMap);
     }
@@ -63,7 +63,7 @@ public class InterceptedInferenceMatchQueryBuilder extends InterceptedInferenceQ
     }
 
     @Override
-    protected QueryBuilder copy(Map<String, InferenceResults> inferenceResultsMap) {
+    protected QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
         return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap);
     }
 

+ 48 - 21
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java

@@ -37,9 +37,11 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
 
+import static org.elasticsearch.TransportVersions.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS;
 import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING;
+import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
+import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.convertFromBwcInferenceResultsMap;
 
 /**
  * <p>
@@ -60,7 +62,7 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
     public static final NodeFeature NEW_SEMANTIC_QUERY_INTERCEPTORS = new NodeFeature("search.new_semantic_query_interceptors");
 
     protected final T originalQuery;
-    protected final Map<String, InferenceResults> inferenceResultsMap;
+    protected final Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap;
 
     protected InterceptedInferenceQueryBuilder(T originalQuery) {
         Objects.requireNonNull(originalQuery, "original query must not be null");
@@ -72,12 +74,20 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
     protected InterceptedInferenceQueryBuilder(StreamInput in) throws IOException {
         super(in);
         this.originalQuery = (T) in.readNamedWriteable(QueryBuilder.class);
-        this.inferenceResultsMap = in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)));
+        if (in.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
+            this.inferenceResultsMap = in.readOptional(
+                i1 -> i1.readImmutableMap(FullyQualifiedInferenceId::new, i2 -> i2.readNamedWriteable(InferenceResults.class))
+            );
+        } else {
+            this.inferenceResultsMap = convertFromBwcInferenceResultsMap(
+                in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)))
+            );
+        }
     }
 
     protected InterceptedInferenceQueryBuilder(
         InterceptedInferenceQueryBuilder<T> other,
-        Map<String, InferenceResults> inferenceResultsMap
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
     ) {
         this.originalQuery = other.originalQuery;
         this.inferenceResultsMap = inferenceResultsMap;
@@ -122,7 +132,7 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
      * @param inferenceResultsMap The inference results map
      * @return A copy of {@code this} with the provided inference results map
      */
-    protected abstract QueryBuilder copy(Map<String, InferenceResults> inferenceResultsMap);
+    protected abstract QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap);
 
     /**
      * Rewrite to a {@link QueryBuilder} appropriate for a specific index's mappings. The implementation can use
@@ -168,7 +178,19 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
     @Override
     protected void doWriteTo(StreamOutput out) throws IOException {
         out.writeNamedWriteable(originalQuery);
-        out.writeOptional((o, v) -> o.writeMap(v, StreamOutput::writeNamedWriteable), inferenceResultsMap);
+        if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
+            out.writeOptional(
+                (o, v) -> o.writeMap(v, StreamOutput::writeWriteable, StreamOutput::writeNamedWriteable),
+                inferenceResultsMap
+            );
+        } else {
+            out.writeOptional((o1, v) -> o1.writeMap(v, (o2, id) -> {
+                if (id.clusterAlias().equals(LOCAL_CLUSTER_GROUP_KEY) == false) {
+                    throw new IllegalArgumentException("Cannot serialize remote cluster inference results in a mixed-version cluster");
+                }
+                o2.writeString(id.inferenceId());
+            }, StreamOutput::writeNamedWriteable), inferenceResultsMap);
+        }
     }
 
     @Override
@@ -227,11 +249,6 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
     }
 
     private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
-        if (this.inferenceResultsMap != null) {
-            inferenceResultsErrorCheck(this.inferenceResultsMap);
-            return this;
-        }
-
         QueryBuilder rewrittenBwC = doRewriteBwC(queryRewriteContext);
         if (rewrittenBwC != this) {
             return rewrittenBwC;
@@ -271,17 +288,27 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
             inferenceIds = Set.of(inferenceIdOverride);
         }
 
-        // If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results are
-        // provided by the user.
-        String query = getQuery();
-        Map<String, InferenceResults> inferenceResultsMap = new ConcurrentHashMap<>();
-        if (query != null) {
-            for (String inferenceId : inferenceIds) {
-                SemanticQueryBuilder.registerInferenceAsyncAction(queryRewriteContext, inferenceResultsMap, query, inferenceId);
+        QueryBuilder rewritten = this;
+        if (queryRewriteContext.hasAsyncActions() == false) {
+            // If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results
+            // are provided by the user. Ensure that we set an empty inference results map in this case so that it is always non-null after
+            // coordinator node rewrite.
+            Map<FullyQualifiedInferenceId, InferenceResults> modifiedInferenceResultsMap = SemanticQueryBuilder.getInferenceResults(
+                queryRewriteContext,
+                inferenceIds,
+                this.inferenceResultsMap,
+                getQuery()
+            );
+
+            if (modifiedInferenceResultsMap == this.inferenceResultsMap) {
+                // The inference results map is fully populated, so we can perform error checking
+                inferenceResultsErrorCheck(modifiedInferenceResultsMap);
+            } else {
+                rewritten = copy(modifiedInferenceResultsMap);
             }
         }
 
-        return copy(inferenceResultsMap);
+        return rewritten;
     }
 
     private static Set<String> getInferenceIdsForFields(
@@ -360,9 +387,9 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
         inferenceFields.compute(field, (k, v) -> v == null ? weight : v * weight);
     }
 
-    private static void inferenceResultsErrorCheck(Map<String, InferenceResults> inferenceResultsMap) {
+    private static void inferenceResultsErrorCheck(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
         for (var entry : inferenceResultsMap.entrySet()) {
-            String inferenceId = entry.getKey();
+            String inferenceId = entry.getKey().inferenceId();
             InferenceResults inferenceResults = entry.getValue();
 
             if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults) {

+ 11 - 11
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java

@@ -47,9 +47,9 @@ public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInf
         super(in);
     }
 
-    public InterceptedInferenceSparseVectorQueryBuilder(
+    InterceptedInferenceSparseVectorQueryBuilder(
         InterceptedInferenceQueryBuilder<SparseVectorQueryBuilder> other,
-        Map<String, InferenceResults> inferenceResultsMap
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
     ) {
         super(other, inferenceResultsMap);
     }
@@ -96,7 +96,7 @@ public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInf
     }
 
     @Override
-    protected QueryBuilder copy(Map<String, InferenceResults> inferenceResultsMap) {
+    protected QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
         return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap);
     }
 
@@ -111,9 +111,9 @@ public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInf
         if (fieldType == null) {
             rewritten = new MatchNoneQueryBuilder();
         } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
-            rewritten = querySemanticTextField(semanticTextFieldType);
+            rewritten = querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType);
         } else {
-            rewritten = queryNonSemanticTextField();
+            rewritten = queryNonSemanticTextField(indexMetadataContext.getLocalClusterAlias());
         }
 
         return rewritten;
@@ -138,7 +138,7 @@ public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInf
         return originalQuery.getFieldName();
     }
 
-    private QueryBuilder querySemanticTextField(SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
+    private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
         MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings();
         if (modelSettings == null) {
             // No inference results have been indexed yet
@@ -154,7 +154,7 @@ public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInf
                 inferenceId = semanticTextFieldType.getSearchInferenceId();
             }
 
-            queryVector = getQueryVector(inferenceId);
+            queryVector = getQueryVector(clusterAlias, inferenceId);
         }
 
         SparseVectorQueryBuilder innerSparseVectorQuery = new SparseVectorQueryBuilder(
@@ -171,7 +171,7 @@ public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInf
             .queryName(originalQuery.queryName());
     }
 
-    private QueryBuilder queryNonSemanticTextField() {
+    private QueryBuilder queryNonSemanticTextField(String clusterAlias) {
         List<WeightedToken> queryVector = originalQuery.getQueryVectors();
         if (queryVector == null) {
             String inferenceId = originalQuery.getInferenceId();
@@ -179,7 +179,7 @@ public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInf
                 throw new IllegalArgumentException("Either query vector or inference ID must be specified");
             }
 
-            queryVector = getQueryVector(inferenceId);
+            queryVector = getQueryVector(clusterAlias, inferenceId);
         }
 
         return new SparseVectorQueryBuilder(
@@ -192,8 +192,8 @@ public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInf
         ).boost(originalQuery.boost()).queryName(originalQuery.queryName());
     }
 
-    private List<WeightedToken> getQueryVector(String inferenceId) {
-        InferenceResults inferenceResults = inferenceResultsMap.get(inferenceId);
+    private List<WeightedToken> getQueryVector(String clusterAlias, String inferenceId) {
+        InferenceResults inferenceResults = inferenceResultsMap.get(new FullyQualifiedInferenceId(clusterAlias, inferenceId));
         if (inferenceResults == null) {
             throw new IllegalStateException("Could not find inference results from inference endpoint [" + inferenceId + "]");
         } else if (inferenceResults instanceof TextExpansionResults == false) {

+ 143 - 36
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

@@ -15,6 +15,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.query.AbstractQueryBuilder;
@@ -40,13 +41,17 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
 
 import java.io.IOException;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
 
+import static org.elasticsearch.TransportVersions.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS;
+import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
 import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
@@ -58,7 +63,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
     public static final NodeFeature SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = new NodeFeature("semantic_query.multiple_inference_ids");
     public static final NodeFeature SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX = new NodeFeature("semantic_query.filter_field_caps_fix");
 
-    private static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV = TransportVersion.fromName(
+    static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV = TransportVersion.fromName(
         "semantic_query_multiple_inference_ids"
     );
 
@@ -84,7 +89,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
 
     private final String fieldName;
     private final String query;
-    private final Map<String, InferenceResults> inferenceResultsMap;
+    private final Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap;
     private final Boolean lenient;
 
     public SemanticQueryBuilder(String fieldName, String query) {
@@ -95,7 +100,12 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
         this(fieldName, query, lenient, null);
     }
 
-    protected SemanticQueryBuilder(String fieldName, String query, Boolean lenient, Map<String, InferenceResults> inferenceResultsMap) {
+    protected SemanticQueryBuilder(
+        String fieldName,
+        String query,
+        Boolean lenient,
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
+    ) {
         if (fieldName == null) {
             throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value");
         }
@@ -112,11 +122,17 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
         super(in);
         this.fieldName = in.readString();
         this.query = in.readString();
-        if (in.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
-            this.inferenceResultsMap = in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)));
+        if (in.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
+            this.inferenceResultsMap = in.readOptional(
+                i1 -> i1.readImmutableMap(FullyQualifiedInferenceId::new, i2 -> i2.readNamedWriteable(InferenceResults.class))
+            );
+        } else if (in.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
+            this.inferenceResultsMap = convertFromBwcInferenceResultsMap(
+                in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)))
+            );
         } else {
             InferenceResults inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class);
-            this.inferenceResultsMap = inferenceResults != null ? buildBwcInferenceResultsMap(inferenceResults) : null;
+            this.inferenceResultsMap = inferenceResults != null ? buildSingleResultInferenceResultsMap(inferenceResults) : null;
             in.readBoolean(); // Discard noInferenceResults, it is no longer necessary
         }
         if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) {
@@ -130,8 +146,18 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
     protected void doWriteTo(StreamOutput out) throws IOException {
         out.writeString(fieldName);
         out.writeString(query);
-        if (out.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
-            out.writeOptional((o, v) -> o.writeMap(v, StreamOutput::writeNamedWriteable), inferenceResultsMap);
+        if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
+            out.writeOptional(
+                (o, v) -> o.writeMap(v, StreamOutput::writeWriteable, StreamOutput::writeNamedWriteable),
+                inferenceResultsMap
+            );
+        } else if (out.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
+            out.writeOptional((o1, v) -> o1.writeMap(v, (o2, id) -> {
+                if (id.clusterAlias().equals(LOCAL_CLUSTER_GROUP_KEY) == false) {
+                    throw new IllegalArgumentException("Cannot serialize remote cluster inference results in a mixed-version cluster");
+                }
+                o2.writeString(id.inferenceId());
+            }, StreamOutput::writeNamedWriteable), inferenceResultsMap);
         } else {
             InferenceResults inferenceResults = null;
             if (inferenceResultsMap != null) {
@@ -150,7 +176,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
         }
     }
 
-    private SemanticQueryBuilder(SemanticQueryBuilder other, Map<String, InferenceResults> inferenceResultsMap) {
+    private SemanticQueryBuilder(SemanticQueryBuilder other, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
         this.fieldName = other.fieldName;
         this.query = other.query;
         this.boost = other.boost;
@@ -182,9 +208,63 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
         return PARSER.apply(parser, null);
     }
 
-    public static void registerInferenceAsyncAction(
+    /**
+     * <p>
+     * Get inference results for the provided query using the provided inference IDs. The inference IDs are fully qualified by the
+     * cluster alias in the provided {@link QueryRewriteContext}.
+     * </p>
+     * <p>
+     * This method will return an inference results map that will be asynchronously populated with inference results. If the provided
+     * inference results map already contains all required inference results, the same map instance will be returned. Otherwise, a new map
+     * instance will be returned. It is guaranteed that a non-null map instance will be returned.
+     * </p>
+     *
+     * @param queryRewriteContext The query rewrite context
+     * @param inferenceIds The inference IDs to use to generate inference results
+     * @param inferenceResultsMap The initial inference results map
+     * @param query The query to generate inference results for
+     * @return An inference results map
+     */
+    static Map<FullyQualifiedInferenceId, InferenceResults> getInferenceResults(
+        QueryRewriteContext queryRewriteContext,
+        Set<String> inferenceIds,
+        @Nullable Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
+        @Nullable String query
+    ) {
+        boolean modifiedInferenceResultsMap = false;
+        Map<FullyQualifiedInferenceId, InferenceResults> currentInferenceResultsMap = inferenceResultsMap != null
+            ? inferenceResultsMap
+            : Map.of();
+
+        if (query != null) {
+            for (String inferenceId : inferenceIds) {
+                FullyQualifiedInferenceId fullyQualifiedInferenceId = new FullyQualifiedInferenceId(
+                    queryRewriteContext.getLocalClusterAlias(),
+                    inferenceId
+                );
+                if (currentInferenceResultsMap.containsKey(fullyQualifiedInferenceId) == false) {
+                    if (modifiedInferenceResultsMap == false) {
+                        // Copy the inference results map to ensure it is mutable and thread safe
+                        currentInferenceResultsMap = new ConcurrentHashMap<>(currentInferenceResultsMap);
+                        modifiedInferenceResultsMap = true;
+                    }
+
+                    registerInferenceAsyncAction(
+                        queryRewriteContext,
+                        ((ConcurrentHashMap<FullyQualifiedInferenceId, InferenceResults>) currentInferenceResultsMap),
+                        query,
+                        inferenceId
+                    );
+                }
+            }
+        }
+
+        return currentInferenceResultsMap;
+    }
+
+    static void registerInferenceAsyncAction(
         QueryRewriteContext queryRewriteContext,
-        Map<String, InferenceResults> inferenceResultsMap,
+        ConcurrentHashMap<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
         String query,
         String inferenceId
     ) {
@@ -208,21 +288,38 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
                 InferenceAction.INSTANCE,
                 inferenceRequest,
                 listener.delegateFailureAndWrap((l, inferenceResponse) -> {
-                    inferenceResultsMap.put(inferenceId, validateAndConvertInferenceResults(inferenceResponse.getResults(), inferenceId));
+                    inferenceResultsMap.put(
+                        new FullyQualifiedInferenceId(queryRewriteContext.getLocalClusterAlias(), inferenceId),
+                        validateAndConvertInferenceResults(inferenceResponse.getResults(), inferenceId)
+                    );
                     l.onResponse(null);
                 })
             )
         );
     }
 
+    static Map<FullyQualifiedInferenceId, InferenceResults> convertFromBwcInferenceResultsMap(
+        Map<String, InferenceResults> inferenceResultsMap
+    ) {
+        Map<FullyQualifiedInferenceId, InferenceResults> converted = null;
+        if (inferenceResultsMap != null) {
+            converted = Collections.unmodifiableMap(
+                inferenceResultsMap.entrySet()
+                    .stream()
+                    .collect(Collectors.toMap(e -> new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, e.getKey()), Map.Entry::getValue))
+            );
+        }
+        return converted;
+    }
+
     /**
      * Build an inference results map to store a single inference result that is not associated with an inference ID.
      *
      * @param inferenceResults The inference result
      * @return An inference results map
      */
-    protected static Map<String, InferenceResults> buildBwcInferenceResultsMap(InferenceResults inferenceResults) {
-        return Map.of(PLACEHOLDER_INFERENCE_ID, inferenceResults);
+    static Map<FullyQualifiedInferenceId, InferenceResults> buildSingleResultInferenceResultsMap(InferenceResults inferenceResults) {
+        return Map.of(new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, PLACEHOLDER_INFERENCE_ID), inferenceResults);
     }
 
     /**
@@ -232,8 +329,8 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
      * @param inferenceResultsMap The inference results map
      * @return The inference result
      */
-    private static InferenceResults getBwcInferenceResults(Map<String, InferenceResults> inferenceResultsMap) {
-        return inferenceResultsMap.get(PLACEHOLDER_INFERENCE_ID);
+    private static InferenceResults getSingleInferenceResult(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
+        return inferenceResultsMap.get(new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, PLACEHOLDER_INFERENCE_ID));
     }
 
     @Override
@@ -255,7 +352,12 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
             return doRewriteBuildSemanticQuery(searchExecutionContext);
         }
 
-        return doRewriteGetInferenceResults(queryRewriteContext);
+        ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
+        if (resolvedIndices != null) {
+            return doRewriteGetInferenceResults(queryRewriteContext);
+        }
+
+        return this;
     }
 
     private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchExecutionContext) {
@@ -271,9 +373,11 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
             }
 
             String inferenceId = semanticTextFieldType.getSearchInferenceId();
-            InferenceResults inferenceResults = getBwcInferenceResults(inferenceResultsMap);
+            InferenceResults inferenceResults = getSingleInferenceResult(inferenceResultsMap);
             if (inferenceResults == null) {
-                inferenceResults = inferenceResultsMap.get(inferenceId);
+                inferenceResults = inferenceResultsMap.get(
+                    new FullyQualifiedInferenceId(searchExecutionContext.getLocalClusterAlias(), inferenceId)
+                );
             }
 
             if (inferenceResults == null) {
@@ -299,27 +403,30 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
     }
 
     private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
-        if (inferenceResultsMap != null) {
-            inferenceResultsErrorCheck();
-            return this;
-        }
-
         ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
-        if (resolvedIndices == null) {
-            throw new IllegalStateException(
-                "Rewriting on the coordinator node requires a query rewrite context with non-null resolved indices"
-            );
-        } else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
+        if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
             throw new IllegalArgumentException(NAME + " query does not support cross-cluster search");
         }
 
-        Map<String, InferenceResults> inferenceResultsMap = new ConcurrentHashMap<>();
-        Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
-        for (String inferenceId : inferenceIds) {
-            registerInferenceAsyncAction(queryRewriteContext, inferenceResultsMap, query, inferenceId);
+        SemanticQueryBuilder rewritten = this;
+        if (queryRewriteContext.hasAsyncActions() == false) {
+            Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
+            Map<FullyQualifiedInferenceId, InferenceResults> modifiedInferenceResultsMap = getInferenceResults(
+                queryRewriteContext,
+                inferenceIds,
+                inferenceResultsMap,
+                query
+            );
+
+            if (modifiedInferenceResultsMap == inferenceResultsMap) {
+                // The inference results map is fully populated, so we can perform error checking
+                inferenceResultsErrorCheck(modifiedInferenceResultsMap);
+            } else {
+                rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap);
+            }
         }
 
-        return new SemanticQueryBuilder(this, inferenceResultsMap);
+        return rewritten;
     }
 
     private static InferenceResults validateAndConvertInferenceResults(
@@ -364,9 +471,9 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
         return inferenceResults;
     }
 
-    private void inferenceResultsErrorCheck() {
+    private void inferenceResultsErrorCheck(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
         for (var entry : inferenceResultsMap.entrySet()) {
-            String inferenceId = entry.getKey();
+            String inferenceId = entry.getKey().inferenceId();
             InferenceResults inferenceResults = entry.getValue();
 
             if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults) {

+ 53 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.queries;
 
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.action.MockResolvedIndices;
 import org.elasticsearch.action.OriginalIndices;
 import org.elasticsearch.action.ResolvedIndices;
@@ -33,9 +34,11 @@ import org.elasticsearch.index.query.AbstractQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.Rewriteable;
+import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.WeightedToken;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
@@ -49,6 +52,7 @@ import org.elasticsearch.transport.RemoteClusterAware;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -64,7 +68,9 @@ import java.util.Map;
 import java.util.function.BiConsumer;
 import java.util.function.Supplier;
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.spy;
@@ -172,7 +178,7 @@ public abstract class AbstractInterceptedInferenceQueryBuilderTestCase<T extends
 
         // Test querying a semantic text field
         final T semanticFieldQuery = createQueryBuilder(field);
-        IllegalArgumentException e = expectThrows(
+        IllegalArgumentException e = assertThrows(
             IllegalArgumentException.class,
             () -> rewriteAndFetch(semanticFieldQuery, queryRewriteContext)
         );
@@ -192,6 +198,47 @@ public abstract class AbstractInterceptedInferenceQueryBuilderTestCase<T extends
         assertCoordinatorNodeRewriteOnNonInferenceField(nonInferenceFieldQuery, coordinatorRewritten);
     }
 
+    public void testSerializationRemoteClusterInferenceResults() throws Exception {
+        InferenceResults inferenceResults1 = new TextExpansionResults(
+            DEFAULT_RESULTS_FIELD,
+            List.of(new WeightedToken("foo", 1.0f)),
+            false
+        );
+        InferenceResults inferenceResults2 = new TextExpansionResults(
+            DEFAULT_RESULTS_FIELD,
+            List.of(new WeightedToken("bar", 2.0f)),
+            false
+        );
+
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap = Map.of(
+            new FullyQualifiedInferenceId(randomAlphaOfLength(5), randomAlphaOfLength(5)),
+            inferenceResults1,
+            new FullyQualifiedInferenceId(randomAlphaOfLength(5), randomAlphaOfLength(5)),
+            inferenceResults2
+        );
+
+        // It doesn't matter that the original query doesn't refer to an inference ID in the inference results map or if the inference
+        // results in the map don't match the type expected by the query. This only tests serialization, so it only matters that both
+        // the original query and the inference results map exists.
+        QueryBuilder interceptedQuery = createInterceptedQueryBuilder(createQueryBuilder(randomAlphaOfLength(5)), inferenceResultsMap);
+
+        // Test with the current transport version, which should work
+        QueryBuilder deserializedQuery = copyNamedWriteable(interceptedQuery, writableRegistry(), QueryBuilder.class);
+        assertThat(deserializedQuery, equalTo(interceptedQuery));
+
+        // Test with a transport version prior to cluster alias support, which should fail
+        TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween(
+            random(),
+            TransportVersions.NEW_SEMANTIC_QUERY_INTERCEPTORS,
+            TransportVersionUtils.getPreviousVersion(TransportVersions.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)
+        );
+        IllegalArgumentException e = assertThrows(
+            IllegalArgumentException.class,
+            () -> copyNamedWriteable(interceptedQuery, writableRegistry(), QueryBuilder.class, transportVersion)
+        );
+        assertThat(e.getMessage(), equalTo("Cannot serialize remote cluster inference results in a mixed-version cluster"));
+    }
+
     protected List<NamedWriteableRegistry.Entry> getNamedWriteables() {
         List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
         getPlugins().forEach(plugin -> entries.addAll(plugin.getNamedWriteables()));
@@ -208,6 +255,11 @@ public abstract class AbstractInterceptedInferenceQueryBuilderTestCase<T extends
 
     protected abstract T createQueryBuilder(String field);
 
+    protected abstract InterceptedInferenceQueryBuilder<T> createInterceptedQueryBuilder(
+        T originalQuery,
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
+    );
+
     protected abstract QueryRewriteInterceptor createQueryRewriteInterceptor();
 
     protected abstract TransportVersion getMinimalSupportedVersion();

+ 15 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java

@@ -33,6 +33,7 @@ import java.util.List;
 import java.util.Map;
 
 import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT;
+import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.notNullValue;
@@ -61,6 +62,17 @@ public class InterceptedInferenceKnnVectorQueryBuilderTests extends AbstractInte
             .addFilterQuery(new TermsQueryBuilder(IndexFieldMapper.NAME, randomAlphanumericOfLength(5)));
     }
 
+    @Override
+    protected InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> createInterceptedQueryBuilder(
+        KnnVectorQueryBuilder originalQuery,
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
+    ) {
+        return new InterceptedInferenceKnnVectorQueryBuilder(
+            new InterceptedInferenceKnnVectorQueryBuilder(originalQuery),
+            inferenceResultsMap
+        );
+    }
+
     @Override
     protected QueryRewriteInterceptor createQueryRewriteInterceptor() {
         return new SemanticKnnVectorQueryRewriteInterceptor();
@@ -166,7 +178,9 @@ public class InterceptedInferenceKnnVectorQueryBuilderTests extends AbstractInte
         assertThat(coordinatorIntercepted.inferenceResultsMap, notNullValue());
         assertThat(coordinatorIntercepted.inferenceResultsMap.size(), equalTo(1));
 
-        InferenceResults inferenceResults = coordinatorIntercepted.inferenceResultsMap.get(DENSE_INFERENCE_ID);
+        InferenceResults inferenceResults = coordinatorIntercepted.inferenceResultsMap.get(
+            new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, DENSE_INFERENCE_ID)
+        );
         assertThat(inferenceResults, notNullValue());
         assertThat(inferenceResults, instanceOf(MlTextEmbeddingResults.class));
         VectorData queryVector = new VectorData(((MlTextEmbeddingResults) inferenceResults).getInferenceAsFloat());

+ 20 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java

@@ -12,10 +12,12 @@ import org.elasticsearch.TransportVersions;
 import org.elasticsearch.index.query.MatchQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
 
 import java.util.Map;
 
+import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.notNullValue;
@@ -26,6 +28,14 @@ public class InterceptedInferenceMatchQueryBuilderTests extends AbstractIntercep
         return new MatchQueryBuilder(field, "foo").boost(randomFloatBetween(0.1f, 4.0f, true)).queryName(randomAlphanumericOfLength(5));
     }
 
+    @Override
+    protected InterceptedInferenceQueryBuilder<MatchQueryBuilder> createInterceptedQueryBuilder(
+        MatchQueryBuilder originalQuery,
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
+    ) {
+        return new InterceptedInferenceMatchQueryBuilder(new InterceptedInferenceMatchQueryBuilder(originalQuery), inferenceResultsMap);
+    }
+
     @Override
     protected QueryRewriteInterceptor createQueryRewriteInterceptor() {
         return new SemanticMatchQueryRewriteInterceptor();
@@ -99,8 +109,16 @@ public class InterceptedInferenceMatchQueryBuilderTests extends AbstractIntercep
         assertThat(coordinatorIntercepted.originalQuery, equalTo(matchQuery));
         assertThat(coordinatorIntercepted.inferenceResultsMap, notNullValue());
         assertThat(coordinatorIntercepted.inferenceResultsMap.size(), equalTo(2));
-        assertTrue(coordinatorIntercepted.inferenceResultsMap.containsKey(DENSE_INFERENCE_ID));
-        assertTrue(coordinatorIntercepted.inferenceResultsMap.containsKey(SPARSE_INFERENCE_ID));
+        assertTrue(
+            coordinatorIntercepted.inferenceResultsMap.containsKey(
+                new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, DENSE_INFERENCE_ID)
+            )
+        );
+        assertTrue(
+            coordinatorIntercepted.inferenceResultsMap.containsKey(
+                new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, SPARSE_INFERENCE_ID)
+            )
+        );
 
         final SemanticQueryBuilder expectedSemanticQuery = new SemanticQueryBuilder(
             field,

+ 15 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java

@@ -30,6 +30,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.notNullValue;
@@ -55,6 +56,17 @@ public class InterceptedInferenceSparseVectorQueryBuilderTests extends AbstractI
         ).boost(randomFloatBetween(0.1f, 4.0f, true)).queryName(randomAlphanumericOfLength(5));
     }
 
+    @Override
+    protected InterceptedInferenceQueryBuilder<SparseVectorQueryBuilder> createInterceptedQueryBuilder(
+        SparseVectorQueryBuilder originalQuery,
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
+    ) {
+        return new InterceptedInferenceSparseVectorQueryBuilder(
+            new InterceptedInferenceSparseVectorQueryBuilder(originalQuery),
+            inferenceResultsMap
+        );
+    }
+
     @Override
     protected QueryRewriteInterceptor createQueryRewriteInterceptor() {
         return new SemanticSparseVectorQueryRewriteInterceptor();
@@ -134,7 +146,9 @@ public class InterceptedInferenceSparseVectorQueryBuilderTests extends AbstractI
         assertThat(coordinatorIntercepted.inferenceResultsMap, notNullValue());
         assertThat(coordinatorIntercepted.inferenceResultsMap.size(), equalTo(1));
 
-        InferenceResults inferenceResults = coordinatorIntercepted.inferenceResultsMap.get(SPARSE_INFERENCE_ID);
+        InferenceResults inferenceResults = coordinatorIntercepted.inferenceResultsMap.get(
+            new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, SPARSE_INFERENCE_ID)
+        );
         assertThat(inferenceResults, notNullValue());
         assertThat(inferenceResults, instanceOf(TextExpansionResults.class));
         TextExpansionResults textExpansionResults = (TextExpansionResults) inferenceResults;

+ 66 - 40
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java

@@ -33,10 +33,7 @@ import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.compress.CompressedXContent;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
-import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
-import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.core.Nullable;
@@ -94,8 +91,10 @@ import java.util.function.Supplier;
 
 import static org.apache.lucene.search.BooleanClause.Occur.FILTER;
 import static org.apache.lucene.search.BooleanClause.Occur.MUST;
+import static org.elasticsearch.TransportVersions.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS;
+import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
-import static org.hamcrest.Matchers.containsString;
+import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.notNullValue;
@@ -108,10 +107,6 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
     private static final String INFERENCE_ID = "test_service";
     private static final String SEARCH_INFERENCE_ID = "search_test_service";
 
-    private static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = TransportVersion.fromName(
-        "semantic_query_multiple_inference_ids"
-    );
-
     private static InferenceResultType inferenceResultType;
     private static DenseVectorFieldMapper.ElementType denseVectorElementType;
     private static boolean useSearchInferenceId;
@@ -375,6 +370,36 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
         }
     }
 
+    public void testSerializationRemoteClusterInferenceResults() throws IOException {
+        InferenceResults inferenceResults1 = new TextExpansionResults(
+            DEFAULT_RESULTS_FIELD,
+            List.of(new WeightedToken("foo", 1.0f)),
+            false
+        );
+        InferenceResults inferenceResults2 = new TextExpansionResults(
+            DEFAULT_RESULTS_FIELD,
+            List.of(new WeightedToken("bar", 2.0f)),
+            false
+        );
+
+        Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap = Map.of(
+            new FullyQualifiedInferenceId(randomAlphaOfLength(5), randomAlphaOfLength(5)),
+            inferenceResults1,
+            new FullyQualifiedInferenceId(randomAlphaOfLength(5), randomAlphaOfLength(5)),
+            inferenceResults2
+        );
+
+        SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(
+            randomAlphaOfLength(5),
+            randomAlphaOfLength(5),
+            null,
+            inferenceResultsMap
+        );
+
+        QueryBuilder deserializedQuery = copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class);
+        assertThat(deserializedQuery, equalTo(originalQuery));
+    }
+
     public void testSerializationBwc() throws IOException {
         InferenceResults inferenceResults1 = new TextExpansionResults(
             DEFAULT_RESULTS_FIELD,
@@ -396,26 +421,18 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
                 fieldName,
                 query,
                 null,
-                Map.of(randomAlphaOfLength(5), inferenceResults)
+                Map.of(new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, randomAlphaOfLength(5)), inferenceResults)
             );
             SemanticQueryBuilder bwcQuery = new SemanticQueryBuilder(
                 fieldName,
                 query,
                 null,
-                SemanticQueryBuilder.buildBwcInferenceResultsMap(inferenceResults)
+                SemanticQueryBuilder.buildSingleResultInferenceResultsMap(inferenceResults)
             );
 
-            try (BytesStreamOutput output = new BytesStreamOutput()) {
-                output.setTransportVersion(version);
-                output.writeNamedWriteable(originalQuery);
-                try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) {
-                    in.setTransportVersion(version);
-                    QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class);
-
-                    SemanticQueryBuilder expectedQuery = version.supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS) ? originalQuery : bwcQuery;
-                    assertThat(deserializedQuery, equalTo(expectedQuery));
-                }
-            }
+            QueryBuilder deserializedQuery = copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class, version);
+            SemanticQueryBuilder expectedQuery = version.supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV) ? originalQuery : bwcQuery;
+            assertThat(deserializedQuery, equalTo(expectedQuery));
         };
 
         for (int i = 0; i < 100; i++) {
@@ -431,8 +448,14 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
         CheckedBiConsumer<List<InferenceResults>, TransportVersion, IOException> assertMultipleInferenceResults = (
             inferenceResultsList,
             version) -> {
-            Map<String, InferenceResults> inferenceResultsMap = new HashMap<>(inferenceResultsList.size());
-            inferenceResultsList.forEach(result -> inferenceResultsMap.put(randomAlphaOfLength(5), result));
+            boolean remoteCluster = randomBoolean();
+            Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap = new HashMap<>(inferenceResultsList.size());
+            inferenceResultsList.forEach(
+                result -> inferenceResultsMap.put(
+                    new FullyQualifiedInferenceId(remoteCluster ? randomAlphaOfLength(5) : LOCAL_CLUSTER_GROUP_KEY, randomAlphaOfLength(5)),
+                    result
+                )
+            );
             SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(
                 randomAlphaOfLength(5),
                 randomAlphaOfLength(5),
@@ -440,23 +463,26 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
                 inferenceResultsMap
             );
 
-            try (BytesStreamOutput output = new BytesStreamOutput()) {
-                output.setTransportVersion(version);
-
-                if (version.supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) {
-                    output.writeNamedWriteable(originalQuery);
-                    try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) {
-                        in.setTransportVersion(version);
-                        QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class);
-                        assertThat(deserializedQuery, equalTo(originalQuery));
-                    }
-                } else {
-                    IllegalArgumentException e = assertThrows(
-                        IllegalArgumentException.class,
-                        () -> output.writeNamedWriteable(originalQuery)
-                    );
-                    assertThat(e.getMessage(), containsString("Cannot query multiple inference IDs in a mixed-version cluster"));
-                }
+            String expectedErrorMessage;
+            if (version.supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
+                expectedErrorMessage = null;
+            } else if (version.supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
+                expectedErrorMessage = remoteCluster
+                    ? "Cannot serialize remote cluster inference results in a mixed-version cluster"
+                    : null;
+            } else {
+                expectedErrorMessage = "Cannot query multiple inference IDs in a mixed-version cluster";
+            }
+
+            if (expectedErrorMessage != null) {
+                IllegalArgumentException e = assertThrows(
+                    IllegalArgumentException.class,
+                    () -> copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class, version)
+                );
+                assertThat(e.getMessage(), equalTo(expectedErrorMessage));
+            } else {
+                QueryBuilder deserializedQuery = copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class, version);
+                assertThat(deserializedQuery, equalTo(originalQuery));
             }
         };