Skip to content

Commit

Permalink
remove circular dependency
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Aug 7, 2024
1 parent 234eb44 commit 9040f6f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.io.IOException;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

/**
* This interface provides a bridge between an aggregator and the optimization context, allowing
Expand All @@ -31,18 +32,17 @@
*/
public abstract class AggregatorBridge {

/**
* The optimization context associated with this aggregator bridge.
*/
FilterRewriteOptimizationContext filterRewriteOptimizationContext;

/**
* The field type associated with this aggregator bridge.
*/
MappedFieldType fieldType;

void setOptimizationContext(FilterRewriteOptimizationContext context) {
this.filterRewriteOptimizationContext = context;
Consumer<Ranges> setRanges;
BiConsumer<Integer, Ranges> setRangesFromSegment;

void setRangesConsumer(Consumer<Ranges> setRanges, BiConsumer<Integer, Ranges> setRangesFromSegment) {
this.setRanges = setRanges;
this.setRangesFromSegment = setRangesFromSegment;
}

/**
Expand Down Expand Up @@ -72,7 +72,13 @@ void setOptimizationContext(FilterRewriteOptimizationContext context) {
*
* @param values the point values (index structure for numeric values) for a segment
* @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket
* @param leafOrd
* @param consumeDebugInfo
* @param ranges
*/
protected abstract void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, int leafOrd) throws IOException;
abstract void tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Consumer<FilterRewriteOptimizationContext.DebugInfo> consumeDebugInfo,
Ranges ranges
) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.io.IOException;
import java.util.OptionalLong;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;

import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;
Expand All @@ -32,6 +33,8 @@
*/
public abstract class DateHistogramAggregatorBridge extends AggregatorBridge {

int maxRewriteFilters;

protected boolean canOptimize(ValuesSourceConfig config) {
if (config.script() == null && config.missing() == null) {
MappedFieldType fieldType = config.fieldType();
Expand All @@ -47,16 +50,17 @@ protected boolean canOptimize(ValuesSourceConfig config) {

protected void buildRanges(SearchContext context) throws IOException {
long[] bounds = Helper.getDateHistoAggBounds(context, fieldType.name());
filterRewriteOptimizationContext.setRanges(buildRanges(bounds));
this.maxRewriteFilters = context.maxAggRewriteFilters();
setRanges.accept(buildRanges(bounds, maxRewriteFilters));
}

@Override
protected void prepareFromSegment(LeafReaderContext leaf) throws IOException {
long[] bounds = Helper.getSegmentBounds(leaf, fieldType.name());
filterRewriteOptimizationContext.setRangesFromSegment(leaf.ord, buildRanges(bounds));
setRangesFromSegment.accept(leaf.ord, buildRanges(bounds, maxRewriteFilters));
}

private Ranges buildRanges(long[] bounds) {
private Ranges buildRanges(long[] bounds, int maxRewriteFilters) {
bounds = processHardBounds(bounds);
if (bounds == null) {
return null;
Expand All @@ -79,7 +83,7 @@ private Ranges buildRanges(long[] bounds) {
getRoundingPrepared(),
bounds[0],
bounds[1],
filterRewriteOptimizationContext.maxAggRewriteFilters
maxRewriteFilters
);
}

Expand Down Expand Up @@ -123,21 +127,23 @@ protected int getSize() {
}

@Override
protected final void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, int leafOrd) throws IOException {
final void tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Consumer<FilterRewriteOptimizationContext.DebugInfo> consumeDebugInfo,
Ranges ranges
) throws IOException {
int size = getSize();

DateFieldMapper.DateFieldType fieldType = getFieldType();
Ranges ranges = filterRewriteOptimizationContext.getRanges(leafOrd);
BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {
long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long bucketOrd = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(bucketOrd, (long) docCount);
};

filterRewriteOptimizationContext.consumeDebugInfo(
multiRangesTraverse(values.getPointTree(), filterRewriteOptimizationContext.getRanges(leafOrd), incrementFunc, size)
);
consumeDebugInfo.accept(multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size));
}

private static long getBucketOrd(long bucketOrd) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ public final class FilterRewriteOptimizationContext {
private boolean preparedAtShardLevel = false;

private final AggregatorBridge aggregatorBridge;
int maxAggRewriteFilters;
private String shardId;

private Ranges ranges;
Expand Down Expand Up @@ -72,8 +71,8 @@ private boolean canOptimize(final Object parent, final int subAggLength, SearchC

boolean canOptimize = aggregatorBridge.canOptimize();
if (canOptimize) {
aggregatorBridge.setOptimizationContext(this);
this.maxAggRewriteFilters = context.maxAggRewriteFilters();
aggregatorBridge.setRangesConsumer(this::setRanges, this::setRangesFromSegment);

this.shardId = context.indexShard().shardId().toString();

assert ranges == null : "Ranges should only be built once at shard level, but they are already built";
Expand Down Expand Up @@ -139,7 +138,7 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Lon
Ranges ranges = tryBuildRangesFromSegment(leafCtx, segmentMatchAll);
if (ranges == null) return false;

aggregatorBridge.tryOptimize(values, incrementDocCount, leafCtx.ord);
aggregatorBridge.tryOptimize(values, incrementDocCount, this::consumeDebugInfo, getRanges(leafCtx.ord));

optimizedSegments++;
logger.debug("Fast filter optimization applied to shard {} segment {}", shardId, leafCtx.ord);
Expand Down Expand Up @@ -190,8 +189,8 @@ public void populateDebugInfo(BiConsumer<String, Object> add) {
if (optimizedSegments > 0) {
add.accept("optimized_segments", optimizedSegments);
add.accept("unoptimized_segments", segments - optimizedSegments);
add.accept("leaf_node_visited", leafNodeVisited);
add.accept("inner_node_visited", innerNodeVisited);
add.accept("leaf_visited", leafNodeVisited);
add.accept("inner_visited", innerNodeVisited);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.io.IOException;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;

import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;
Expand Down Expand Up @@ -65,7 +66,7 @@ protected void buildRanges(RangeAggregator.Range[] ranges) {
uppers[i] = upper;
}

filterRewriteOptimizationContext.setRanges(new Ranges(lowers, uppers));
setRanges.accept(new Ranges(lowers, uppers));
}

@Override
Expand All @@ -74,17 +75,20 @@ public void prepareFromSegment(LeafReaderContext leaf) {
}

@Override
protected final void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, int leafOrd) throws IOException {
final void tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Consumer<FilterRewriteOptimizationContext.DebugInfo> consumeDebugInfo,
Ranges ranges
) throws IOException {
int size = Integer.MAX_VALUE;

BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {
long bucketOrd = bucketOrdProducer().apply(activeIndex);
incrementDocCount.accept(bucketOrd, (long) docCount);
};

filterRewriteOptimizationContext.consumeDebugInfo(
multiRangesTraverse(values.getPointTree(), filterRewriteOptimizationContext.getRanges(leafOrd), incrementFunc, size)
);
consumeDebugInfo.accept(multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size));
}

/**
Expand Down

0 comments on commit 9040f6f

Please sign in to comment.