Browse Source

EQL: Deal with internally created IN in a different way for EQL (#132167) (#132393)

Andrei Stefan 2 months ago
parent
commit
e282bb8217

+ 6 - 0
docs/changelog/132167.yaml

@@ -0,0 +1,6 @@
+pr: 132167
+summary: Deal with internally created IN in a different way for EQL
+area: EQL
+type: bug
+issues:
+ - 118621

+ 7 - 0
x-pack/plugin/eql/qa/common/build.gradle

@@ -8,3 +8,10 @@ dependencies {
   // TOML parser for EqlActionIT tests
   api 'io.ous:jtoml:2.0.0'
 }
+
+tasks.register("loadTestData", JavaExec) {
+  group = "Execution"
+  description = "Loads EQL Spec Tests data on a running stand-alone instance"
+  classpath = sourceSets.main.runtimeClasspath
+  mainClass = "org.elasticsearch.test.eql.DataLoader"
+}

+ 36 - 14
x-pack/plugin/eql/qa/common/src/main/java/org/elasticsearch/test/eql/DataLoader.java

@@ -76,39 +76,60 @@ public class DataLoader {
     public static void main(String[] args) throws IOException {
         main = true;
         try (RestClient client = RestClient.builder(new HttpHost("localhost", 9200)).build()) {
-            loadDatasetIntoEs(client, DataLoader::createParser);
+            loadDatasetIntoEsWithIndexCreator(client, DataLoader::createParser, (restClient, indexName, indexMapping) -> {
+                // don't use ESRestTestCase methods here or, if you do, test running the main method before making the change
+                StringBuilder jsonBody = new StringBuilder("{");
+                jsonBody.append("\"settings\":{\"number_of_shards\":1},");
+                jsonBody.append("\"mappings\":");
+                jsonBody.append(indexMapping);
+                jsonBody.append("}");
+
+                Request request = new Request("PUT", "/" + indexName);
+                request.setJsonEntity(jsonBody.toString());
+                restClient.performRequest(request);
+            });
         }
     }
 
     public static void loadDatasetIntoEs(RestClient client, CheckedBiFunction<XContent, InputStream, XContentParser, IOException> p)
         throws IOException {
+        loadDatasetIntoEsWithIndexCreator(client, p, (restClient, indexName, indexMapping) -> {
+            ESRestTestCase.createIndex(restClient, indexName, Settings.builder().put("number_of_shards", 1).build(), indexMapping, null);
+        });
+    }
+
+    private static void loadDatasetIntoEsWithIndexCreator(
+        RestClient client,
+        CheckedBiFunction<XContent, InputStream, XContentParser, IOException> p,
+        IndexCreator indexCreator
+    ) throws IOException {
 
         //
         // Main Index
         //
-        load(client, TEST_INDEX, null, DataLoader::timestampToUnixMillis, p);
+        load(client, TEST_INDEX, null, DataLoader::timestampToUnixMillis, p, indexCreator);
         //
         // Aux Index
         //
-        load(client, TEST_EXTRA_INDEX, null, null, p);
+        load(client, TEST_EXTRA_INDEX, null, null, p, indexCreator);
         //
         // Date_Nanos index
         //
         // The data for this index is loaded from the same endgame-140.data sample, only having the mapping for @timestamp changed: the
         // chosen Windows filetime timestamps (2017+) can coincidentally also be readily used as nano-resolution unix timestamps (1973+).
         // There are mixed values with and without nanos precision so that the filtering is properly tested for both cases.
-        load(client, TEST_NANOS_INDEX, TEST_INDEX, DataLoader::timestampToUnixNanos, p);
-        load(client, TEST_SAMPLE, null, null, p);
+        load(client, TEST_NANOS_INDEX, TEST_INDEX, DataLoader::timestampToUnixNanos, p, indexCreator);
+        load(client, TEST_SAMPLE, null, null, p, indexCreator);
         //
         // missing_events index
         //
-        load(client, TEST_MISSING_EVENTS_INDEX, null, null, p);
-        load(client, TEST_SAMPLE_MULTI, null, null, p);
+        load(client, TEST_MISSING_EVENTS_INDEX, null, null, p, indexCreator);
+        load(client, TEST_SAMPLE_MULTI, null, null, p, indexCreator);
         //
         // index with a runtime field ("broken", type long) that causes shard failures.
         // the rest of the mapping is the same as TEST_INDEX
         //
-        load(client, TEST_SHARD_FAILURES_INDEX, null, DataLoader::timestampToUnixMillis, p);
+        load(client, TEST_SHARD_FAILURES_INDEX, null, DataLoader::timestampToUnixMillis, p, indexCreator);
     }
 
     private static void load(
@@ -116,7 +137,8 @@ public class DataLoader {
         String indexNames,
         String dataName,
         Consumer<Map<String, Object>> datasetTransform,
-        CheckedBiFunction<XContent, InputStream, XContentParser, IOException> p
+        CheckedBiFunction<XContent, InputStream, XContentParser, IOException> p,
+        IndexCreator indexCreator
     ) throws IOException {
         String[] splitNames = indexNames.split(",");
         for (String indexName : splitNames) {
@@ -130,15 +152,11 @@ public class DataLoader {
             if (data == null) {
                 throw new IllegalArgumentException("Cannot find resource " + name);
             }
-            createTestIndex(client, indexName, readMapping(mapping));
+            indexCreator.createIndex(client, indexName, readMapping(mapping));
             loadData(client, indexName, datasetTransform, data, p);
         }
     }
 
-    private static void createTestIndex(RestClient client, String indexName, String mapping) throws IOException {
-        ESRestTestCase.createIndex(client, indexName, Settings.builder().put("number_of_shards", 1).build(), mapping, null);
-    }
-
     /**
      * Reads the mapping file, ignoring comments and replacing placeholders for random types.
      */
@@ -236,4 +254,8 @@ public class DataLoader {
         NamedXContentRegistry contentRegistry = new NamedXContentRegistry(ClusterModule.getNamedXWriteables());
         return xContent.createParser(contentRegistry, LoggingDeprecationHandler.INSTANCE, data);
     }
+
+    private interface IndexCreator {
+        void createIndex(RestClient client, String indexName, String mapping) throws IOException;
+    }
 }

+ 8 - 1
x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java

@@ -44,7 +44,6 @@ import org.elasticsearch.xpack.ql.optimizer.OptimizerRules;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BinaryComparisonSimplification;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanFunctionEqualsElimination;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanSimplification;
-import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.CombineDisjunctionsToIn;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ConstantFolding;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.LiteralsOnTheRight;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerRule;
@@ -252,6 +251,14 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
 
     }
 
+    static class CombineDisjunctionsToIn extends org.elasticsearch.xpack.ql.optimizer.OptimizerRules.CombineDisjunctionsToIn {
+
+        @Override
+        protected boolean shouldValidateIn() {
+            return true;
+        }
+    }
+
     static class PruneFilters extends org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneFilters {
 
         @Override

+ 84 - 0
x-pack/plugin/eql/src/test/resources/querytranslator_tests.txt

@@ -123,6 +123,90 @@ process where process_name in ("python.exe", "SMSS.exe", "explorer.exe")
 "terms":{"process_name":["python.exe","SMSS.exe","explorer.exe"],
 ;
 
+mutipleOrEquals_As_InTranslation1
+process where process_name == "python.exe" or process_name == "SMSS.exe" or process_name == "explorer.exe"
+;
+"terms":{"process_name":["python.exe","SMSS.exe","explorer.exe"],
+;
+
+multipleOrAndEquals_As_InTranslation
+process where process_name == "python.exe" and process_name == "SMSS.exe" or process_name == "explorer.exe" or process_name == "test.exe"
+;
+{"bool":{"should":[{"bool":{"must":[{"term":{"process_name":{"value":"python.exe"}}},{"term":{"process_name":{"value":"SMSS.exe"}}}],"boost":1.0}},{"terms":{"process_name":["explorer.exe","test.exe"],"boost":1.0}}],"boost":1.0}}
+;
+
+mutipleOrEquals_As_InTranslation2
+process where source_address == "123.12.1.1" or (opcode == 123 or opcode == 127)
+;
+{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"terms":{"opcode":[123,127],"boost":1.0}}],"boost":1.0}}
+;
+
+mutipleOrEquals_As_InTranslation3
+process where (source_address == "123.12.1.1" or source_address == "127.0.0.1") and (opcode == 123 or opcode == 127)
+;
+{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"term":{"source_address":{"value":"127.0.0.1"}}}],"boost":1.0}},{"terms":{"opcode":[123,127],"boost":1.0}}
+;
+
+mutipleOrEquals_As_InTranslation4
+process where (source_address == "123.12.1.1" or source_address == "127.0.0.1") and (opcode == 123 or opcode == 127)
+;
+"must":[{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"term":{"source_address":{"value":"127.0.0.1"}}}],"boost":1.0}},{"terms":{"opcode":[123,127],"boost":1.0}},{"term":{"event.category":{"value":"process"}}}]
+;
+
+multipleOrIncompatibleTypes1
+process where process_name == "python.exe" or process_name == 2 or process_name == "3"
+;
+{"bool":{"should":[{"term":{"process_name":{"value":"python.exe"}}},{"term":{"process_name":{"value":2}}},{"term":{"process_name":{"value":"3"}}}],"boost":1.0}}
+;
+
+multipleOrIncompatibleTypes2
+process where process_name == "1" or process_name == 2 or process_name == "3"
+;
+{"bool":{"should":[{"term":{"process_name":{"value":"1"}}},{"term":{"process_name":{"value":2}}},{"term":{"process_name":{"value":"3"}}}],"boost":1.0}}
+;
+
+multipleOrIncompatibleTypes3
+process where process_name == 1.2 or process_name == 2 or process_name == "3"
+;
+{"bool":{"should":[{"term":{"process_name":{"value":1.2}}},{"term":{"process_name":{"value":2}}},{"term":{"process_name":{"value":"3"}}}],"boost":1.0}}
+;
+
+// this query as an equivalent with
+// process where process_name in (1.2, 2, 3)
+// will result in a user error: 1st argument of [process_name in (1.2, 2, 3)] must be [keyword], found value [1.2] type [double]
+multipleOrIncompatibleTypes4
+process where process_name == 1.2 or process_name == 2 or process_name == 3
+;
+{"bool":{"should":[{"term":{"process_name":{"value":1.2}}},{"term":{"process_name":{"value":2}}},{"term":{"process_name":{"value":3}}}],"boost":1.0}}
+;
+
+// this query as an equivalent with
+// process where source_address in ("123.12.1.1", "123.12.1.2")
+// will result in a user error: 1st argument of [source_address in ("123.12.1.1", "123.12.1.2")] must be [ip], found value ["123.12.1.1"] type [keyword]
+multipleOrIncompatibleTypes5
+process where source_address == "123.12.1.1" or source_address == "123.12.1.2"
+;
+{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"term":{"source_address":{"value":"123.12.1.2"}}}],"boost":1.0}}
+;
+
+multipleOrIncompatibleTypes6
+process where source_address == "123.12.1.1" or source_address == concat("123.12.","1.2")
+;
+{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"term":{"source_address":{"value":"123.12.1.2"}}}],"boost":1.0}}
+;
+
+multipleOrIncompatibleTypes7
+process where source_address == "123.12.1.1" and (source_address == "123.12.1.2" or source_address >= "127.0.0.1")
+;
+"must":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.2"}}},{"range":{"source_address":{"gte":"127.0.0.1","boost":1.0}}}],"boost":1.0}},{"term":{"event.category":{"value":"process"}}}]
+;
+
+multipleOrIncompatibleTypes8
+process where source_address == "123.12.1.1" and (source_address == "123.12.1.2" or source_address == "127.0.0.1")
+;
+"must":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.2"}}},{"term":{"source_address":{"value":"127.0.0.1"}}}],"boost":1.0}},{"term":{"event.category":{"value":"process"}}}]
+;
+
 inFilterWithScripting
 process where substring(command_line, 5) in ("test*","best")
 ;

+ 4 - 0
x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/In.java

@@ -177,6 +177,10 @@ public class In extends ScalarFunction {
         return super.resolveType();
     }
 
+    public TypeResolution validateInTypes() {
+        return resolveType();
+    }
+
     @Override
     public int hashCode() {
         return Objects.hash(value, list);

+ 47 - 12
x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java

@@ -1203,8 +1203,8 @@ public final class OptimizerRules {
      * 2. a == 1 OR a IN (2) becomes a IN (1, 2)
      * 3. a IN (1) OR a IN (2) becomes a IN (1, 2)
      *
-     * This rule does NOT check for type compatibility as that phase has been
-     * already be verified in the analyzer.
+     * By default (see {@link #shouldValidateIn()}), this rule does NOT check for type compatibility as that phase has
+     * already been verified in the analyzer, but this behavior can be changed by subclasses.
      */
     public static class CombineDisjunctionsToIn extends OptimizerExpressionRule<Or> {
         public CombineDisjunctionsToIn() {
@@ -1214,18 +1214,24 @@ public final class OptimizerRules {
         @Override
         protected Expression rule(Or or) {
             Expression e = or;
-            // look only at equals and In
+            // look only at Equals and In
             List<Expression> exps = splitOr(e);
 
             Map<Expression, Set<Expression>> found = new LinkedHashMap<>();
+            Map<Expression, List<Expression>> originalOrs = new LinkedHashMap<>();
             ZoneId zoneId = null;
             List<Expression> ors = new LinkedList<>();
 
             for (Expression exp : exps) {
                 if (exp instanceof Equals eq) {
-                    // consider only equals against foldables
+                    // consider only Equals against foldables
                     if (eq.right().foldable()) {
                         found.computeIfAbsent(eq.left(), k -> new LinkedHashSet<>()).add(eq.right());
+                        if (shouldValidateIn()) {
+                            // in case there is an optimized In being built and its validation fails, rebuild the original ORs
+                            // so, keep around the original Expressions
+                            originalOrs.computeIfAbsent(eq.left(), k -> new ArrayList<>()).add(eq);
+                        }
                     } else {
                         ors.add(exp);
                     }
@@ -1234,6 +1240,11 @@ public final class OptimizerRules {
                     }
                 } else if (exp instanceof In in) {
                     found.computeIfAbsent(in.value(), k -> new LinkedHashSet<>()).addAll(in.list());
+                    if (shouldValidateIn()) {
+                        // in case there is an optimized In being built and its validation fails, rebuild the original ORs
+                        // so, keep around the original Expressions
+                        originalOrs.computeIfAbsent(in.value(), k -> new ArrayList<>()).add(in);
+                    }
                     if (zoneId == null) {
                         zoneId = in.zoneId();
                     }
@@ -1243,11 +1254,31 @@ public final class OptimizerRules {
             }
 
             if (found.isEmpty() == false) {
-                // combine equals alongside the existing ors
+                // combine Equals alongside the existing ORs
                 final ZoneId finalZoneId = zoneId;
-                found.forEach(
-                    (k, v) -> { ors.add(v.size() == 1 ? createEquals(k, v, finalZoneId) : createIn(k, new ArrayList<>(v), finalZoneId)); }
-                );
+                found.forEach((k, v) -> {
+                    if (v.size() == 1) {
+                        ors.add(createEquals(k, v.iterator().next(), finalZoneId));
+                    } else {
+                        In in = createIn(k, new ArrayList<>(v), finalZoneId);
+                        // IN has its own particularities when it comes to type resolution and not all implementations
+                        // double check the validity of an internally created IN (like the one created here). EQL is one where the IN
+                        // implementation is like this mechanism here has been specifically created for it
+                        if (shouldValidateIn()) {
+                            Expression.TypeResolution resolution = in.validateInTypes();
+                            if (resolution.unresolved()) {
+                                // if the internally created In is not valid, fall back to the original ORs
+                                assert originalOrs.containsKey(k);
+                                assert originalOrs.get(k).isEmpty() == false;
+                                ors.add(combineOr(originalOrs.get(k)));
+                            } else {
+                                ors.add(in);
+                            }
+                        } else {
+                            ors.add(in);
+                        }
+                    }
+                });
 
                 Expression combineOr = combineOr(ors);
                 // check the result semantically since the result might different in order
@@ -1261,13 +1292,17 @@ public final class OptimizerRules {
             return e;
         }
 
-        protected Equals createEquals(Expression k, Set<Expression> v, ZoneId finalZoneId) {
-            return new Equals(k.source(), k, v.iterator().next(), finalZoneId);
-        }
-
         protected In createIn(Expression key, List<Expression> values, ZoneId zoneId) {
             return new In(key.source(), key, values, zoneId);
         }
+
+        protected boolean shouldValidateIn() {
+            return false;
+        }
+
+        private Equals createEquals(Expression key, Expression value, ZoneId finalZoneId) {
+            return new Equals(key.source(), key, value, finalZoneId);
+        }
     }
 
     public static class PushDownAndCombineFilters extends OptimizerRule<Filter> {

+ 217 - 36
x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java

@@ -66,6 +66,8 @@ import org.elasticsearch.xpack.ql.util.StringUtils;
 import java.time.ZoneId;
 import java.util.Collections;
 import java.util.List;
+import java.util.Set;
+import java.util.function.Consumer;
 
 import static java.util.Arrays.asList;
 import static java.util.Collections.emptyList;
@@ -105,6 +107,9 @@ public class OptimizerRulesTests extends ESTestCase {
     private static final Literal FOUR = L(4);
     private static final Literal FIVE = L(5);
     private static final Literal SIX = L(6);
+    private static final Literal TEXT_A = L("A");
+    private static final Literal TEXT_B = L("B");
+    private static final Literal TEXT_C = L("C");
 
     public static class DummyBooleanExpression extends Expression {
 
@@ -1491,48 +1496,71 @@ public class OptimizerRulesTests extends ESTestCase {
     //
     // CombineDisjunction in Equals
     //
+
+    // CombineDisjunctionsToIn with shouldValidateIn as true
+    private final class ValidateableCombineDisjunctionsToIn extends CombineDisjunctionsToIn {
+        @Override
+        protected boolean shouldValidateIn() {
+            return true;
+        }
+    };
+
+    private void assertCombineDisjunctionsToIn(Consumer<CombineDisjunctionsToIn> tester) {
+        for (CombineDisjunctionsToIn rule : Set.of(new CombineDisjunctionsToIn(), new ValidateableCombineDisjunctionsToIn())) {
+            tester.accept(rule);
+        }
+    }
+
     public void testTwoEqualsWithOr() throws Exception {
         FieldAttribute fa = getFieldAttribute();
 
         Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO));
-        Expression e = new CombineDisjunctionsToIn().rule(or);
-        assertEquals(In.class, e.getClass());
-        In in = (In) e;
-        assertEquals(fa, in.value());
-        assertThat(in.list(), contains(ONE, TWO));
+        assertCombineDisjunctionsToIn((rule) -> {
+            Expression e = rule.rule(or);
+            assertEquals(In.class, e.getClass());
+            In in = (In) e;
+            assertEquals(fa, in.value());
+            assertThat(in.list(), contains(ONE, TWO));
+        });
     }
 
     public void testTwoEqualsWithSameValue() throws Exception {
         FieldAttribute fa = getFieldAttribute();
 
         Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, ONE));
-        Expression e = new CombineDisjunctionsToIn().rule(or);
-        assertEquals(Equals.class, e.getClass());
-        Equals eq = (Equals) e;
-        assertEquals(fa, eq.left());
-        assertEquals(ONE, eq.right());
+        assertCombineDisjunctionsToIn((rule) -> {
+            Expression e = rule.rule(or);
+            assertEquals(Equals.class, e.getClass());
+            Equals eq = (Equals) e;
+            assertEquals(fa, eq.left());
+            assertEquals(ONE, eq.right());
+        });
     }
 
     public void testOneEqualsOneIn() throws Exception {
         FieldAttribute fa = getFieldAttribute();
 
         Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, singletonList(TWO)));
-        Expression e = new CombineDisjunctionsToIn().rule(or);
-        assertEquals(In.class, e.getClass());
-        In in = (In) e;
-        assertEquals(fa, in.value());
-        assertThat(in.list(), contains(ONE, TWO));
+        assertCombineDisjunctionsToIn((rule) -> {
+            Expression e = rule.rule(or);
+            assertEquals(In.class, e.getClass());
+            In in = (In) e;
+            assertEquals(fa, in.value());
+            assertThat(in.list(), contains(ONE, TWO));
+        });
     }
 
     public void testOneEqualsOneInWithSameValue() throws Exception {
         FieldAttribute fa = getFieldAttribute();
 
         Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, asList(ONE, TWO)));
-        Expression e = new CombineDisjunctionsToIn().rule(or);
-        assertEquals(In.class, e.getClass());
-        In in = (In) e;
-        assertEquals(fa, in.value());
-        assertThat(in.list(), contains(ONE, TWO));
+        assertCombineDisjunctionsToIn((rule) -> {
+            Expression e = rule.rule(or);
+            assertEquals(In.class, e.getClass());
+            In in = (In) e;
+            assertEquals(fa, in.value());
+            assertThat(in.list(), contains(ONE, TWO));
+        });
     }
 
     public void testSingleValueInToEquals() throws Exception {
@@ -1540,8 +1568,10 @@ public class OptimizerRulesTests extends ESTestCase {
 
         Equals equals = equalsOf(fa, ONE);
         Or or = new Or(EMPTY, equals, new In(EMPTY, fa, singletonList(ONE)));
-        Expression e = new CombineDisjunctionsToIn().rule(or);
-        assertEquals(equals, e);
+        assertCombineDisjunctionsToIn((rule) -> {
+            Expression e = rule.rule(or);
+            assertEquals(equals, e);
+        });
     }
 
     public void testEqualsBehindAnd() throws Exception {
@@ -1549,9 +1579,11 @@ public class OptimizerRulesTests extends ESTestCase {
 
         And and = new And(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO));
         Filter dummy = new Filter(EMPTY, relation(), and);
-        LogicalPlan transformed = new CombineDisjunctionsToIn().apply(dummy);
-        assertSame(dummy, transformed);
-        assertEquals(and, ((Filter) transformed).condition());
+        assertCombineDisjunctionsToIn((rule) -> {
+            LogicalPlan transformed = rule.apply(dummy);
+            assertSame(dummy, transformed);
+            assertEquals(and, ((Filter) transformed).condition());
+        });
     }
 
     public void testTwoEqualsDifferentFields() throws Exception {
@@ -1559,8 +1591,10 @@ public class OptimizerRulesTests extends ESTestCase {
         FieldAttribute fieldTwo = TestUtils.getFieldAttribute("TWO");
 
         Or or = new Or(EMPTY, equalsOf(fieldOne, ONE), equalsOf(fieldTwo, TWO));
-        Expression e = new CombineDisjunctionsToIn().rule(or);
-        assertEquals(or, e);
+        assertCombineDisjunctionsToIn((rule) -> {
+            Expression e = rule.rule(or);
+            assertEquals(or, e);
+        });
     }
 
     public void testMultipleIn() throws Exception {
@@ -1568,11 +1602,13 @@ public class OptimizerRulesTests extends ESTestCase {
 
         Or firstOr = new Or(EMPTY, new In(EMPTY, fa, singletonList(ONE)), new In(EMPTY, fa, singletonList(TWO)));
         Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, singletonList(THREE)));
-        Expression e = new CombineDisjunctionsToIn().rule(secondOr);
-        assertEquals(In.class, e.getClass());
-        In in = (In) e;
-        assertEquals(fa, in.value());
-        assertThat(in.list(), contains(ONE, TWO, THREE));
+        assertCombineDisjunctionsToIn((rule) -> {
+            Expression e = rule.rule(secondOr);
+            assertEquals(In.class, e.getClass());
+            In in = (In) e;
+            assertEquals(fa, in.value());
+            assertThat(in.list(), contains(ONE, TWO, THREE));
+        });
     }
 
     public void testOrWithNonCombinableExpressions() throws Exception {
@@ -1580,14 +1616,159 @@ public class OptimizerRulesTests extends ESTestCase {
 
         Or firstOr = new Or(EMPTY, new In(EMPTY, fa, singletonList(ONE)), lessThanOf(fa, TWO));
         Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, singletonList(THREE)));
-        Expression e = new CombineDisjunctionsToIn().rule(secondOr);
+        assertCombineDisjunctionsToIn((rule) -> {
+            Expression e = rule.rule(secondOr);
+            assertEquals(Or.class, e.getClass());
+            Or or = (Or) e;
+            assertEquals(or.left(), firstOr.right());
+            assertEquals(In.class, or.right().getClass());
+            In in = (In) or.right();
+            assertEquals(fa, in.value());
+            assertThat(in.list(), contains(ONE, THREE));
+        });
+    }
+
+    public void testDontCombineSimpleDifferentTypes() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+
+        Or or = new Or(EMPTY, new Equals(EMPTY, fa, ONE), new Equals(EMPTY, fa, TEXT_A));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(or);
+        assertEquals(or, e);
+    }
+
+    public void testDontCombineDifferentTypes() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+
+        Or or = new Or(EMPTY, new Equals(EMPTY, fa, ONE), new Equals(EMPTY, fa, TEXT_A));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(or);
+        assertEquals(or, e);
+    }
+
+    // See https://github.com/elastic/elasticsearch/issues/118621
+    public void testDontCombineStringTypesForIPField() throws Exception {
+        FieldAttribute fa = TestUtils.getFieldAttribute("ip", DataTypes.IP);
+
+        Or or = new Or(EMPTY, new Equals(EMPTY, fa, TEXT_A), new Equals(EMPTY, fa, TEXT_B));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(or);
+        assertEquals(or, e);
+    }
+
+    public void testDontCombineForIncompatibleFieldType() throws Exception {
+        FieldAttribute fa = TestUtils.getFieldAttribute("boolean", BOOLEAN);
+
+        Or or = new Or(EMPTY, new Equals(EMPTY, fa, ONE), new Equals(EMPTY, fa, TWO));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(or);
+        assertEquals(or, e);
+    }
+
+    public void testDontCombineTwoCompatibleAndOneIncompatible() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+
+        Or firstOr = new Or(EMPTY, new Equals(EMPTY, fa, ONE), new Equals(EMPTY, fa, TWO));
+        Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, fa, TEXT_A));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(secondOr);
+        assertEquals(secondOr, e);
+    }
+
+    public void testDontCombineOneIncompatibleEqualsWithCompatibleIn() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+
+        Or or = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE, TWO)), new Equals(EMPTY, fa, TEXT_A));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(or);
+        assertEquals(or, e);
+    }
+
+    public void testDontCombineTwoIncompatibleIns1() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+
+        Or or = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE, TWO)), new In(EMPTY, fa, List.of(TEXT_A, TEXT_B, TEXT_C)));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(or);
+        assertEquals(or, e);
+    }
+
+    public void testDontCombineTwoIncompatibleIns2() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+
+        Or or = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), new In(EMPTY, fa, List.of(TEXT_A)));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(or);
+        assertEquals(or, e);
+    }
+
+    public void testDontCombineTwoIncompatibleIns3() throws Exception {
+        FieldAttribute fa1 = TestUtils.getFieldAttribute("field1");
+        FieldAttribute fa2 = TestUtils.getFieldAttribute("field2");
+
+        Or or = new Or(EMPTY, new In(EMPTY, fa1, List.of(ONE, TWO)), new In(EMPTY, fa2, List.of(THREE, FOUR)));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(or);
+        assertEquals(or, e);
+    }
+
+    public void testDontCombineIncompatibleInWithTwoCompatibleEquals() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+
+        Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(TEXT_A, TEXT_B)), new Equals(EMPTY, fa, THREE));
+        Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, fa, FOUR));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(secondOr);
+        assertEquals(secondOr, e);
+    }
+
+    public void testCombineOnlyEqualsExpressions() throws Exception {
+        FieldAttribute faIn = TestUtils.getFieldAttribute("field_for_in");
+        FieldAttribute faEquals = TestUtils.getFieldAttribute("field_for_equals");
+
+        Or firstOr = new Or(EMPTY, new In(EMPTY, faIn, List.of(ONE, TWO)), new Equals(EMPTY, faEquals, THREE));
+        Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, faEquals, FOUR));
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(secondOr);
         assertEquals(Or.class, e.getClass());
         Or or = (Or) e;
-        assertEquals(or.left(), firstOr.right());
+        assertEquals(or.left(), firstOr.left());
         assertEquals(In.class, or.right().getClass());
         In in = (In) or.right();
-        assertEquals(fa, in.value());
-        assertThat(in.list(), contains(ONE, THREE));
+        assertEquals(faEquals, in.value());
+        assertThat(in.list(), contains(THREE, FOUR));
+    }
+
+    public void testCombineOnlyCompatibleEqualsExpressions() throws Exception {
+        FieldAttribute faEquals1 = TestUtils.getFieldAttribute("field_for_equals1");
+        FieldAttribute faEquals2 = TestUtils.getFieldAttribute("field_for_equals2");
+
+        Equals equalsA = new Equals(EMPTY, faEquals2, TEXT_A);
+        Equals equalsB = new Equals(EMPTY, faEquals2, TEXT_B);
+        Or firstOr = new Or(EMPTY, new Equals(EMPTY, faEquals1, ONE), equalsA);
+        Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, faEquals1, TWO));
+        Or thirdOr = new Or(EMPTY, secondOr, equalsB);
+
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(thirdOr);
+        assertEquals(Or.class, e.getClass());
+        Or or = (Or) e;
+        assertEquals(In.class, or.left().getClass());
+        In in = (In) or.left();
+        assertThat(in.list(), contains(ONE, TWO));
+
+        assertEquals(Or.class, or.right().getClass());
+        or = (Or) or.right();
+        assertEquals(or.left(), equalsA);
+        assertEquals(or.right(), equalsB);
+    }
+
+    public void testCombineTwoCompatiblePairsOrEqualsExpressions() throws Exception {
+        FieldAttribute faEquals1 = TestUtils.getFieldAttribute("field_for_equals1");
+        FieldAttribute faEquals2 = TestUtils.getFieldAttribute("field_for_equals2");
+
+        Or firstOr = new Or(EMPTY, new Equals(EMPTY, faEquals1, ONE), new Equals(EMPTY, faEquals2, THREE));
+        Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, faEquals1, TWO));
+        Or thirdOr = new Or(EMPTY, secondOr, new Equals(EMPTY, faEquals2, FOUR));
+
+        Expression e = new ValidateableCombineDisjunctionsToIn().rule(thirdOr);
+        assertEquals(Or.class, e.getClass());
+        Or or = (Or) e;
+        assertEquals(In.class, or.left().getClass());
+        In in = (In) or.left();
+        assertThat(in.list(), contains(ONE, TWO));
+
+        assertEquals(In.class, or.right().getClass());
+        in = (In) or.right();
+        assertThat(in.list(), contains(THREE, FOUR));
     }
 
     // Null folding