瀏覽代碼

Script: Def encoding parser (#74840)

Parse and validate the def encoding string.

Broken out of e26fa4e
Stuart Tettemer 4 年之前
父節點
當前提交
e4df4d7205

+ 106 - 32
modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java

@@ -24,6 +24,7 @@ import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -219,11 +220,10 @@ public final class Def {
          int upTo = 1;
          for (int i = 1; i < numArguments; i++) {
              if (lambdaArgs.get(i - 1)) {
-                 String signature = (String) args[upTo++];
-                 int numCaptures = Integer.parseInt(signature.substring(signature.indexOf(',')+1));
-                 arity -= numCaptures;
+                 Def.Encoding signature = new Def.Encoding((String) args[upTo++]);
+                 arity -= signature.numCaptures;
                  // arity in painlessLookup does not include 'this' reference
-                 if (signature.charAt(1) == 't') {
+                 if (signature.needsInstance) {
                      arity--;
                  }
              }
@@ -251,17 +251,10 @@ public final class Def {
          for (int i = 1; i < numArguments; i++) {
              // its a functional reference, replace the argument with an impl
              if (lambdaArgs.get(i - 1)) {
-                 // decode signature of form 'type.call,2'
-                 String signature = (String) args[upTo++];
-                 int separator = signature.lastIndexOf('.');
-                 int separator2 = signature.indexOf(',');
-                 String type = signature.substring(2, separator);
-                 boolean needsScriptInstance = signature.charAt(1) == 't';
-                 String call = signature.substring(separator+1, separator2);
-                 int numCaptures = Integer.parseInt(signature.substring(separator2+1));
+                 Def.Encoding defEncoding = new Encoding((String) args[upTo++]);
                  MethodHandle filter;
-                 Class<?> interfaceType = method.typeParameters.get(i - 1 - replaced - (needsScriptInstance ? 1 : 0));
-                 if (signature.charAt(0) == 'S') {
+                 Class<?> interfaceType = method.typeParameters.get(i - 1 - replaced - (defEncoding.needsInstance ? 1 : 0));
+                 if (defEncoding.isStatic) {
                      // the implementation is strongly typed, now that we know the interface type,
                      // we have everything.
                      filter = lookupReferenceInternal(painlessLookup,
@@ -269,16 +262,16 @@ public final class Def {
                                                       constants,
                                                       methodHandlesLookup,
                                                       interfaceType,
-                                                      type,
-                                                      call,
-                                                      numCaptures,
-                                                      needsScriptInstance
+                                                      defEncoding.symbol,
+                                                      defEncoding.methodName,
+                                                      defEncoding.numCaptures,
+                                                      defEncoding.needsInstance
                      );
-                 } else if (signature.charAt(0) == 'D') {
+                } else {
                      // the interface type is now known, but we need to get the implementation.
                      // this is dynamically based on the receiver type (and cached separately, underneath
                      // this cache). It won't blow up since we never nest here (just references)
-                     Class<?>[] captures = new Class<?>[numCaptures];
+                     Class<?>[] captures = new Class<?>[defEncoding.numCaptures];
                      for (int capture = 0; capture < captures.length; capture++) {
                          captures[capture] = callSiteType.parameterType(i + 1 + capture);
                      }
@@ -287,20 +280,18 @@ public final class Def {
                                                               functions,
                                                               constants,
                                                               methodHandlesLookup,
-                                                              call,
+                                                              defEncoding.methodName,
                                                               nestedType,
                                                               0,
                                                               DefBootstrap.REFERENCE,
                                                               PainlessLookupUtility.typeToCanonicalTypeName(interfaceType));
                      filter = nested.dynamicInvoker();
-                 } else {
-                     throw new AssertionError();
-                 }
+                }
                  // the filter now ignores the signature (placeholder) on the stack
                  filter = MethodHandles.dropArguments(filter, 0, String.class);
-                 handle = MethodHandles.collectArguments(handle, i - (needsScriptInstance ? 1 : 0), filter);
-                 i += numCaptures;
-                 replaced += numCaptures;
+                 handle = MethodHandles.collectArguments(handle, i - (defEncoding.needsInstance ? 1 : 0), filter);
+                 i += defEncoding.numCaptures;
+                 replaced += defEncoding.numCaptures;
              }
          }
 
@@ -1278,29 +1269,112 @@ public final class Def {
         private ArrayIndexNormalizeHelper() {}
     }
 
+
     public static class Encoding {
         public final boolean isStatic;
         public final boolean needsInstance;
         public final String symbol;
         public final String methodName;
         public final int numCaptures;
+
+        /**
+         * Encoding is passed to invokedynamic to help DefBootstrap find the method.  invokedynamic can only take
+         * "Class, java.lang.invoke.MethodHandle, java.lang.invoke.MethodType, String, int, long, float, or double" types to
+         * help find the callsite, which is why this object is encoded as a String for indy.
+         * See: https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-6.html#jvms-6.5.invokedynamic
+         * */
         public final String encoding;
 
+        private static final String FORMAT = "[SD][tf]symbol.methodName,numCaptures";
+
         public Encoding(boolean isStatic, boolean needsInstance, String symbol, String methodName, int numCaptures) {
             this.isStatic = isStatic;
             this.needsInstance = needsInstance;
-            this.symbol = symbol;
-            this.methodName = methodName;
+            this.symbol = Objects.requireNonNull(symbol);
+            this.methodName = Objects.requireNonNull(methodName);
             this.numCaptures = numCaptures;
             this.encoding = (isStatic ? "S" : "D") + (needsInstance ? "t" : "f") +
-                            symbol + "." +
-                            methodName + "," +
-                            numCaptures;
+                    symbol + "." +
+                    methodName + "," +
+                    numCaptures;
+
+
+            if ("this".equals(symbol)) {
+                if (isStatic == false) {
+                    throw new IllegalArgumentException("Def.Encoding must be static if symbol is 'this', encoding [" + encoding + "]");
+                }
+            } else {
+                if (needsInstance) {
+                    throw new IllegalArgumentException("Def.Encoding symbol must be 'this', not [" + symbol + "] if needsInstance," +
+                        " encoding [" + encoding + "]");
+                }
+            }
+
+            if (methodName.isEmpty()) {
+                throw new IllegalArgumentException("methodName must be non-empty, encoding [" + encoding + "]");
+            }
+            if (numCaptures < 0) {
+                throw new IllegalArgumentException("numCaptures must be non-negative, not [" + numCaptures + "]," +
+                    " encoding: [" + encoding + "]");
+            }
+        }
+
+        // Parsing constructor, does minimal validation to avoid extra work during runtime
+        public Encoding(String encoding) {
+            this.encoding = Objects.requireNonNull(encoding);
+            if (encoding.length() < 6) {
+                throw new IllegalArgumentException("Encoding too short. Minimum 6, given [" + encoding.length() + "]," +
+                    " encoding: [" + encoding + "], format: " + FORMAT + "");
+            }
+
+            // 'S' or 'D'
+            this.isStatic = encoding.charAt(0) == 'S';
+
+            // 't' or 'f'
+            this.needsInstance = encoding.charAt(1) == 't';
+
+            int dotIndex = encoding.lastIndexOf('.');
+            if (dotIndex < 2) {
+                throw new IllegalArgumentException("Invalid symbol, could not find '.' at expected position after index 1, instead found" +
+                    " index [" + dotIndex + "], encoding: [" + encoding + "], format: " + FORMAT);
+            }
+
+            this.symbol = encoding.substring(2, dotIndex);
+
+            int commaIndex = encoding.indexOf(',');
+            if (commaIndex <= dotIndex) {
+                throw new IllegalArgumentException("Invalid symbol, could not find ',' at expected position after '.' at" +
+                    " [" + dotIndex + "], instead found index [" + commaIndex + "], encoding: [" + encoding + "], format: " + FORMAT);
+            }
+
+            this.methodName = encoding.substring(dotIndex + 1, commaIndex);
+
+            if (commaIndex == encoding.length() - 1) {
+                throw new IllegalArgumentException("Invalid symbol, could not find ',' at expected position, instead found" +
+                    " index [" + commaIndex + "], encoding: [" + encoding + "], format: " + FORMAT);
+            }
+
+            this.numCaptures = Integer.parseUnsignedInt(encoding.substring(commaIndex + 1));
         }
 
         @Override
         public String toString() {
             return encoding;
         }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if ((o instanceof Encoding) == false) return false;
+            Encoding encoding1 = (Encoding) o;
+            return isStatic == encoding1.isStatic && needsInstance == encoding1.needsInstance && numCaptures == encoding1.numCaptures
+                && Objects.equals(symbol, encoding1.symbol) && Objects.equals(methodName, encoding1.methodName)
+                && Objects.equals(encoding, encoding1.encoding);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(isStatic, needsInstance, symbol, methodName, numCaptures, encoding);
+        }
     }
 }

+ 62 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/DefEncodingTests.java

@@ -0,0 +1,62 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.painless;
+
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.startsWith;
+
+public class DefEncodingTests extends ESTestCase {
+
+    public void testParse() {
+        assertEquals(new Def.Encoding(true, false, "java.util.Comparator", "thenComparing", 1),
+            new Def.Encoding("Sfjava.util.Comparator.thenComparing,1"));
+
+        assertEquals(new Def.Encoding(false, false, "ft0", "augmentInjectMultiTimesX", 1),
+            new Def.Encoding("Dfft0.augmentInjectMultiTimesX,1"));
+
+        assertEquals(new Def.Encoding(false, false, "x", "concat", 1),
+            new Def.Encoding("Dfx.concat,1"));
+
+        assertEquals(new Def.Encoding(true, false, "java.lang.StringBuilder", "setLength", 1),
+            new Def.Encoding("Sfjava.lang.StringBuilder.setLength,1"));
+
+        assertEquals(new Def.Encoding(true, false, "org.elasticsearch.painless.FeatureTestObject", "overloadedStatic", 0),
+            new Def.Encoding("Sforg.elasticsearch.painless.FeatureTestObject.overloadedStatic,0"));
+
+        assertEquals(new Def.Encoding(true, false, "this", "lambda$synthetic$0", 1),
+            new Def.Encoding("Sfthis.lambda$synthetic$0,1"));
+
+        assertEquals(new Def.Encoding(true, true, "this", "lambda$synthetic$0", 2),
+            new Def.Encoding("Stthis.lambda$synthetic$0,2"));
+
+        assertEquals(new Def.Encoding(true, true, "this", "mycompare", 0),
+            new Def.Encoding("Stthis.mycompare,0"));
+    }
+
+    public void testValidate() {
+        IllegalArgumentException expected = expectThrows(IllegalArgumentException.class,
+            () -> new Def.Encoding(false, false, "this", "myMethod", 0));
+
+        assertThat(expected.getMessage(),
+            startsWith("Def.Encoding must be static if symbol is 'this', encoding [Dfthis.myMethod,0]"));
+
+        expected = expectThrows(IllegalArgumentException.class,
+            () -> new Def.Encoding(true, true, "org.elasticsearch.painless.FeatureTestObject", "overloadedStatic", 0));
+
+        assertThat(expected.getMessage(),
+            startsWith("Def.Encoding symbol must be 'this', not [org.elasticsearch.painless.FeatureTestObject] if needsInstance"));
+
+        expected = expectThrows(IllegalArgumentException.class,
+            () -> new Def.Encoding(false, false, "x", "", 1));
+
+        assertThat(expected.getMessage(),
+            startsWith("methodName must be non-empty, encoding [Dfx.,1]"));
+    }
+}