Browse Source

[ES|QL] COMPLETION command physical plan (#126766)

Aurélien FOUCRET 6 months ago
parent
commit
6702afc96c

+ 2 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java

@@ -51,6 +51,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
 import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec;
 import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
 import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
 import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
 
 import java.util.ArrayList;
@@ -94,6 +95,7 @@ public class PlanWritables {
     public static List<NamedWriteableRegistry.Entry> physical() {
         return List.of(
             AggregateExec.ENTRY,
+            CompletionExec.ENTRY,
             DissectExec.ENTRY,
             EnrichExec.ENTRY,
             EsQueryExec.ENTRY,

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java

@@ -29,7 +29,7 @@ public abstract class InferencePlan<PlanType extends InferencePlan<PlanType>> ex
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
-        Source.EMPTY.writeTo(out);
+        source().writeTo(out);
         out.writeNamedWriteable(child());
         out.writeNamedWriteable(inferenceId());
     }

+ 114 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExec.java

@@ -0,0 +1,114 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
+import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
+
+public class CompletionExec extends InferenceExec {
+
+    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+        PhysicalPlan.class,
+        "CompletionExec",
+        CompletionExec::new
+    );
+
+    private final Expression prompt;
+    private final Attribute targetField;
+    private List<Attribute> lazyOutput;
+
+    public CompletionExec(Source source, PhysicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
+        super(source, child, inferenceId);
+        this.prompt = prompt;
+        this.targetField = targetField;
+    }
+
+    public CompletionExec(StreamInput in) throws IOException {
+        this(
+            Source.readFrom((PlanStreamInput) in),
+            in.readNamedWriteable(PhysicalPlan.class),
+            in.readNamedWriteable(Expression.class),
+            in.readNamedWriteable(Expression.class),
+            in.readNamedWriteable(Attribute.class)
+        );
+    }
+
+    @Override
+    public String getWriteableName() {
+        return ENTRY.name;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+        out.writeNamedWriteable(prompt);
+        out.writeNamedWriteable(targetField);
+    }
+
+    public Expression prompt() {
+        return prompt;
+    }
+
+    public Attribute targetField() {
+        return targetField;
+    }
+
+    @Override
+    protected NodeInfo<? extends PhysicalPlan> info() {
+        return NodeInfo.create(this, CompletionExec::new, child(), inferenceId(), prompt, targetField);
+    }
+
+    @Override
+    public UnaryExec replaceChild(PhysicalPlan newChild) {
+        return new CompletionExec(source(), newChild, inferenceId(), prompt, targetField);
+    }
+
+    @Override
+    public List<Attribute> output() {
+        if (lazyOutput == null) {
+            lazyOutput = mergeOutputAttributes(List.of(targetField), child().output());
+        }
+
+        return lazyOutput;
+    }
+
+    @Override
+    protected AttributeSet computeReferences() {
+        return prompt.references();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        if (super.equals(o) == false) return false;
+        CompletionExec completion = (CompletionExec) o;
+
+        return Objects.equals(prompt, completion.prompt) && Objects.equals(targetField, completion.targetField);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(super.hashCode(), prompt, targetField);
+    }
+}

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java

@@ -30,7 +30,7 @@ public abstract class InferenceExec extends UnaryExec {
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
-        Source.EMPTY.writeTo(out);
+        source().writeTo(out);
         out.writeNamedWriteable(child());
         out.writeNamedWriteable(inferenceId());
     }

+ 6 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java

@@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Project;
 import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
 import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
 import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
 import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
 import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
 import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
@@ -43,6 +44,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
 import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec;
 import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
 import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
 import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
 import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders;
 
@@ -99,6 +101,10 @@ class MapperUtils {
             );
         }
 
+        if (p instanceof Completion completion) {
+            return new CompletionExec(completion.source(), child, completion.inferenceId(), completion.prompt(), completion.targetField());
+        }
+
         if (p instanceof Enrich enrich) {
             return new EnrichExec(
                 enrich.source(),

+ 3 - 10
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/CompletionSerializationTests.java

@@ -12,7 +12,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.expression.Literal;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.core.type.DataType;
-import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests;
+import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests;
 import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
 
@@ -22,9 +22,7 @@ public class CompletionSerializationTests extends AbstractLogicalPlanSerializati
 
     @Override
     protected Completion createTestInstance() {
-        Source source = randomSource();
-        LogicalPlan child = randomChild(0);
-        return new Completion(source, child, randomInferenceId(), randomPrompt(), randomAttribute());
+        return new Completion(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute());
     }
 
     @Override
@@ -43,11 +41,6 @@ public class CompletionSerializationTests extends AbstractLogicalPlanSerializati
         return new Completion(instance.source(), child, inferenceId, prompt, targetField);
     }
 
-    @Override
-    protected boolean alwaysEmptySource() {
-        return true;
-    }
-
     private Literal randomInferenceId() {
         return new Literal(Source.EMPTY, randomIdentifier(), DataType.KEYWORD);
     }
@@ -57,6 +50,6 @@ public class CompletionSerializationTests extends AbstractLogicalPlanSerializati
     }
 
     private Attribute randomAttribute() {
-        return FieldAttributeTests.createFieldAttribute(3, randomBoolean());
+        return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
     }
 }

+ 0 - 5
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/RerankSerializationTests.java

@@ -47,11 +47,6 @@ public class RerankSerializationTests extends AbstractLogicalPlanSerializationTe
         return new Rerank(instance.source(), child, inferenceId, queryText, fields, instance.scoreAttribute());
     }
 
-    @Override
-    protected boolean alwaysEmptySource() {
-        return true;
-    }
-
     private List<Alias> randomFields() {
         return randomList(0, 10, AliasTests::randomAlias);
     }

+ 54 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/CompletionExecSerializationTests.java

@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical.inference;
+
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests;
+import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests;
+import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
+
+import java.io.IOException;
+
+public class CompletionExecSerializationTests extends AbstractPhysicalPlanSerializationTests<CompletionExec> {
+    @Override
+    protected CompletionExec createTestInstance() {
+        return new CompletionExec(randomSource(), randomChild(0), randomInferenceId(), randomPrompt(), randomAttribute());
+    }
+
+    @Override
+    protected CompletionExec mutateInstance(CompletionExec instance) throws IOException {
+        PhysicalPlan child = instance.child();
+        Expression inferenceId = instance.inferenceId();
+        Expression prompt = instance.prompt();
+        Attribute targetField = instance.targetField();
+
+        switch (between(0, 3)) {
+            case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
+            case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId);
+            case 2 -> prompt = randomValueOtherThan(prompt, this::randomPrompt);
+            case 3 -> targetField = randomValueOtherThan(targetField, this::randomAttribute);
+        }
+        return new CompletionExec(instance.source(), child, inferenceId, prompt, targetField);
+    }
+
+    private Literal randomInferenceId() {
+        return new Literal(Source.EMPTY, randomIdentifier(), DataType.KEYWORD);
+    }
+
+    private Expression randomPrompt() {
+        return randomBoolean() ? new Literal(Source.EMPTY, randomIdentifier(), DataType.KEYWORD) : randomAttribute();
+    }
+
+    private Attribute randomAttribute() {
+        return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
+    }
+}

+ 0 - 5
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExecSerializationTests.java

@@ -47,11 +47,6 @@ public class RerankExecSerializationTests extends AbstractPhysicalPlanSerializat
         return new RerankExec(instance.source(), child, inferenceId, queryText, fields, scoreAttribute());
     }
 
-    @Override
-    protected boolean alwaysEmptySource() {
-        return true;
-    }
-
     private List<Alias> randomFields() {
         return randomList(0, 10, AliasTests::randomAlias);
     }