Browse Source

Support DoubleValues expression scripts in lang-expression (#89895)

This allows for custom DoubleValues scripts to be made
with lang-expressions. The change also adds a way for plugin
integration and unit tests to be able to create class loaders,
after this permission was removed in 8.0.
Nikola Grcevski 3 years ago
parent
commit
a23999e700

+ 5 - 0
docs/changelog/89895.yaml

@@ -0,0 +1,5 @@
+pr: 89895
+summary: Initial code to support binary expression scripts
+area: "Infra/Scripting"
+type: enhancement
+issues: []

+ 88 - 0
modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionDoubleValuesScript.java

@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.script.expression;
+
+import org.apache.lucene.expressions.Bindings;
+import org.apache.lucene.expressions.Expression;
+import org.apache.lucene.search.DoubleValues;
+import org.apache.lucene.search.DoubleValuesSource;
+import org.apache.lucene.search.Rescorer;
+import org.apache.lucene.search.SortField;
+import org.elasticsearch.script.DoubleValuesScript;
+
+import java.util.function.Function;
+
+/**
+ * A factory for a custom compiled {@link Expression} scripts
+ * <p>
+ * Instead of an execution result, we return a wrapper to an {@link Expression} object, which
+ * can be used for all supported double values operations.
+ */
+public class ExpressionDoubleValuesScript implements DoubleValuesScript.Factory {
+    private final Expression exprScript;
+
+    ExpressionDoubleValuesScript(Expression e) {
+        this.exprScript = e;
+    }
+
+    @Override
+    public DoubleValuesScript newInstance() {
+        return new DoubleValuesScript() {
+            @Override
+            public double execute() {
+                return exprScript.evaluate(new DoubleValues[0]);
+            }
+
+            @Override
+            public double evaluate(DoubleValues[] functionValues) {
+                return exprScript.evaluate(functionValues);
+            }
+
+            @Override
+            public DoubleValuesSource getDoubleValuesSource(Function<String, DoubleValuesSource> sourceProvider) {
+                return exprScript.getDoubleValuesSource(new Bindings() {
+                    @Override
+                    public DoubleValuesSource getDoubleValuesSource(String name) {
+                        return sourceProvider.apply(name);
+                    }
+                });
+            }
+
+            @Override
+            public SortField getSortField(Function<String, DoubleValuesSource> sourceProvider, boolean reverse) {
+                return exprScript.getSortField(new Bindings() {
+                    @Override
+                    public DoubleValuesSource getDoubleValuesSource(String name) {
+                        return sourceProvider.apply(name);
+                    }
+                }, reverse);
+            }
+
+            @Override
+            public Rescorer getRescorer(Function<String, DoubleValuesSource> sourceProvider) {
+                return exprScript.getRescorer(new Bindings() {
+                    @Override
+                    public DoubleValuesSource getDoubleValuesSource(String name) {
+                        return sourceProvider.apply(name);
+                    }
+                });
+            }
+
+            @Override
+            public String sourceText() {
+                return exprScript.sourceText;
+            }
+
+            @Override
+            public String[] variables() {
+                return exprScript.variables;
+            }
+        };
+    }
+}

+ 9 - 0
modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScriptEngine.java

@@ -24,6 +24,7 @@ import org.elasticsearch.script.AggregationScript;
 import org.elasticsearch.script.BucketAggregationScript;
 import org.elasticsearch.script.BucketAggregationSelectorScript;
 import org.elasticsearch.script.ClassPermission;
+import org.elasticsearch.script.DoubleValuesScript;
 import org.elasticsearch.script.FieldScript;
 import org.elasticsearch.script.FilterScript;
 import org.elasticsearch.script.NumberSortScript;
@@ -132,6 +133,14 @@ public class ExpressionScriptEngine implements ScriptEngine {
                 return newFieldScript(expr, lookup, params);
             }
 
+            @Override
+            public boolean isResultDeterministic() {
+                return true;
+            }
+        },
+
+        DoubleValuesScript.CONTEXT,
+        (Expression expr) -> new ExpressionDoubleValuesScript(expr) {
             @Override
             public boolean isResultDeterministic() {
                 return true;

+ 87 - 0
modules/lang-expression/src/test/java/org/elasticsearch/script/expression/ExpressionDoubleValuesScriptTests.java

@@ -0,0 +1,87 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.script.expression;
+
+import org.apache.lucene.expressions.SimpleBindings;
+import org.apache.lucene.search.DoubleValues;
+import org.apache.lucene.search.DoubleValuesSource;
+import org.apache.lucene.search.SortField;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.script.DoubleValuesScript;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptException;
+import org.elasticsearch.script.ScriptModule;
+import org.elasticsearch.script.ScriptService;
+import org.elasticsearch.script.ScriptType;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.text.ParseException;
+import java.util.Collections;
+import java.util.Map;
+
+/**
+ * Tests {@link ExpressionDoubleValuesScript} through the {@link ScriptService}
+ */
+public class ExpressionDoubleValuesScriptTests extends ESTestCase {
+    private ExpressionScriptEngine engine;
+    private ScriptService scriptService;
+
+    @Override
+    public void setUp() throws Exception {
+        super.setUp();
+
+        engine = new ExpressionScriptEngine();
+        scriptService = new ScriptService(Settings.EMPTY, Map.of("expression", engine), ScriptModule.CORE_CONTEXTS, () -> 1L);
+    }
+
+    @SuppressWarnings("unchecked")
+    private DoubleValuesScript compile(String expression) {
+        var script = new Script(ScriptType.INLINE, "expression", expression, Collections.emptyMap());
+        return scriptService.compile(script, DoubleValuesScript.CONTEXT).newInstance();
+    }
+
+    public void testCompileError() {
+        ScriptException e = expectThrows(ScriptException.class, () -> compile("10 * log(10)"));
+        assertTrue(e.getCause() instanceof ParseException);
+        assertEquals("Invalid expression '10 * log(10)': Unrecognized function call (log).", e.getCause().getMessage());
+    }
+
+    public void testEvaluate() {
+        var expression = compile("10 * log10(10)");
+        assertEquals("10 * log10(10)", expression.sourceText());
+        assertEquals(10.0, expression.evaluate(new DoubleValues[0]), 0.00001);
+        assertEquals(10.0, expression.execute(), 0.00001);
+
+        expression = compile("20 * log10(a)");
+        assertEquals("20 * log10(a)", expression.sourceText());
+        assertEquals(20.0, expression.evaluate(new DoubleValues[] { DoubleValues.withDefault(DoubleValues.EMPTY, 10.0) }), 0.00001);
+    }
+
+    public void testDoubleValuesSource() throws IOException {
+        SimpleBindings bindings = new SimpleBindings();
+        bindings.add("popularity", DoubleValuesSource.constant(5));
+
+        var expression = compile("10 * log10(popularity)");
+        var doubleValues = expression.getDoubleValuesSource((name) -> bindings.getDoubleValuesSource(name));
+        assertEquals("expr(10 * log10(popularity))", doubleValues.toString());
+        var values = doubleValues.getValues(null, null);
+        assertTrue(values.advanceExact(0));
+        assertEquals(6, (int) values.doubleValue());
+
+        var sortField = expression.getSortField((name) -> bindings.getDoubleValuesSource(name), false);
+        assertEquals("expr(10 * log10(popularity))", sortField.getField());
+        assertEquals(SortField.Type.CUSTOM, sortField.getType());
+        assertFalse(sortField.getReverse());
+
+        var rescorer = expression.getRescorer((name) -> bindings.getDoubleValuesSource(name));
+        assertNotNull(rescorer);
+    }
+
+}

+ 46 - 0
server/src/main/java/org/elasticsearch/script/DoubleValuesScript.java

@@ -0,0 +1,46 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.script;
+
+import org.apache.lucene.search.DoubleValues;
+import org.apache.lucene.search.DoubleValuesSource;
+import org.apache.lucene.search.Rescorer;
+import org.apache.lucene.search.SortField;
+
+import java.util.function.Function;
+
+/**
+ * A custom script that can be used for various DoubleValue Lucene operations.
+ */
+public abstract class DoubleValuesScript {
+
+    public DoubleValuesScript() {}
+
+    public abstract double execute();
+
+    public abstract double evaluate(DoubleValues[] functionValues);
+
+    public abstract DoubleValuesSource getDoubleValuesSource(Function<String, DoubleValuesSource> sourceProvider);
+
+    public abstract SortField getSortField(Function<String, DoubleValuesSource> sourceProvider, boolean reverse);
+
+    public abstract Rescorer getRescorer(Function<String, DoubleValuesSource> sourceProvider);
+
+    public abstract String sourceText();
+
+    public abstract String[] variables();
+
+    /** A factory to construct {@link DoubleValuesScript} instances. */
+    public interface Factory extends ScriptFactory {
+        DoubleValuesScript newInstance();
+    }
+
+    @SuppressWarnings("rawtypes")
+    public static final ScriptContext<Factory> CONTEXT = new ScriptContext<>("double_values", Factory.class);
+}

+ 2 - 1
server/src/main/java/org/elasticsearch/script/ScriptModule.java

@@ -68,7 +68,8 @@ public class ScriptModule {
                 ScriptedMetricAggContexts.MapScript.CONTEXT,
                 ScriptedMetricAggContexts.CombineScript.CONTEXT,
                 ScriptedMetricAggContexts.ReduceScript.CONTEXT,
-                IntervalFilterScript.CONTEXT
+                IntervalFilterScript.CONTEXT,
+                DoubleValuesScript.CONTEXT
             ),
             RUNTIME_FIELDS_CONTEXTS.stream()
         ).collect(Collectors.toMap(c -> c.name, Function.identity()));

+ 16 - 0
server/src/test/java/org/elasticsearch/plugins/PluginsServiceTests.java

@@ -33,6 +33,7 @@ import java.net.URLClassLoader;
 import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.nio.file.Path;
+import java.security.AccessControlException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -766,6 +767,21 @@ public class PluginsServiceTests extends ESTestCase {
         }
     }
 
+    public void testCanCreateAClassLoader() {
+        assertEquals(
+            "access denied (\"java.lang.RuntimePermission\" \"createClassLoader\")",
+            expectThrows(AccessControlException.class, () -> new Loader(this.getClass().getClassLoader())).getMessage()
+        );
+        var loader = PrivilegedOperations.supplierWithCreateClassLoader(() -> new Loader(this.getClass().getClassLoader()));
+        assertEquals(this.getClass().getClassLoader(), loader.getParent());
+    }
+
+    static final class Loader extends ClassLoader {
+        Loader(ClassLoader parent) {
+            super(parent);
+        }
+    }
+
     // Closes the URLClassLoaders of plugins loaded by the given plugin service.
     static void closePluginLoaders(PluginsService pluginService) {
         for (var lp : pluginService.plugins()) {

+ 9 - 0
test/framework/src/main/java/org/elasticsearch/test/PrivilegedOperations.java

@@ -59,6 +59,15 @@ public final class PrivilegedOperations {
         );
     }
 
+    public static <T> T supplierWithCreateClassLoader(Supplier<T> supplier) {
+        return AccessController.doPrivileged(
+            (PrivilegedAction<T>) () -> supplier.get(),
+            context,
+            new RuntimePermission("createClassLoader"),
+            new RuntimePermission("closeClassLoader")
+        );
+    }
+
     @SuppressForbidden(reason = "need to create file permission")
     private static FilePermission newAllFilesReadPermission() {
         return new FilePermission("<<ALL FILES>>", "read");