diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/FilterPartitionBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/FilterPartitionBenchmark.java index 6908b72909dc..405a33a21842 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/FilterPartitionBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/FilterPartitionBenchmark.java @@ -66,6 +66,7 @@ import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.filter.OrFilter; import org.apache.druid.segment.filter.SelectorFilter; +import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException; import org.apache.druid.segment.generator.DataGenerator; import org.apache.druid.segment.generator.GeneratorBasicSchemas; import org.apache.druid.segment.generator.GeneratorSchemaInfo; @@ -370,7 +371,7 @@ public void readOrFilter(Blackhole blackhole) @Benchmark @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MICROSECONDS) - public void readOrFilterCNF(Blackhole blackhole) + public void readOrFilterCNF(Blackhole blackhole) throws CNFFilterExplosionException { Filter filter = new NoBitmapSelectorFilter("dimSequential", "199"); Filter filter2 = new AndFilter(Arrays.asList(new SelectorFilter("dimMultivalEnumerated2", "Corundum"), new NoBitmapSelectorFilter("dimMultivalEnumerated", "Bar"))); @@ -421,7 +422,7 @@ public void readComplexOrFilter(Blackhole blackhole) @Benchmark @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MICROSECONDS) - public void readComplexOrFilterCNF(Blackhole blackhole) + public void readComplexOrFilterCNF(Blackhole blackhole) throws CNFFilterExplosionException { DimFilter dimFilter1 = new OrDimFilter(Arrays.asList( new SelectorDimFilter("dimSequential", "199", null), diff --git a/processing/src/main/java/org/apache/druid/segment/filter/Filters.java b/processing/src/main/java/org/apache/druid/segment/filter/Filters.java index abecb0d48a5b..98a2c3f4f6ba 100644 --- a/processing/src/main/java/org/apache/druid/segment/filter/Filters.java +++ b/processing/src/main/java/org/apache/druid/segment/filter/Filters.java @@ -45,6 +45,7 @@ import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.data.CloseableIndexed; import org.apache.druid.segment.data.Indexed; +import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException; import org.apache.druid.segment.filter.cnf.CalciteCnfHelper; import org.apache.druid.segment.filter.cnf.HiveCnfHelper; import org.apache.druid.segment.join.filter.AllNullColumnSelectorFactory; @@ -433,10 +434,15 @@ public static Filter convertToCNFFromQueryContext(Query query, @Nullable Filter return null; } boolean useCNF = query.getContextBoolean(QueryContexts.USE_FILTER_CNF_KEY, QueryContexts.DEFAULT_USE_FILTER_CNF); - return useCNF ? Filters.toCnf(filter) : filter; + try { + return useCNF ? Filters.toCnf(filter) : filter; + } + catch (CNFFilterExplosionException cnfFilterExplosionException) { + return filter; // cannot convert to CNF, return the filter as is + } } - public static Filter toCnf(Filter current) + public static Filter toCnf(Filter current) throws CNFFilterExplosionException { // Push down NOT filters to leaves if possible to remove NOT on NOT filters and reduce hierarchy. // ex) ~(a OR ~b) => ~a AND b @@ -578,7 +584,7 @@ public static Optional maybeOr(final List filters) * * @return The normalized or clauses for the provided filter. */ - public static List toNormalizedOrClauses(Filter filter) + public static List toNormalizedOrClauses(Filter filter) throws CNFFilterExplosionException { Filter normalizedFilter = Filters.toCnf(filter); diff --git a/processing/src/main/java/org/apache/druid/segment/filter/cnf/CNFFilterExplosionException.java b/processing/src/main/java/org/apache/druid/segment/filter/cnf/CNFFilterExplosionException.java new file mode 100644 index 000000000000..45014ffd181d --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/filter/cnf/CNFFilterExplosionException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.filter.cnf; + +import org.apache.druid.java.util.common.StringUtils; + +public class CNFFilterExplosionException extends Exception +{ + public CNFFilterExplosionException(String formatText, Object... arguments) + { + super(StringUtils.nonStrictFormat(formatText, arguments)); + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/filter/cnf/HiveCnfHelper.java b/processing/src/main/java/org/apache/druid/segment/filter/cnf/HiveCnfHelper.java index 5c5ca1010c44..9e2992e0f5ea 100644 --- a/processing/src/main/java/org/apache/druid/segment/filter/cnf/HiveCnfHelper.java +++ b/processing/src/main/java/org/apache/druid/segment/filter/cnf/HiveCnfHelper.java @@ -19,6 +19,7 @@ package org.apache.druid.segment.filter.cnf; +import org.apache.druid.java.util.common.NonnullPair; import org.apache.druid.query.filter.BooleanFilter; import org.apache.druid.query.filter.Filter; import org.apache.druid.segment.filter.AndFilter; @@ -37,6 +38,8 @@ */ public class HiveCnfHelper { + private static final int CNF_MAX_FILTER_THRESHOLD = 10_000; + public static Filter pushDownNot(Filter current) { if (current instanceof NotFilter) { @@ -78,17 +81,37 @@ public static Filter pushDownNot(Filter current) return current; } - public static Filter convertToCnf(Filter current) + public static Filter convertToCnf(Filter current) throws CNFFilterExplosionException + { + return convertToCnfWithLimit(current, CNF_MAX_FILTER_THRESHOLD).lhs; + } + + /** + * Converts a filter to CNF form with a limit on filter count + * @param maxCNFFilterLimit the maximum number of filters allowed in CNF conversion + * @return a pair of the CNF converted filter and the new remaining filter limit + * @throws CNFFilterExplosionException is thrown if the filters in CNF representation go beyond maxCNFFilterLimit + */ + private static NonnullPair convertToCnfWithLimit( + Filter current, + int maxCNFFilterLimit + ) throws CNFFilterExplosionException { if (current instanceof NotFilter) { - return new NotFilter(convertToCnf(((NotFilter) current).getBaseFilter())); + NonnullPair result = convertToCnfWithLimit(((NotFilter) current).getBaseFilter(), maxCNFFilterLimit); + return new NonnullPair<>(new NotFilter(result.lhs), result.rhs); } if (current instanceof AndFilter) { List children = new ArrayList<>(); for (Filter child : ((AndFilter) current).getFilters()) { - children.add(convertToCnf(child)); + NonnullPair result = convertToCnfWithLimit(child, maxCNFFilterLimit); + children.add(result.lhs); + maxCNFFilterLimit = result.rhs; + if (maxCNFFilterLimit < 0) { + throw new CNFFilterExplosionException("Exceeded maximum allowed filters for CNF (conjunctive normal form) conversion"); + } } - return Filters.and(children); + return new NonnullPair<>(Filters.and(children), maxCNFFilterLimit); } if (current instanceof OrFilter) { // a list of leaves that weren't under AND expressions @@ -107,11 +130,11 @@ public static Filter convertToCnf(Filter current) } if (!andList.isEmpty()) { List result = new ArrayList<>(); - generateAllCombinations(result, andList, nonAndList); - return Filters.and(result); + generateAllCombinations(result, andList, nonAndList, maxCNFFilterLimit); + return new NonnullPair<>(Filters.and(result), maxCNFFilterLimit - result.size()); } } - return current; + return new NonnullPair<>(current, maxCNFFilterLimit); } public static Filter flatten(Filter root) @@ -158,8 +181,9 @@ public static Filter flatten(Filter root) private static void generateAllCombinations( List result, List andList, - List nonAndList - ) + List nonAndList, + int maxAllowedFilters + ) throws CNFFilterExplosionException { List children = new ArrayList<>(((AndFilter) andList.get(0)).getFilters()); if (result.isEmpty()) { @@ -168,6 +192,9 @@ private static void generateAllCombinations( a.add(child); // Result must receive an actual OrFilter, so wrap if Filters.or managed to un-OR it. result.add(idempotentOr(Filters.or(a))); + if (result.size() > maxAllowedFilters) { + throw new CNFFilterExplosionException("Exceeded maximum allowed filters for CNF (conjunctive normal form) conversion"); + } } } else { List work = new ArrayList<>(result); @@ -178,11 +205,14 @@ private static void generateAllCombinations( a.add(child); // Result must receive an actual OrFilter. result.add(idempotentOr(Filters.or(a))); + if (result.size() > maxAllowedFilters) { + throw new CNFFilterExplosionException("Exceeded maximum allowed filters for CNF (conjunctive normal form) conversion"); + } } } } if (andList.size() > 1) { - generateAllCombinations(result, andList.subList(1, andList.size()), nonAndList); + generateAllCombinations(result, andList.subList(1, andList.size()), nonAndList, maxAllowedFilters); } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java index 2420369d9e8d..5db5bb1b7c7b 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java @@ -31,6 +31,7 @@ import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.filter.OrFilter; import org.apache.druid.segment.filter.SelectorFilter; +import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException; import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import javax.annotation.Nullable; @@ -101,7 +102,13 @@ public static JoinFilterPreAnalysis computeJoinFilterPreAnalysis(final JoinFilte return preAnalysisBuilder.build(); } - List normalizedOrClauses = Filters.toNormalizedOrClauses(key.getFilter()); + List normalizedOrClauses; + try { + normalizedOrClauses = Filters.toNormalizedOrClauses(key.getFilter()); + } + catch (CNFFilterExplosionException cnfFilterExplosionException) { + return preAnalysisBuilder.build(); // disable the filter pushdown and rewrite optimization + } List normalizedBaseTableClauses = new ArrayList<>(); List normalizedJoinTableClauses = new ArrayList<>(); diff --git a/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java b/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java index c74c6eff2f92..ab4cca94fbb4 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java @@ -81,6 +81,7 @@ import org.apache.druid.segment.data.ConciseBitmapSerdeFactory; import org.apache.druid.segment.data.IndexedInts; import org.apache.druid.segment.data.RoaringBitmapSerdeFactory; +import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException; import org.apache.druid.segment.incremental.IncrementalIndex; import org.apache.druid.segment.incremental.IncrementalIndexSchema; import org.apache.druid.segment.incremental.IncrementalIndexStorageAdapter; @@ -367,7 +368,12 @@ private Filter makeFilter(final DimFilter dimFilter) final DimFilter maybeOptimized = optimize ? dimFilter.optimize() : dimFilter; final Filter filter = maybeOptimized.toFilter(); - return cnf ? Filters.toCnf(filter) : filter; + try { + return cnf ? Filters.toCnf(filter) : filter; + } + catch (CNFFilterExplosionException cnfFilterExplosionException) { + throw new RuntimeException(cnfFilterExplosionException); + } } private DimFilter maybeOptimize(final DimFilter dimFilter) diff --git a/processing/src/test/java/org/apache/druid/segment/filter/FilterCnfConversionTest.java b/processing/src/test/java/org/apache/druid/segment/filter/FilterCnfConversionTest.java index e9b5304b25db..5867920b4bdc 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/FilterCnfConversionTest.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/FilterCnfConversionTest.java @@ -29,6 +29,7 @@ import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.DimensionSelector; import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException; import org.apache.druid.segment.filter.cnf.CalciteCnfHelper; import org.apache.druid.segment.filter.cnf.HiveCnfHelper; import org.junit.Assert; @@ -113,7 +114,7 @@ public void testFlattenUnflattenable() } @Test - public void testToCnfWithMuchReducibleFilter() + public void testToCnfWithMuchReducibleFilter() throws CNFFilterExplosionException { final Filter muchReducible = FilterTestUtils.and( // should be flattened @@ -158,7 +159,7 @@ public void testToCnfWithMuchReducibleFilter() } @Test - public void testToNormalizedOrClausesWithMuchReducibleFilter() + public void testToNormalizedOrClausesWithMuchReducibleFilter() throws CNFFilterExplosionException { final Filter muchReducible = FilterTestUtils.and( // should be flattened @@ -203,7 +204,7 @@ public void testToNormalizedOrClausesWithMuchReducibleFilter() } @Test - public void testToCnfWithComplexFilterIncludingNotAndOr() + public void testToCnfWithComplexFilterIncludingNotAndOr() throws CNFFilterExplosionException { final Filter filter = FilterTestUtils.and( FilterTestUtils.or( @@ -307,7 +308,7 @@ public void testToCnfWithComplexFilterIncludingNotAndOr() } @Test - public void testToNormalizedOrClausesWithComplexFilterIncludingNotAndOr() + public void testToNormalizedOrClausesWithComplexFilterIncludingNotAndOr() throws CNFFilterExplosionException { final Filter filter = FilterTestUtils.and( FilterTestUtils.or( @@ -411,7 +412,7 @@ public void testToNormalizedOrClausesWithComplexFilterIncludingNotAndOr() } @Test - public void testToCnfCollapsibleBigFilter() + public void testToCnfCollapsibleBigFilter() throws CNFFilterExplosionException { List ands = new ArrayList<>(); List ors = new ArrayList<>(); @@ -477,7 +478,7 @@ public void testPullNotPullableFilter() } @Test - public void testToCnfFilterThatPullCannotConvertToCnfProperly() + public void testToCnfFilterThatPullCannotConvertToCnfProperly() throws CNFFilterExplosionException { final Filter filter = FilterTestUtils.or( FilterTestUtils.and( @@ -507,7 +508,7 @@ public void testToCnfFilterThatPullCannotConvertToCnfProperly() } @Test - public void testToNormalizedOrClausesNonAndFilterShouldReturnSingleton() + public void testToNormalizedOrClausesNonAndFilterShouldReturnSingleton() throws CNFFilterExplosionException { Filter filter = FilterTestUtils.or( FilterTestUtils.selector("col1", "val1"), @@ -528,6 +529,78 @@ public void testTrueFalseFilterRequiredColumnRewrite() Assert.assertEquals(FalseFilter.instance(), FalseFilter.instance().rewriteRequiredColumns(ImmutableMap.of())); } + @Test(expected = CNFFilterExplosionException.class) + public void testExceptionOnCNFFilterExplosion() throws CNFFilterExplosionException + { + Filter filter = FilterTestUtils.or( + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val1"), + FilterTestUtils.selector("col2", "val2") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val3"), + FilterTestUtils.selector("col2", "val4") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val1"), + FilterTestUtils.selector("col2", "val3") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val3"), + FilterTestUtils.selector("col2", "val2") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val5"), + FilterTestUtils.selector("col2", "val6") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val5"), + FilterTestUtils.selector("col2", "val7") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val6"), + FilterTestUtils.selector("col2", "val7") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val6"), + FilterTestUtils.selector("col2", "val8") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val7"), + FilterTestUtils.selector("col2", "val9") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val8"), + FilterTestUtils.selector("col2", "val9") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val4"), + FilterTestUtils.selector("col2", "val9") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val4"), + FilterTestUtils.selector("col2", "val8") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val5"), + FilterTestUtils.selector("col2", "val2") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val5"), + FilterTestUtils.selector("col2", "val1") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val7"), + FilterTestUtils.selector("col2", "val0") + ), + FilterTestUtils.and( + FilterTestUtils.selector("col1", "val9"), + FilterTestUtils.selector("col2", "val8") + ) + ); + Filters.toNormalizedOrClauses(filter); + } + private void assertFilter(Filter original, Filter expectedConverted, Filter actualConverted) { assertEquivalent(original, expectedConverted); diff --git a/processing/src/test/java/org/apache/druid/segment/filter/FilterPartitionTest.java b/processing/src/test/java/org/apache/druid/segment/filter/FilterPartitionTest.java index 323b6f29d697..00752463c75e 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/FilterPartitionTest.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/FilterPartitionTest.java @@ -43,6 +43,7 @@ import org.apache.druid.segment.IndexBuilder; import org.apache.druid.segment.QueryableIndexStorageAdapter; import org.apache.druid.segment.StorageAdapter; +import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Test; @@ -609,7 +610,7 @@ public void testMissingColumnNotSpecifiedInDimensionList() } @Test - public void testDistributeOrCNF() + public void testDistributeOrCNF() throws CNFFilterExplosionException { DimFilter dimFilter1 = new OrDimFilter(Arrays.asList( new SelectorDimFilter("dim0", "6", null), @@ -663,7 +664,7 @@ public void testDistributeOrCNF() } @Test - public void testDistributeOrCNFExtractionFn() + public void testDistributeOrCNFExtractionFn() throws CNFFilterExplosionException { DimFilter dimFilter1 = new OrDimFilter(Arrays.asList( new SelectorDimFilter("dim0", "super-6", JS_EXTRACTION_FN),