Browse Source

Move score script context from SearchScript to its own class (#30816)

Martijn van Groningen 7 years ago
parent
commit
ae2f021f1c

+ 43 - 0
modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScriptEngine.java

@@ -23,8 +23,10 @@ import org.apache.lucene.expressions.Expression;
 import org.apache.lucene.expressions.SimpleBindings;
 import org.apache.lucene.expressions.js.JavascriptCompiler;
 import org.apache.lucene.expressions.js.VariableContext;
+import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.queries.function.ValueSource;
 import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource;
+import org.apache.lucene.search.Scorer;
 import org.apache.lucene.search.SortField;
 import org.elasticsearch.SpecialPermission;
 import org.elasticsearch.common.Nullable;
@@ -39,12 +41,14 @@ import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.script.ClassPermission;
 import org.elasticsearch.script.ExecutableScript;
 import org.elasticsearch.script.FilterScript;
+import org.elasticsearch.script.ScoreScript;
 import org.elasticsearch.script.ScriptContext;
 import org.elasticsearch.script.ScriptEngine;
 import org.elasticsearch.script.ScriptException;
 import org.elasticsearch.script.SearchScript;
 import org.elasticsearch.search.lookup.SearchLookup;
 
+import java.io.IOException;
 import java.security.AccessControlContext;
 import java.security.AccessController;
 import java.security.PrivilegedAction;
@@ -111,6 +115,9 @@ public class ExpressionScriptEngine extends AbstractComponent implements ScriptE
         } else if (context.instanceClazz.equals(FilterScript.class)) {
             FilterScript.Factory factory = (p, lookup) -> newFilterScript(expr, lookup, p);
             return context.factoryClazz.cast(factory);
+        } else if (context.instanceClazz.equals(ScoreScript.class)) {
+            ScoreScript.Factory factory = (p, lookup) -> newScoreScript(expr, lookup, p);
+            return context.factoryClazz.cast(factory);
         }
         throw new IllegalArgumentException("expression engine does not know how to handle script context [" + context.name + "]");
     }
@@ -260,6 +267,42 @@ public class ExpressionScriptEngine extends AbstractComponent implements ScriptE
             };
         };
     }
+    
+    private ScoreScript.LeafFactory newScoreScript(Expression expr, SearchLookup lookup, @Nullable Map<String, Object> vars) {
+        SearchScript.LeafFactory searchLeafFactory = newSearchScript(expr, lookup, vars);
+        return new ScoreScript.LeafFactory() {
+            @Override
+            public boolean needs_score() {
+                return searchLeafFactory.needs_score();
+            }
+
+            @Override
+            public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
+                SearchScript script = searchLeafFactory.newInstance(ctx);
+                return new ScoreScript(vars, lookup, ctx) {
+                    @Override
+                    public double execute() {
+                        return script.runAsDouble();
+                    }
+    
+                    @Override
+                    public void setDocument(int docid) {
+                        script.setDocument(docid);
+                    }
+    
+                    @Override
+                    public void setScorer(Scorer scorer) {
+                        script.setScorer(scorer);
+                    }
+    
+                    @Override
+                    public double get_score() {
+                        return script.getScore();
+                    }
+                };
+            }
+        };
+    }
 
     /**
      * converts a ParseException at compile-time or link-time to a ScriptException

+ 8 - 8
plugins/examples/script-expert-scoring/src/main/java/org/elasticsearch/example/expertscript/ExpertScriptPlugin.java

@@ -30,9 +30,9 @@ import org.apache.lucene.index.Term;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.ScriptPlugin;
+import org.elasticsearch.script.ScoreScript;
 import org.elasticsearch.script.ScriptContext;
 import org.elasticsearch.script.ScriptEngine;
-import org.elasticsearch.script.SearchScript;
 
 /**
  * An example script plugin that adds a {@link ScriptEngine} implementing expert scoring.
@@ -54,12 +54,12 @@ public class ExpertScriptPlugin extends Plugin implements ScriptPlugin {
 
         @Override
         public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
-            if (context.equals(SearchScript.SCRIPT_SCORE_CONTEXT) == false) {
+            if (context.equals(ScoreScript.CONTEXT) == false) {
                 throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
             }
             // we use the script "source" as the script identifier
             if ("pure_df".equals(scriptSource)) {
-                SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
+                ScoreScript.Factory factory = (p, lookup) -> new ScoreScript.LeafFactory() {
                     final String field;
                     final String term;
                     {
@@ -74,18 +74,18 @@ public class ExpertScriptPlugin extends Plugin implements ScriptPlugin {
                     }
 
                     @Override
-                    public SearchScript newInstance(LeafReaderContext context) throws IOException {
+                    public ScoreScript newInstance(LeafReaderContext context) throws IOException {
                         PostingsEnum postings = context.reader().postings(new Term(field, term));
                         if (postings == null) {
                             // the field and/or term don't exist in this segment, so always return 0
-                            return new SearchScript(p, lookup, context) {
+                            return new ScoreScript(p, lookup, context) {
                                 @Override
-                                public double runAsDouble() {
+                                public double execute() {
                                     return 0.0d;
                                 }
                             };
                         }
-                        return new SearchScript(p, lookup, context) {
+                        return new ScoreScript(p, lookup, context) {
                             int currentDocid = -1;
                             @Override
                             public void setDocument(int docid) {
@@ -100,7 +100,7 @@ public class ExpertScriptPlugin extends Plugin implements ScriptPlugin {
                                 currentDocid = docid;
                             }
                             @Override
-                            public double runAsDouble() {
+                            public double execute() {
                                 if (postings.docID() != currentDocid) {
                                     // advance moved past the current doc, so this doc has no occurrences of the term
                                     return 0.0d;

+ 5 - 5
server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java

@@ -24,8 +24,8 @@ import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.Scorer;
 import org.elasticsearch.script.ExplainableSearchScript;
+import org.elasticsearch.script.ScoreScript;
 import org.elasticsearch.script.Script;
-import org.elasticsearch.script.SearchScript;
 
 import java.io.IOException;
 import java.util.Objects;
@@ -58,10 +58,10 @@ public class ScriptScoreFunction extends ScoreFunction {
 
     private final Script sScript;
 
-    private final SearchScript.LeafFactory script;
+    private final ScoreScript.LeafFactory script;
 
 
-    public ScriptScoreFunction(Script sScript, SearchScript.LeafFactory script) {
+    public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script) {
         super(CombineFunction.REPLACE);
         this.sScript = sScript;
         this.script = script;
@@ -69,7 +69,7 @@ public class ScriptScoreFunction extends ScoreFunction {
 
     @Override
     public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException {
-        final SearchScript leafScript = script.newInstance(ctx);
+        final ScoreScript leafScript = script.newInstance(ctx);
         final CannedScorer scorer = new CannedScorer();
         leafScript.setScorer(scorer);
         return new LeafScoreFunction() {
@@ -78,7 +78,7 @@ public class ScriptScoreFunction extends ScoreFunction {
                 leafScript.setDocument(docId);
                 scorer.docid = docId;
                 scorer.score = subQueryScore;
-                double result = leafScript.runAsDouble();
+                double result = leafScript.execute();
                 return result;
             }
 

+ 3 - 2
server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java

@@ -28,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.index.query.QueryShardException;
+import org.elasticsearch.script.ScoreScript;
 import org.elasticsearch.script.Script;
 import org.elasticsearch.script.SearchScript;
 
@@ -92,8 +93,8 @@ public class ScriptScoreFunctionBuilder extends ScoreFunctionBuilder<ScriptScore
     @Override
     protected ScoreFunction doToFunction(QueryShardContext context) {
         try {
-            SearchScript.Factory factory = context.getScriptService().compile(script, SearchScript.SCRIPT_SCORE_CONTEXT);
-            SearchScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
+            ScoreScript.Factory factory = context.getScriptService().compile(script, ScoreScript.CONTEXT);
+            ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
             return new ScriptScoreFunction(script, searchScript);
         } catch (Exception e) {
             throw new QueryShardException(context, "script_score: the script could not be loaded", e);

+ 102 - 0
server/src/main/java/org/elasticsearch/script/ScoreScript.java

@@ -0,0 +1,102 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.script;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.Scorer;
+import org.elasticsearch.index.fielddata.ScriptDocValues;
+import org.elasticsearch.search.lookup.LeafSearchLookup;
+import org.elasticsearch.search.lookup.SearchLookup;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.Map;
+import java.util.function.DoubleSupplier;
+
+/**
+ * A script used for adjusting the score on a per document basis.
+ */
+public abstract class ScoreScript {
+    
+    public static final String[] PARAMETERS = new String[]{};
+    
+    /** The generic runtime parameters for the script. */
+    private final Map<String, Object> params;
+    
+    /** A leaf lookup for the bound segment this script will operate on. */
+    private final LeafSearchLookup leafLookup;
+    
+    private DoubleSupplier scoreSupplier = () -> 0.0;
+    
+    public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
+        this.params = params;
+        this.leafLookup = lookup.getLeafSearchLookup(leafContext);
+    }
+    
+    public abstract double execute();
+    
+    /** Return the parameters for this script. */
+    public Map<String, Object> getParams() {
+        return params;
+    }
+    
+    /** The doc lookup for the Lucene segment this script was created for. */
+    public final Map<String, ScriptDocValues<?>> getDoc() {
+        return leafLookup.doc();
+    }
+    
+    /** Set the current document to run the script on next. */
+    public void setDocument(int docid) {
+        leafLookup.setDocument(docid);
+    }
+    
+    public void setScorer(Scorer scorer) {
+        this.scoreSupplier = () -> {
+            try {
+                return scorer.score();
+            } catch (IOException e) {
+                throw new UncheckedIOException(e);
+            }
+        };
+    }
+    
+    public double get_score() {
+        return scoreSupplier.getAsDouble();
+    }
+    
+    /** A factory to construct {@link ScoreScript} instances. */
+    public interface LeafFactory {
+    
+        /**
+         * Return {@code true} if the script needs {@code _score} calculated, or {@code false} otherwise.
+         */
+        boolean needs_score();
+        
+        ScoreScript newInstance(LeafReaderContext ctx) throws IOException;
+    }
+    
+    /** A factory to construct stateful {@link ScoreScript} factories for a specific index. */
+    public interface Factory {
+        
+        ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup);
+        
+    }
+    
+    public static final ScriptContext<ScoreScript.Factory> CONTEXT = new ScriptContext<>("score", ScoreScript.Factory.class);
+}

+ 1 - 1
server/src/main/java/org/elasticsearch/script/ScriptModule.java

@@ -42,7 +42,7 @@ public class ScriptModule {
         CORE_CONTEXTS = Stream.of(
             SearchScript.CONTEXT,
             SearchScript.AGGS_CONTEXT,
-            SearchScript.SCRIPT_SCORE_CONTEXT,
+            ScoreScript.CONTEXT,
             SearchScript.SCRIPT_SORT_CONTEXT,
             SearchScript.TERMS_SET_QUERY_CONTEXT,
             ExecutableScript.CONTEXT,

+ 0 - 2
server/src/main/java/org/elasticsearch/script/SearchScript.java

@@ -162,8 +162,6 @@ public abstract class SearchScript implements ScorerAware, ExecutableScript {
     public static final ScriptContext<Factory> AGGS_CONTEXT = new ScriptContext<>("aggs", Factory.class);
     // Can return a double. (For ScriptSortType#NUMBER only, for ScriptSortType#STRING normal CONTEXT should be used)
     public static final ScriptContext<Factory> SCRIPT_SORT_CONTEXT = new ScriptContext<>("sort", Factory.class);
-    // Can return a float
-    public static final ScriptContext<Factory> SCRIPT_SCORE_CONTEXT = new ScriptContext<>("score", Factory.class);
     // Can return a long
     public static final ScriptContext<Factory> TERMS_SET_QUERY_CONTEXT = new ScriptContext<>("terms_set", Factory.class);
 }

+ 16 - 22
server/src/test/java/org/elasticsearch/search/functionscore/ExplainableScriptIT.java

@@ -30,14 +30,14 @@ import org.elasticsearch.index.fielddata.ScriptDocValues;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.ScriptPlugin;
 import org.elasticsearch.script.ExplainableSearchScript;
+import org.elasticsearch.script.ScoreScript;
 import org.elasticsearch.script.Script;
 import org.elasticsearch.script.ScriptContext;
 import org.elasticsearch.script.ScriptEngine;
 import org.elasticsearch.script.ScriptType;
-import org.elasticsearch.script.SearchScript;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
-import org.elasticsearch.search.lookup.LeafDocLookup;
+import org.elasticsearch.search.lookup.SearchLookup;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
 import org.elasticsearch.test.ESIntegTestCase.Scope;
@@ -76,16 +76,17 @@ public class ExplainableScriptIT extends ESIntegTestCase {
                 @Override
                 public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
                     assert scriptSource.equals("explainable_script");
-                    assert context == SearchScript.SCRIPT_SCORE_CONTEXT;
-                    SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
-                        @Override
-                        public SearchScript newInstance(LeafReaderContext context) throws IOException {
-                            return new MyScript(lookup.doc().getLeafDocLookup(context));
-                        }
+                    assert context == ScoreScript.CONTEXT;
+                    ScoreScript.Factory factory = (params1, lookup) -> new ScoreScript.LeafFactory() {
                         @Override
                         public boolean needs_score() {
                             return false;
                         }
+
+                        @Override
+                        public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
+                            return new MyScript(params1, lookup, ctx);
+                        }
                     };
                     return context.factoryClazz.cast(factory);
                 }
@@ -93,28 +94,21 @@ public class ExplainableScriptIT extends ESIntegTestCase {
         }
     }
 
-    static class MyScript extends SearchScript implements ExplainableSearchScript {
-        LeafDocLookup docLookup;
+    static class MyScript extends ScoreScript implements ExplainableSearchScript {
 
-        MyScript(LeafDocLookup docLookup) {
-            super(null, null, null);
-            this.docLookup = docLookup;
+        MyScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
+            super(params, lookup, leafContext);
         }
-
-        @Override
-        public void setDocument(int doc) {
-            docLookup.setDocument(doc);
-        }
-
+    
         @Override
         public Explanation explain(Explanation subQueryScore) throws IOException {
             Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore);
-            return Explanation.match((float) (runAsDouble()), "This script returned " + runAsDouble(), scoreExp);
+            return Explanation.match((float) (execute()), "This script returned " + execute(), scoreExp);
         }
 
         @Override
-        public double runAsDouble() {
-            return ((Number) ((ScriptDocValues) docLookup.get("number_field")).getValues().get(0)).doubleValue();
+        public double execute() {
+            return ((Number) ((ScriptDocValues) getDoc().get("number_field")).getValues().get(0)).doubleValue();
         }
     }
 

+ 43 - 2
test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java

@@ -25,7 +25,6 @@ import org.elasticsearch.index.similarity.ScriptedSimilarity.Doc;
 import org.elasticsearch.index.similarity.ScriptedSimilarity.Field;
 import org.elasticsearch.index.similarity.ScriptedSimilarity.Query;
 import org.elasticsearch.index.similarity.ScriptedSimilarity.Term;
-import org.elasticsearch.index.similarity.SimilarityService;
 import org.elasticsearch.search.aggregations.pipeline.movfn.MovingFunctionScript;
 import org.elasticsearch.search.aggregations.pipeline.movfn.MovingFunctions;
 import org.elasticsearch.search.lookup.LeafSearchLookup;
@@ -36,7 +35,6 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.function.Function;
-import java.util.function.Predicate;
 
 import static java.util.Collections.emptyMap;
 
@@ -114,6 +112,9 @@ public class MockScriptEngine implements ScriptEngine {
         } else if (context.instanceClazz.equals(MovingFunctionScript.class)) {
             MovingFunctionScript.Factory factory = mockCompiled::createMovingFunctionScript;
             return context.factoryClazz.cast(factory);
+        } else if (context.instanceClazz.equals(ScoreScript.class)) {
+            ScoreScript.Factory factory = new MockScoreScript(script);
+            return context.factoryClazz.cast(factory);
         }
         throw new IllegalArgumentException("mock script engine does not know how to handle context [" + context.name + "]");
     }
@@ -342,5 +343,45 @@ public class MockScriptEngine implements ScriptEngine {
             return MovingFunctions.unweightedAvg(values);
         }
     }
+    
+    public class MockScoreScript implements ScoreScript.Factory {
+    
+        private final Function<Map<String, Object>, Object> scripts;
+        
+        MockScoreScript(Function<Map<String, Object>, Object> scripts) {
+            this.scripts = scripts;
+        }
+        
+        @Override
+        public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
+            return new ScoreScript.LeafFactory() {
+                @Override
+                public boolean needs_score() {
+                    return true;
+                }
+    
+                @Override
+                public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
+                    Scorer[] scorerHolder = new Scorer[1];
+                    return new ScoreScript(params, lookup, ctx) {
+                        @Override
+                        public double execute() {
+                            Map<String, Object> vars = new HashMap<>(getParams());
+                            vars.put("doc", getDoc());
+                            if (scorerHolder[0] != null) {
+                                vars.put("_score", new ScoreAccessor(scorerHolder[0]));
+                            }
+                            return ((Number) scripts.apply(vars)).doubleValue();
+                        }
+    
+                        @Override
+                        public void setScorer(Scorer scorer) {
+                            scorerHolder[0] = scorer;
+                        }
+                    };
+                }
+            };
+        }
+    }
 
 }