Skip to content

Commit

Permalink
Guard against exponential increase of filters during CNF conversion (#…
Browse files Browse the repository at this point in the history
…12314)

Currently, the CNF conversion of a filter is unbounded, which means that it can create as many filters as possible thereby also leading to OOMs in historical heap. We should throw an error or disable CNF conversion if the filter count starts getting out of hand. There are ways to do CNF conversion with linear increase in filters as well but that has been left out of the scope of this change since those algorithms add new variables in the predicate - which can be contentious.
  • Loading branch information
rohangarg authored Mar 9, 2022
1 parent 0600772 commit 56fbd2a
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.filter.OrFilter;
import org.apache.druid.segment.filter.SelectorFilter;
import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException;
import org.apache.druid.segment.generator.DataGenerator;
import org.apache.druid.segment.generator.GeneratorBasicSchemas;
import org.apache.druid.segment.generator.GeneratorSchemaInfo;
Expand Down Expand Up @@ -370,7 +371,7 @@ public void readOrFilter(Blackhole blackhole)
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public void readOrFilterCNF(Blackhole blackhole)
public void readOrFilterCNF(Blackhole blackhole) throws CNFFilterExplosionException
{
Filter filter = new NoBitmapSelectorFilter("dimSequential", "199");
Filter filter2 = new AndFilter(Arrays.asList(new SelectorFilter("dimMultivalEnumerated2", "Corundum"), new NoBitmapSelectorFilter("dimMultivalEnumerated", "Bar")));
Expand Down Expand Up @@ -421,7 +422,7 @@ public void readComplexOrFilter(Blackhole blackhole)
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public void readComplexOrFilterCNF(Blackhole blackhole)
public void readComplexOrFilterCNF(Blackhole blackhole) throws CNFFilterExplosionException
{
DimFilter dimFilter1 = new OrDimFilter(Arrays.asList(
new SelectorDimFilter("dimSequential", "199", null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.data.CloseableIndexed;
import org.apache.druid.segment.data.Indexed;
import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException;
import org.apache.druid.segment.filter.cnf.CalciteCnfHelper;
import org.apache.druid.segment.filter.cnf.HiveCnfHelper;
import org.apache.druid.segment.join.filter.AllNullColumnSelectorFactory;
Expand Down Expand Up @@ -433,10 +434,15 @@ public static Filter convertToCNFFromQueryContext(Query query, @Nullable Filter
return null;
}
boolean useCNF = query.getContextBoolean(QueryContexts.USE_FILTER_CNF_KEY, QueryContexts.DEFAULT_USE_FILTER_CNF);
return useCNF ? Filters.toCnf(filter) : filter;
try {
return useCNF ? Filters.toCnf(filter) : filter;
}
catch (CNFFilterExplosionException cnfFilterExplosionException) {
return filter; // cannot convert to CNF, return the filter as is
}
}

public static Filter toCnf(Filter current)
public static Filter toCnf(Filter current) throws CNFFilterExplosionException
{
// Push down NOT filters to leaves if possible to remove NOT on NOT filters and reduce hierarchy.
// ex) ~(a OR ~b) => ~a AND b
Expand Down Expand Up @@ -578,7 +584,7 @@ public static Optional<Filter> maybeOr(final List<Filter> filters)
*
* @return The normalized or clauses for the provided filter.
*/
public static List<Filter> toNormalizedOrClauses(Filter filter)
public static List<Filter> toNormalizedOrClauses(Filter filter) throws CNFFilterExplosionException
{
Filter normalizedFilter = Filters.toCnf(filter);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.segment.filter.cnf;

import org.apache.druid.java.util.common.StringUtils;

public class CNFFilterExplosionException extends Exception
{
public CNFFilterExplosionException(String formatText, Object... arguments)
{
super(StringUtils.nonStrictFormat(formatText, arguments));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.druid.segment.filter.cnf;

import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.query.filter.BooleanFilter;
import org.apache.druid.query.filter.Filter;
import org.apache.druid.segment.filter.AndFilter;
Expand All @@ -37,6 +38,8 @@
*/
public class HiveCnfHelper
{
private static final int CNF_MAX_FILTER_THRESHOLD = 10_000;

public static Filter pushDownNot(Filter current)
{
if (current instanceof NotFilter) {
Expand Down Expand Up @@ -78,17 +81,37 @@ public static Filter pushDownNot(Filter current)
return current;
}

public static Filter convertToCnf(Filter current)
public static Filter convertToCnf(Filter current) throws CNFFilterExplosionException
{
return convertToCnfWithLimit(current, CNF_MAX_FILTER_THRESHOLD).lhs;
}

/**
* Converts a filter to CNF form with a limit on filter count
* @param maxCNFFilterLimit the maximum number of filters allowed in CNF conversion
* @return a pair of the CNF converted filter and the new remaining filter limit
* @throws CNFFilterExplosionException is thrown if the filters in CNF representation go beyond maxCNFFilterLimit
*/
private static NonnullPair<Filter, Integer> convertToCnfWithLimit(
Filter current,
int maxCNFFilterLimit
) throws CNFFilterExplosionException
{
if (current instanceof NotFilter) {
return new NotFilter(convertToCnf(((NotFilter) current).getBaseFilter()));
NonnullPair<Filter, Integer> result = convertToCnfWithLimit(((NotFilter) current).getBaseFilter(), maxCNFFilterLimit);
return new NonnullPair<>(new NotFilter(result.lhs), result.rhs);
}
if (current instanceof AndFilter) {
List<Filter> children = new ArrayList<>();
for (Filter child : ((AndFilter) current).getFilters()) {
children.add(convertToCnf(child));
NonnullPair<Filter, Integer> result = convertToCnfWithLimit(child, maxCNFFilterLimit);
children.add(result.lhs);
maxCNFFilterLimit = result.rhs;
if (maxCNFFilterLimit < 0) {
throw new CNFFilterExplosionException("Exceeded maximum allowed filters for CNF (conjunctive normal form) conversion");
}
}
return Filters.and(children);
return new NonnullPair<>(Filters.and(children), maxCNFFilterLimit);
}
if (current instanceof OrFilter) {
// a list of leaves that weren't under AND expressions
Expand All @@ -107,11 +130,11 @@ public static Filter convertToCnf(Filter current)
}
if (!andList.isEmpty()) {
List<Filter> result = new ArrayList<>();
generateAllCombinations(result, andList, nonAndList);
return Filters.and(result);
generateAllCombinations(result, andList, nonAndList, maxCNFFilterLimit);
return new NonnullPair<>(Filters.and(result), maxCNFFilterLimit - result.size());
}
}
return current;
return new NonnullPair<>(current, maxCNFFilterLimit);
}

public static Filter flatten(Filter root)
Expand Down Expand Up @@ -158,8 +181,9 @@ public static Filter flatten(Filter root)
private static void generateAllCombinations(
List<Filter> result,
List<Filter> andList,
List<Filter> nonAndList
)
List<Filter> nonAndList,
int maxAllowedFilters
) throws CNFFilterExplosionException
{
List<Filter> children = new ArrayList<>(((AndFilter) andList.get(0)).getFilters());
if (result.isEmpty()) {
Expand All @@ -168,6 +192,9 @@ private static void generateAllCombinations(
a.add(child);
// Result must receive an actual OrFilter, so wrap if Filters.or managed to un-OR it.
result.add(idempotentOr(Filters.or(a)));
if (result.size() > maxAllowedFilters) {
throw new CNFFilterExplosionException("Exceeded maximum allowed filters for CNF (conjunctive normal form) conversion");
}
}
} else {
List<Filter> work = new ArrayList<>(result);
Expand All @@ -178,11 +205,14 @@ private static void generateAllCombinations(
a.add(child);
// Result must receive an actual OrFilter.
result.add(idempotentOr(Filters.or(a)));
if (result.size() > maxAllowedFilters) {
throw new CNFFilterExplosionException("Exceeded maximum allowed filters for CNF (conjunctive normal form) conversion");
}
}
}
}
if (andList.size() > 1) {
generateAllCombinations(result, andList.subList(1, andList.size()), nonAndList);
generateAllCombinations(result, andList.subList(1, andList.size()), nonAndList, maxAllowedFilters);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.filter.OrFilter;
import org.apache.druid.segment.filter.SelectorFilter;
import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -101,7 +102,13 @@ public static JoinFilterPreAnalysis computeJoinFilterPreAnalysis(final JoinFilte
return preAnalysisBuilder.build();
}

List<Filter> normalizedOrClauses = Filters.toNormalizedOrClauses(key.getFilter());
List<Filter> normalizedOrClauses;
try {
normalizedOrClauses = Filters.toNormalizedOrClauses(key.getFilter());
}
catch (CNFFilterExplosionException cnfFilterExplosionException) {
return preAnalysisBuilder.build(); // disable the filter pushdown and rewrite optimization
}

List<Filter> normalizedBaseTableClauses = new ArrayList<>();
List<Filter> normalizedJoinTableClauses = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import org.apache.druid.segment.data.ConciseBitmapSerdeFactory;
import org.apache.druid.segment.data.IndexedInts;
import org.apache.druid.segment.data.RoaringBitmapSerdeFactory;
import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException;
import org.apache.druid.segment.incremental.IncrementalIndex;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.incremental.IncrementalIndexStorageAdapter;
Expand Down Expand Up @@ -367,7 +368,12 @@ private Filter makeFilter(final DimFilter dimFilter)

final DimFilter maybeOptimized = optimize ? dimFilter.optimize() : dimFilter;
final Filter filter = maybeOptimized.toFilter();
return cnf ? Filters.toCnf(filter) : filter;
try {
return cnf ? Filters.toCnf(filter) : filter;
}
catch (CNFFilterExplosionException cnfFilterExplosionException) {
throw new RuntimeException(cnfFilterExplosionException);
}
}

private DimFilter maybeOptimize(final DimFilter dimFilter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.filter.cnf.CNFFilterExplosionException;
import org.apache.druid.segment.filter.cnf.CalciteCnfHelper;
import org.apache.druid.segment.filter.cnf.HiveCnfHelper;
import org.junit.Assert;
Expand Down Expand Up @@ -113,7 +114,7 @@ public void testFlattenUnflattenable()
}

@Test
public void testToCnfWithMuchReducibleFilter()
public void testToCnfWithMuchReducibleFilter() throws CNFFilterExplosionException
{
final Filter muchReducible = FilterTestUtils.and(
// should be flattened
Expand Down Expand Up @@ -158,7 +159,7 @@ public void testToCnfWithMuchReducibleFilter()
}

@Test
public void testToNormalizedOrClausesWithMuchReducibleFilter()
public void testToNormalizedOrClausesWithMuchReducibleFilter() throws CNFFilterExplosionException
{
final Filter muchReducible = FilterTestUtils.and(
// should be flattened
Expand Down Expand Up @@ -203,7 +204,7 @@ public void testToNormalizedOrClausesWithMuchReducibleFilter()
}

@Test
public void testToCnfWithComplexFilterIncludingNotAndOr()
public void testToCnfWithComplexFilterIncludingNotAndOr() throws CNFFilterExplosionException
{
final Filter filter = FilterTestUtils.and(
FilterTestUtils.or(
Expand Down Expand Up @@ -307,7 +308,7 @@ public void testToCnfWithComplexFilterIncludingNotAndOr()
}

@Test
public void testToNormalizedOrClausesWithComplexFilterIncludingNotAndOr()
public void testToNormalizedOrClausesWithComplexFilterIncludingNotAndOr() throws CNFFilterExplosionException
{
final Filter filter = FilterTestUtils.and(
FilterTestUtils.or(
Expand Down Expand Up @@ -411,7 +412,7 @@ public void testToNormalizedOrClausesWithComplexFilterIncludingNotAndOr()
}

@Test
public void testToCnfCollapsibleBigFilter()
public void testToCnfCollapsibleBigFilter() throws CNFFilterExplosionException
{
List<Filter> ands = new ArrayList<>();
List<Filter> ors = new ArrayList<>();
Expand Down Expand Up @@ -477,7 +478,7 @@ public void testPullNotPullableFilter()
}

@Test
public void testToCnfFilterThatPullCannotConvertToCnfProperly()
public void testToCnfFilterThatPullCannotConvertToCnfProperly() throws CNFFilterExplosionException
{
final Filter filter = FilterTestUtils.or(
FilterTestUtils.and(
Expand Down Expand Up @@ -507,7 +508,7 @@ public void testToCnfFilterThatPullCannotConvertToCnfProperly()
}

@Test
public void testToNormalizedOrClausesNonAndFilterShouldReturnSingleton()
public void testToNormalizedOrClausesNonAndFilterShouldReturnSingleton() throws CNFFilterExplosionException
{
Filter filter = FilterTestUtils.or(
FilterTestUtils.selector("col1", "val1"),
Expand All @@ -528,6 +529,78 @@ public void testTrueFalseFilterRequiredColumnRewrite()
Assert.assertEquals(FalseFilter.instance(), FalseFilter.instance().rewriteRequiredColumns(ImmutableMap.of()));
}

@Test(expected = CNFFilterExplosionException.class)
public void testExceptionOnCNFFilterExplosion() throws CNFFilterExplosionException
{
Filter filter = FilterTestUtils.or(
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val1"),
FilterTestUtils.selector("col2", "val2")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val3"),
FilterTestUtils.selector("col2", "val4")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val1"),
FilterTestUtils.selector("col2", "val3")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val3"),
FilterTestUtils.selector("col2", "val2")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val5"),
FilterTestUtils.selector("col2", "val6")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val5"),
FilterTestUtils.selector("col2", "val7")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val6"),
FilterTestUtils.selector("col2", "val7")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val6"),
FilterTestUtils.selector("col2", "val8")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val7"),
FilterTestUtils.selector("col2", "val9")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val8"),
FilterTestUtils.selector("col2", "val9")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val4"),
FilterTestUtils.selector("col2", "val9")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val4"),
FilterTestUtils.selector("col2", "val8")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val5"),
FilterTestUtils.selector("col2", "val2")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val5"),
FilterTestUtils.selector("col2", "val1")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val7"),
FilterTestUtils.selector("col2", "val0")
),
FilterTestUtils.and(
FilterTestUtils.selector("col1", "val9"),
FilterTestUtils.selector("col2", "val8")
)
);
Filters.toNormalizedOrClauses(filter);
}

private void assertFilter(Filter original, Filter expectedConverted, Filter actualConverted)
{
assertEquivalent(original, expectedConverted);
Expand Down
Loading

0 comments on commit 56fbd2a

Please sign in to comment.