1
0
Эх сурвалжийг харах

Add script filter to intervals (#36776)

This commit adds the ability to filter out intervals based on their start and end position, and internal
gaps:
```
POST _search
{
  "query": {
    "intervals" : {
      "my_text" : {
        "match" : {
          "query" : "hot porridge",
          "filter" : {
            "script" : {
              "source" : "interval.start > 10 && interval.end < 20 && interval.gaps == 0"
            }
          }
        }
      }
    }
  }
}
```
Alan Woodward 6 жил өмнө
parent
commit
344917efab

+ 29 - 0
docs/reference/query-dsl/intervals-query.asciidoc

@@ -154,6 +154,35 @@ Produces intervals that are not contained by an interval from the filter rule
 `not_overlapping`::
 Produces intervals that do not overlap with an interval from the filter rule
 
+[[interval-script-filter]]
+==== Script filters
+
+You can also filter intervals based on their start position, end position and
+internal gap count, using a script.  The script has access to an `interval`
+variable, with `start`, `end` and `gaps` methods:
+
+[source,js]
+--------------------------------------------------
+POST _search
+{
+  "query": {
+    "intervals" : {
+      "my_text" : {
+        "match" : {
+          "query" : "hot porridge",
+          "filter" : {
+            "script" : {
+              "source" : "interval.start > 10 && interval.end < 20 && interval.gaps == 0"
+            }
+          }
+        }
+      }
+    }
+  }
+}
+--------------------------------------------------
+// CONSOLE
+
 [[interval-minimization]]
 ==== Minimization
 

+ 6 - 0
modules/lang-painless/src/main/resources/org/elasticsearch/painless/spi/org.elasticsearch.txt

@@ -234,6 +234,12 @@ class org.elasticsearch.index.similarity.ScriptedSimilarity$Doc {
   float getFreq()
 }
 
+class org.elasticsearch.index.query.IntervalFilterScript$Interval {
+  int getStart()
+  int getEnd()
+  int getGaps()
+}
+
 # for testing
 class org.elasticsearch.painless.FeatureTest no_import {
   int z

+ 46 - 0
modules/lang-painless/src/test/resources/rest-api-spec/test/painless/90_interval_query_filter.yml

@@ -0,0 +1,46 @@
+setup:
+  - skip:
+      version: " - 6.99.99"
+      reason:  "Implemented in 7.0"
+
+  - do:
+      indices.create:
+        index:  test
+        body:
+          mappings:
+            test:
+              properties:
+                text:
+                  type: text
+                  analyzer: standard
+  - do:
+      bulk:
+        refresh: true
+        body:
+          - '{"index": {"_index": "test", "_type": "test", "_id": "1"}}'
+          - '{"text" : "Some like it hot, some like it cold"}'
+          - '{"index": {"_index": "test", "_type": "test", "_id": "2"}}'
+          - '{"text" : "Its cold outside, theres no kind of atmosphere"}'
+          - '{"index": {"_index": "test", "_type": "test", "_id": "3"}}'
+          - '{"text" : "Baby its cold there outside"}'
+          - '{"index": {"_index": "test", "_type": "test", "_id": "4"}}'
+          - '{"text" : "Outside it is cold and wet"}'
+
+---
+"Test filtering by script":
+  - do:
+      search:
+        index: test
+        body:
+          query:
+            intervals:
+              text:
+                match:
+                  query: "cold"
+                  filter:
+                    script:
+                      source: "interval.start > 3"
+
+  - match: { hits.total.value: 1 }
+
+

+ 60 - 0
server/src/main/java/org/elasticsearch/index/query/IntervalFilterScript.java

@@ -0,0 +1,60 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.index.query;
+
+import org.apache.lucene.search.intervals.IntervalIterator;
+import org.elasticsearch.script.ScriptContext;
+
+/**
+ * Base class for scripts used as interval filters, see {@link IntervalsSourceProvider.IntervalFilter}
+ */
+public abstract class IntervalFilterScript {
+
+    public static class Interval {
+
+        private IntervalIterator iterator;
+
+        void setIterator(IntervalIterator iterator) {
+            this.iterator = iterator;
+        }
+
+        public int getStart() {
+            return iterator.start();
+        }
+
+        public int getEnd() {
+            return iterator.end();
+        }
+
+        public int getGaps() {
+            return iterator.gaps();
+        }
+    }
+
+    public abstract boolean execute(Interval interval);
+
+    public interface Factory {
+        IntervalFilterScript newInstance();
+    }
+
+    public static final String[] PARAMETERS = new String[]{ "interval" };
+    public static final ScriptContext<Factory> CONTEXT = new ScriptContext<>("interval", Factory.class);
+
+}

+ 56 - 2
server/src/main/java/org/elasticsearch/index/query/IntervalsSourceProvider.java

@@ -19,6 +19,8 @@
 
 package org.elasticsearch.index.query;
 
+import org.apache.lucene.search.intervals.FilteredIntervalsSource;
+import org.apache.lucene.search.intervals.IntervalIterator;
 import org.apache.lucene.search.intervals.Intervals;
 import org.apache.lucene.search.intervals.IntervalsSource;
 import org.elasticsearch.common.ParseField;
@@ -34,6 +36,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.analysis.NamedAnalyzer;
 import org.elasticsearch.index.mapper.MappedFieldType;
+import org.elasticsearch.script.Script;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -387,24 +390,59 @@ public abstract class IntervalsSourceProvider implements NamedWriteable, ToXCont
         }
     }
 
+    static class ScriptFilterSource extends FilteredIntervalsSource {
+
+        final IntervalFilterScript script;
+        IntervalFilterScript.Interval interval = new IntervalFilterScript.Interval();
+
+        ScriptFilterSource(IntervalsSource in, String name, IntervalFilterScript script) {
+            super("FILTER(" + name + ")", in);
+            this.script = script;
+        }
+
+        @Override
+        protected boolean accept(IntervalIterator it) {
+            interval.setIterator(it);
+            return script.execute(interval);
+        }
+    }
+
     public static class IntervalFilter implements ToXContent, Writeable {
 
         public static final String NAME = "filter";
 
         private final String type;
         private final IntervalsSourceProvider filter;
+        private final Script script;
 
         public IntervalFilter(IntervalsSourceProvider filter, String type) {
             this.filter = filter;
             this.type = type.toLowerCase(Locale.ROOT);
+            this.script = null;
+        }
+
+        IntervalFilter(Script script) {
+            this.script = script;
+            this.type = "script";
+            this.filter = null;
         }
 
         public IntervalFilter(StreamInput in) throws IOException {
             this.type = in.readString();
-            this.filter = in.readNamedWriteable(IntervalsSourceProvider.class);
+            this.filter = in.readOptionalNamedWriteable(IntervalsSourceProvider.class);
+            if (in.readBoolean()) {
+                this.script = new Script(in);
+            }
+            else {
+                this.script = null;
+            }
         }
 
         public IntervalsSource filter(IntervalsSource input, QueryShardContext context, MappedFieldType fieldType) throws IOException {
+            if (script != null) {
+                IntervalFilterScript ifs = context.getScriptService().compile(script, IntervalFilterScript.CONTEXT).newInstance();
+                return new ScriptFilterSource(input, script.getIdOrCode(), ifs);
+            }
             IntervalsSource filterSource = filter.getSource(context, fieldType);
             switch (type) {
                 case "containing":
@@ -439,7 +477,14 @@ public abstract class IntervalsSourceProvider implements NamedWriteable, ToXCont
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             out.writeString(type);
-            out.writeNamedWriteable(filter);
+            out.writeOptionalNamedWriteable(filter);
+            if (script == null) {
+                out.writeBoolean(false);
+            }
+            else {
+                out.writeBoolean(true);
+                script.writeTo(out);
+            }
         }
 
         @Override
@@ -458,6 +503,13 @@ public abstract class IntervalsSourceProvider implements NamedWriteable, ToXCont
                 throw new ParsingException(parser.getTokenLocation(), "Expected [FIELD_NAME] but got [" + parser.currentToken() + "]");
             }
             String type = parser.currentName();
+            if (Script.SCRIPT_PARSE_FIELD.match(type, parser.getDeprecationHandler())) {
+                Script script = Script.parse(parser);
+                if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
+                    throw new ParsingException(parser.getTokenLocation(), "Expected [END_OBJECT] but got [" + parser.currentToken() + "]");
+                }
+                return new IntervalFilter(script);
+            }
             if (parser.nextToken() != XContentParser.Token.START_OBJECT) {
                 throw new ParsingException(parser.getTokenLocation(), "Expected [START_OBJECT] but got [" + parser.currentToken() + "]");
             }
@@ -475,4 +527,6 @@ public abstract class IntervalsSourceProvider implements NamedWriteable, ToXCont
         }
     }
 
+
+
 }

+ 3 - 1
server/src/main/java/org/elasticsearch/script/ScriptModule.java

@@ -21,6 +21,7 @@ package org.elasticsearch.script;
 
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.query.IntervalFilterScript;
 import org.elasticsearch.plugins.ScriptPlugin;
 import org.elasticsearch.search.aggregations.pipeline.MovingFunctionScript;
 
@@ -60,7 +61,8 @@ public class ScriptModule {
             ScriptedMetricAggContexts.InitScript.CONTEXT,
             ScriptedMetricAggContexts.MapScript.CONTEXT,
             ScriptedMetricAggContexts.CombineScript.CONTEXT,
-            ScriptedMetricAggContexts.ReduceScript.CONTEXT
+            ScriptedMetricAggContexts.ReduceScript.CONTEXT,
+            IntervalFilterScript.CONTEXT
         ).collect(Collectors.toMap(c -> c.name, Function.identity()));
     }
 

+ 49 - 0
server/src/test/java/org/elasticsearch/index/query/IntervalQueryBuilderTests.java

@@ -25,11 +25,16 @@ import org.apache.lucene.search.Query;
 import org.apache.lucene.search.intervals.IntervalQuery;
 import org.apache.lucene.search.intervals.Intervals;
 import org.elasticsearch.common.ParsingException;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptContext;
+import org.elasticsearch.script.ScriptService;
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.test.AbstractQueryTestCase;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 
 import static org.hamcrest.Matchers.equalTo;
@@ -277,4 +282,48 @@ public class IntervalQueryBuilderTests extends AbstractQueryTestCase<IntervalQue
         });
         assertThat(e.getMessage(), equalTo("Only one interval rule can be specified, found [match] and [all_of]"));
     }
+
+    public void testScriptFilter() throws IOException {
+
+        IntervalFilterScript.Factory factory = () -> new IntervalFilterScript() {
+            @Override
+            public boolean execute(Interval interval) {
+                return interval.getStart() > 3;
+            }
+        };
+
+        ScriptService scriptService = new ScriptService(Settings.EMPTY, Collections.emptyMap(), Collections.emptyMap()){
+            @Override
+            @SuppressWarnings("unchecked")
+            public <FactoryType> FactoryType compile(Script script, ScriptContext<FactoryType> context) {
+                assertEquals(IntervalFilterScript.CONTEXT, context);
+                assertEquals(new Script("interval.start > 3"), script);
+                return (FactoryType) factory;
+            }
+        };
+
+        QueryShardContext baseContext = createShardContext();
+        QueryShardContext context = new QueryShardContext(baseContext.getShardId(), baseContext.getIndexSettings(),
+            null, null, baseContext.getMapperService(), null,
+            scriptService,
+            null, null, null, null, null, null);
+
+        String json = "{ \"intervals\" : { \"" + STRING_FIELD_NAME + "\": { " +
+            "\"match\" : { " +
+            "   \"query\" : \"term1\"," +
+            "   \"filter\" : { " +
+            "       \"script\" : { " +
+            "            \"source\" : \"interval.start > 3\" } } } } } }";
+
+        IntervalQueryBuilder builder = (IntervalQueryBuilder) parseQuery(json);
+        Query q = builder.toQuery(context);
+
+
+        IntervalQuery expected = new IntervalQuery(STRING_FIELD_NAME,
+            new IntervalsSourceProvider.ScriptFilterSource(Intervals.term("term1"), "interval.start > 3", null));
+        assertEquals(expected, q);
+
+    }
+
+
 }

+ 13 - 0
test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java

@@ -21,6 +21,7 @@ package org.elasticsearch.script;
 
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.Scorable;
+import org.elasticsearch.index.query.IntervalFilterScript;
 import org.elasticsearch.index.similarity.ScriptedSimilarity.Doc;
 import org.elasticsearch.index.similarity.ScriptedSimilarity.Field;
 import org.elasticsearch.index.similarity.ScriptedSimilarity.Query;
@@ -287,6 +288,9 @@ public class MockScriptEngine implements ScriptEngine {
         } else if (context.instanceClazz.equals(ScriptedMetricAggContexts.ReduceScript.class)) {
             ScriptedMetricAggContexts.ReduceScript.Factory factory = mockCompiled::createMetricAggReduceScript;
             return context.factoryClazz.cast(factory);
+        } else if (context.instanceClazz.equals(IntervalFilterScript.class)) {
+            IntervalFilterScript.Factory factory = mockCompiled::createIntervalFilterScript;
+            return context.factoryClazz.cast(factory);
         }
         ContextCompiler compiler = contexts.get(context);
         if (compiler != null) {
@@ -353,6 +357,15 @@ public class MockScriptEngine implements ScriptEngine {
         public ScriptedMetricAggContexts.ReduceScript createMetricAggReduceScript(Map<String, Object> params, List<Object> states) {
             return new MockMetricAggReduceScript(params, states, script != null ? script : ctx -> 42d);
         }
+
+        public IntervalFilterScript createIntervalFilterScript() {
+            return new IntervalFilterScript() {
+                @Override
+                public boolean execute(Interval interval) {
+                    return false;
+                }
+            };
+        }
     }
 
     public static class MockFilterScript implements FilterScript.LeafFactory {