Преглед изворни кода

Catch up DLS with recent Lucene changes (#133966)

Adrien Grand пре 2 недеља
родитељ
комит
8e461545a9

+ 85 - 0
server/src/main/java/org/elasticsearch/common/lucene/search/BitsIterator.java

@@ -0,0 +1,85 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.common.lucene.search;
+
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.FixedBitSet;
+
+import java.util.Objects;
+
+/**
+ * A {@link DocIdSetIterator} over set bits of a {@link Bits} instance.
+ */
+public final class BitsIterator extends DocIdSetIterator {
+
+    private static final int WINDOW_SIZE = 1024;
+
+    private final Bits bits;
+
+    private int doc = -1;
+    private final FixedBitSet bitSet;
+    private int from = 0;
+    private int to = 0;
+
+    public BitsIterator(Bits bits) {
+        this.bits = Objects.requireNonNull(bits);
+        // 1024 bits may sound heavy at first sight but it's only a long[16] under the hood
+        bitSet = new FixedBitSet(WINDOW_SIZE);
+    }
+
+    @Override
+    public int docID() {
+        return doc;
+    }
+
+    @Override
+    public int nextDoc() {
+        return advance(docID() + 1);
+    }
+
+    @Override
+    public int advance(int target) {
+        for (;;) {
+            if (target >= to) {
+                if (target >= bits.length()) {
+                    return doc = NO_MORE_DOCS;
+                }
+                refill(target);
+            }
+
+            int next = bitSet.nextSetBit(target - from);
+            if (next != NO_MORE_DOCS) {
+                return doc = from + next;
+            } else {
+                target = to;
+            }
+        }
+    }
+
+    private void refill(int target) {
+        assert target >= to;
+        from = target;
+        bitSet.set(0, WINDOW_SIZE);
+        if (bits.length() - from < WINDOW_SIZE) {
+            to = bits.length();
+            bitSet.clear(to - from, WINDOW_SIZE);
+        } else {
+            to = from + WINDOW_SIZE;
+        }
+        bits.applyMask(bitSet, from);
+    }
+
+    @Override
+    public long cost() {
+        // We have no better estimate
+        return bits.length();
+    }
+}

+ 0 - 127
server/src/main/java/org/elasticsearch/lucene/util/CombinedBitSet.java

@@ -1,127 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the "Elastic License
- * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
- * Public License v 1"; you may not use this file except in compliance with, at
- * your election, the "Elastic License 2.0", the "GNU Affero General Public
- * License v3.0 only", or the "Server Side Public License, v 1".
- */
-
-package org.elasticsearch.lucene.util;
-
-import org.apache.lucene.search.DocIdSetIterator;
-import org.apache.lucene.util.BitSet;
-import org.apache.lucene.util.Bits;
-
-/**
- * A {@link BitSet} implementation that combines two instances of {@link BitSet} and {@link Bits}
- * to provide a single merged view.
- */
-public final class CombinedBitSet extends BitSet implements Bits {
-    private final BitSet first;
-    private final Bits second;
-    private final int length;
-
-    public CombinedBitSet(BitSet first, Bits second) {
-        this.first = first;
-        this.second = second;
-        this.length = first.length();
-    }
-
-    public BitSet getFirst() {
-        return first;
-    }
-
-    /**
-     * This implementation is slow and requires to iterate over all bits to compute
-     * the intersection. Use {@link #approximateCardinality()} for
-     * a fast approximation.
-     */
-    @Override
-    public int cardinality() {
-        int card = 0;
-        for (int i = 0; i < length; i++) {
-            card += get(i) ? 1 : 0;
-        }
-        return card;
-    }
-
-    @Override
-    public int approximateCardinality() {
-        return first.cardinality();
-    }
-
-    @Override
-    public int prevSetBit(int index) {
-        assert index >= 0 && index < length : "index=" + index + ", numBits=" + length();
-        int prev = first.prevSetBit(index);
-        while (prev != -1 && second.get(prev) == false) {
-            if (prev == 0) {
-                return -1;
-            }
-            prev = first.prevSetBit(prev - 1);
-        }
-        return prev;
-    }
-
-    @Override
-    public int nextSetBit(int index) {
-        assert index >= 0 && index < length : "index=" + index + " numBits=" + length();
-        int next = first.nextSetBit(index);
-        while (next != DocIdSetIterator.NO_MORE_DOCS && second.get(next) == false) {
-            if (next == length() - 1) {
-                return DocIdSetIterator.NO_MORE_DOCS;
-            }
-            next = first.nextSetBit(next + 1);
-        }
-        return next;
-    }
-
-    @Override
-    public int nextSetBit(int index, int upperBound) {
-        assert index >= 0 && index < length : "index=" + index + " numBits=" + length();
-        int next = first.nextSetBit(index, upperBound);
-        while (next != DocIdSetIterator.NO_MORE_DOCS && second.get(next) == false) {
-            if (next == length() - 1) {
-                return DocIdSetIterator.NO_MORE_DOCS;
-            }
-            next = first.nextSetBit(next + 1, upperBound);
-        }
-        return next;
-    }
-
-    @Override
-    public long ramBytesUsed() {
-        return first.ramBytesUsed();
-    }
-
-    @Override
-    public boolean get(int index) {
-        return first.get(index) && second.get(index);
-    }
-
-    @Override
-    public int length() {
-        return length;
-    }
-
-    @Override
-    public void set(int i) {
-        throw new UnsupportedOperationException("not implemented");
-    }
-
-    @Override
-    public void clear(int i) {
-        throw new UnsupportedOperationException("not implemented");
-    }
-
-    @Override
-    public void clear(int startIndex, int endIndex) {
-        throw new UnsupportedOperationException("not implemented");
-    }
-
-    @Override
-    public boolean getAndSet(int i) {
-        throw new UnsupportedOperationException("not implemented");
-    }
-}

+ 47 - 0
server/src/main/java/org/elasticsearch/lucene/util/CombinedBits.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.lucene.util;
+
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.FixedBitSet;
+
+/**
+ * A {@link Bits} implementation that combines two  {@link Bits} instances by and-ing them to provide a single merged view.
+ */
+public final class CombinedBits implements Bits {
+    private final Bits first;
+    private final Bits second;
+    private final int length;
+
+    public CombinedBits(Bits first, Bits second) {
+        if (first.length() != second.length()) {
+            throw new IllegalArgumentException("Provided bits have different lengths: " + first.length() + " != " + second.length());
+        }
+        this.first = first;
+        this.second = second;
+        this.length = first.length();
+    }
+
+    @Override
+    public boolean get(int index) {
+        return first.get(index) && second.get(index);
+    }
+
+    @Override
+    public int length() {
+        return length;
+    }
+
+    @Override
+    public void applyMask(FixedBitSet bitSet, int offset) {
+        first.applyMask(bitSet, offset);
+        second.applyMask(bitSet, offset);
+    }
+}

+ 9 - 26
server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java

@@ -32,12 +32,9 @@ import org.apache.lucene.search.Scorer;
 import org.apache.lucene.search.TermStatistics;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.search.similarities.Similarity;
-import org.apache.lucene.util.BitSet;
-import org.apache.lucene.util.BitSetIterator;
 import org.apache.lucene.util.Bits;
-import org.apache.lucene.util.SparseFixedBitSet;
+import org.elasticsearch.common.lucene.search.BitsIterator;
 import org.elasticsearch.core.Releasable;
-import org.elasticsearch.lucene.util.CombinedBitSet;
 import org.elasticsearch.search.dfs.AggregatedDfs;
 import org.elasticsearch.search.profile.Timer;
 import org.elasticsearch.search.profile.query.ProfileWeight;
@@ -454,8 +451,11 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
             return;
         }
         Bits liveDocs = ctx.reader().getLiveDocs();
-        BitSet liveDocsBitSet = getSparseBitSetOrNull(liveDocs);
-        if (liveDocsBitSet == null) {
+        int numDocs = ctx.reader().numDocs();
+        // This threshold comes from the previous heuristic that checked whether the BitSet was a SparseFixedBitSet, which uses this
+        // threshold at creation time. But a higher threshold would likely perform better?
+        int threshold = ctx.reader().maxDoc() >> 7;
+        if (numDocs >= threshold) {
             BulkScorer bulkScorer = weight.bulkScorer(ctx);
             if (bulkScorer != null) {
                 if (cancellable.isEnabled()) {
@@ -475,7 +475,7 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
                 try {
                     intersectScorerAndBitSet(
                         scorer,
-                        liveDocsBitSet,
+                        liveDocs,
                         leafCollector,
                         this.cancellable.isEnabled() ? cancellable::checkCancelled : () -> {}
                     );
@@ -490,27 +490,10 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
         leafCollector.finish();
     }
 
-    private static BitSet getSparseBitSetOrNull(Bits liveDocs) {
-        if (liveDocs instanceof SparseFixedBitSet) {
-            return (BitSet) liveDocs;
-        } else if (liveDocs instanceof CombinedBitSet
-            // if the underlying role bitset is sparse
-            && ((CombinedBitSet) liveDocs).getFirst() instanceof SparseFixedBitSet) {
-                return (BitSet) liveDocs;
-            } else {
-                return null;
-            }
-
-    }
-
-    static void intersectScorerAndBitSet(Scorer scorer, BitSet acceptDocs, LeafCollector collector, Runnable checkCancelled)
+    static void intersectScorerAndBitSet(Scorer scorer, Bits acceptDocs, LeafCollector collector, Runnable checkCancelled)
         throws IOException {
         collector.setScorer(scorer);
-        // ConjunctionDISI uses the DocIdSetIterator#cost() to order the iterators, so if roleBits has the lowest cardinality it should
-        // be used first:
-        DocIdSetIterator iterator = ConjunctionUtils.intersectIterators(
-            Arrays.asList(new BitSetIterator(acceptDocs, acceptDocs.approximateCardinality()), scorer.iterator())
-        );
+        DocIdSetIterator iterator = ConjunctionUtils.intersectIterators(Arrays.asList(new BitsIterator(acceptDocs), scorer.iterator()));
         int seen = 0;
         checkCancelled.run();
         for (int docId = iterator.nextDoc(); docId < DocIdSetIterator.NO_MORE_DOCS; docId = iterator.nextDoc()) {

+ 51 - 0
server/src/test/java/org/elasticsearch/common/lucene/search/BitsIteratorTests.java

@@ -0,0 +1,51 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.common.lucene.search;
+
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.FixedBitSet;
+import org.elasticsearch.test.ESTestCase;
+
+public class BitsIteratorTests extends ESTestCase {
+
+    public void testEmpty() {
+        Bits bits = new Bits.MatchNoBits(10_000);
+        BitsIterator iterator = new BitsIterator((bits));
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc());
+    }
+
+    public void testSingleBit() {
+        FixedBitSet bits = new FixedBitSet(10_000);
+        bits.set(5000);
+
+        BitsIterator iterator = new BitsIterator((bits));
+        assertEquals(5000, iterator.nextDoc());
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc());
+
+        iterator = new BitsIterator((bits));
+        assertEquals(5000, iterator.advance(5000));
+
+        iterator = new BitsIterator((bits));
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.advance(5001));
+    }
+
+    public void testEverySecondBit() {
+        FixedBitSet bits = new FixedBitSet(10_000);
+        for (int i = 0; i < bits.length(); i += 2) {
+            bits.set(i);
+        }
+        BitsIterator iterator = new BitsIterator((bits));
+        for (int i = 0; i < bits.length(); i += 2) {
+            assertEquals(i, iterator.nextDoc());
+        }
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc());
+    }
+}

+ 18 - 8
server/src/test/java/org/elasticsearch/lucene/util/CombinedBitSetTests.java → server/src/test/java/org/elasticsearch/lucene/util/CombinedBitsTests.java

@@ -11,11 +11,12 @@ package org.elasticsearch.lucene.util;
 
 import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.util.BitSet;
+import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.FixedBitSet;
 import org.apache.lucene.util.SparseFixedBitSet;
 import org.elasticsearch.test.ESTestCase;
 
-public class CombinedBitSetTests extends ESTestCase {
+public class CombinedBitsTests extends ESTestCase {
     public void testEmpty() {
         for (float percent : new float[] { 0f, 0.1f, 0.5f, 0.9f, 1f }) {
             testCase(randomIntBetween(1, 10000), 0f, percent);
@@ -47,16 +48,11 @@ public class CombinedBitSetTests extends ESTestCase {
     private void testCase(int numBits, float percent1, float percent2) {
         BitSet first = randomSet(numBits, percent1);
         BitSet second = randomSet(numBits, percent2);
-        CombinedBitSet actual = new CombinedBitSet(first, second);
+        CombinedBits actual = new CombinedBits(first, second);
         FixedBitSet expected = new FixedBitSet(numBits);
         or(expected, first);
         and(expected, second);
-        assertEquals(expected.cardinality(), actual.cardinality());
         assertEquals(expected, actual, numBits);
-        for (int i = 0; i < numBits; ++i) {
-            assertEquals(expected.nextSetBit(i), actual.nextSetBit(i));
-            assertEquals(Integer.toString(i), expected.prevSetBit(i), actual.prevSetBit(i));
-        }
     }
 
     private void or(BitSet set1, BitSet set2) {
@@ -77,10 +73,24 @@ public class CombinedBitSetTests extends ESTestCase {
         }
     }
 
-    private void assertEquals(BitSet set1, BitSet set2, int maxDoc) {
+    private void assertEquals(Bits set1, Bits set2, int maxDoc) {
         for (int i = 0; i < maxDoc; ++i) {
             assertEquals("Different at " + i, set1.get(i), set2.get(i));
         }
+
+        FixedBitSet bitSet1 = new FixedBitSet(100);
+        FixedBitSet bitSet2 = new FixedBitSet(100);
+        for (int from = 0; from < maxDoc; from += bitSet1.length()) {
+            bitSet1.set(0, bitSet1.length());
+            bitSet2.set(0, bitSet1.length());
+            if (from + bitSet1.length() > maxDoc) {
+                bitSet1.clear(maxDoc - from, bitSet1.length());
+                bitSet2.clear(maxDoc - from, bitSet1.length());
+            }
+            set1.applyMask(bitSet1, from);
+            set2.applyMask(bitSet2, from);
+            assertEquals(bitSet1, bitSet2);
+        }
     }
 
     private BitSet randomSet(int numBits, float percentSet) {

+ 6 - 6
server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java

@@ -66,7 +66,7 @@ import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.cache.bitset.BitsetFilterCache;
 import org.elasticsearch.index.shard.ShardId;
-import org.elasticsearch.lucene.util.CombinedBitSet;
+import org.elasticsearch.lucene.util.CombinedBits;
 import org.elasticsearch.lucene.util.MatchAllBitSet;
 import org.elasticsearch.search.aggregations.BucketCollector;
 import org.elasticsearch.search.aggregations.LeafBucketCollector;
@@ -131,7 +131,7 @@ public class ContextIndexSearcherTests extends ESTestCase {
 
         LeafReaderContext leaf = searcher.getIndexReader().leaves().get(0);
 
-        CombinedBitSet bitSet = new CombinedBitSet(query(leaf, "field1", "value1"), leaf.reader().getLiveDocs());
+        CombinedBits bitSet = new CombinedBits(query(leaf, "field1", "value1"), leaf.reader().getLiveDocs());
         LeafCollector leafCollector = new LeafBucketCollector() {
             Scorable scorer;
 
@@ -148,7 +148,7 @@ public class ContextIndexSearcherTests extends ESTestCase {
         };
         intersectScorerAndBitSet(weight.scorer(leaf), bitSet, leafCollector, () -> {});
 
-        bitSet = new CombinedBitSet(query(leaf, "field1", "value2"), leaf.reader().getLiveDocs());
+        bitSet = new CombinedBits(query(leaf, "field1", "value2"), leaf.reader().getLiveDocs());
         leafCollector = new LeafBucketCollector() {
             @Override
             public void collect(int doc, long bucket) throws IOException {
@@ -157,7 +157,7 @@ public class ContextIndexSearcherTests extends ESTestCase {
         };
         intersectScorerAndBitSet(weight.scorer(leaf), bitSet, leafCollector, () -> {});
 
-        bitSet = new CombinedBitSet(query(leaf, "field1", "value3"), leaf.reader().getLiveDocs());
+        bitSet = new CombinedBits(query(leaf, "field1", "value3"), leaf.reader().getLiveDocs());
         leafCollector = new LeafBucketCollector() {
             @Override
             public void collect(int doc, long bucket) throws IOException {
@@ -166,7 +166,7 @@ public class ContextIndexSearcherTests extends ESTestCase {
         };
         intersectScorerAndBitSet(weight.scorer(leaf), bitSet, leafCollector, () -> {});
 
-        bitSet = new CombinedBitSet(query(leaf, "field1", "value4"), leaf.reader().getLiveDocs());
+        bitSet = new CombinedBits(query(leaf, "field1", "value4"), leaf.reader().getLiveDocs());
         leafCollector = new LeafBucketCollector() {
             @Override
             public void collect(int doc, long bucket) throws IOException {
@@ -715,7 +715,7 @@ public class ContextIndexSearcherTests extends ESTestCase {
                 return roleQueryBits;
             } else {
                 // apply deletes when needed:
-                return new CombinedBitSet(roleQueryBits, actualLiveDocs);
+                return new CombinedBits(roleQueryBits, actualLiveDocs);
             }
         }
 

+ 17 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/accesscontrol/DocumentSubsetReader.java

@@ -17,13 +17,14 @@ import org.apache.lucene.store.AlreadyClosedException;
 import org.apache.lucene.util.BitSet;
 import org.apache.lucene.util.BitSetIterator;
 import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.FixedBitSet;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.common.cache.Cache;
 import org.elasticsearch.common.cache.CacheBuilder;
 import org.elasticsearch.common.logging.LoggerMessageFormat;
 import org.elasticsearch.common.lucene.index.SequentialStoredFieldsLeafReader;
-import org.elasticsearch.lucene.util.CombinedBitSet;
+import org.elasticsearch.lucene.util.CombinedBits;
 import org.elasticsearch.lucene.util.MatchAllBitSet;
 import org.elasticsearch.transport.Transports;
 
@@ -66,14 +67,22 @@ public final class DocumentSubsetReader extends SequentialStoredFieldsLeafReader
             // slow
             return roleQueryBits.cardinality();
         } else {
-            // very slow, but necessary in order to be correct
+            // slower, but necessary in order to be correct
             int numDocs = 0;
-            DocIdSetIterator it = new BitSetIterator(roleQueryBits, 0L); // we don't use the cost
+            // Temporary bit set, just to do the counting
+            FixedBitSet bitSet = new FixedBitSet(1024);
+            DocIdSetIterator roleBitsIterator = new BitSetIterator(roleQueryBits, 0L); // we don't use the cost
             try {
-                for (int doc = it.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = it.nextDoc()) {
-                    if (liveDocs.get(doc)) {
-                        numDocs++;
-                    }
+                for (int from = roleBitsIterator.nextDoc(); from != DocIdSetIterator.NO_MORE_DOCS; from = roleBitsIterator.docID()) {
+                    bitSet.clear();
+
+                    // OR role bits into `bitSet`
+                    int upTo = (int) Math.min((long) from + bitSet.length(), Integer.MAX_VALUE);
+                    roleBitsIterator.intoBitSet(upTo, bitSet, from);
+
+                    // And then AND live doc bits into `bitSet`
+                    liveDocs.applyMask(bitSet, from);
+                    numDocs += bitSet.cardinality();
                 }
                 return numDocs;
             } catch (IOException e) {
@@ -204,7 +213,7 @@ public final class DocumentSubsetReader extends SequentialStoredFieldsLeafReader
             return roleQueryBits;
         } else {
             // apply deletes when needed:
-            return new CombinedBitSet(roleQueryBits, actualLiveDocs);
+            return new CombinedBits(roleQueryBits, actualLiveDocs);
         }
     }