Skip to content

Commit

Permalink
Show only intersecting buckets to the Adjacency matrix aggregation (#…
Browse files Browse the repository at this point in the history
…11733)


Signed-off-by: Ivan Brusic <[email protected]>
  • Loading branch information
brusic authored Jan 11, 2025
1 parent fccd6c5 commit 8d5e1a3
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,40 @@ setup:

- match: { aggregations.conns.buckets.3.doc_count: 1 }
- match: { aggregations.conns.buckets.3.key: "4" }


---
"Show only intersections":
- skip:
version: " - 2.99.99"
reason: "show_only_intersecting was added in 3.0.0"
features: node_selector
- do:
node_selector:
version: "3.0.0 - "
search:
index: test
rest_total_hits_as_int: true
body:
size: 0
aggs:
conns:
adjacency_matrix:
show_only_intersecting: true
filters:
1:
term:
num: 1
2:
term:
num: 2
4:
term:
num: 4

- match: { hits.total: 3 }

- length: { aggregations.conns.buckets: 1 }

- match: { aggregations.conns.buckets.0.doc_count: 1 }
- match: { aggregations.conns.buckets.0.key: "1&2" }
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.search.aggregations.bucket.adjacency;

import org.opensearch.Version;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -71,7 +72,10 @@ public class AdjacencyMatrixAggregationBuilder extends AbstractAggregationBuilde

private static final ParseField SEPARATOR_FIELD = new ParseField("separator");
private static final ParseField FILTERS_FIELD = new ParseField("filters");
private static final ParseField SHOW_ONLY_INTERSECTING = new ParseField("show_only_intersecting");

private List<KeyedFilter> filters;
private boolean showOnlyIntersecting = false;
private String separator = DEFAULT_SEPARATOR;

private static final ObjectParser<AdjacencyMatrixAggregationBuilder, String> PARSER = ObjectParser.fromBuilder(
Expand All @@ -81,6 +85,10 @@ public class AdjacencyMatrixAggregationBuilder extends AbstractAggregationBuilde
static {
PARSER.declareString(AdjacencyMatrixAggregationBuilder::separator, SEPARATOR_FIELD);
PARSER.declareNamedObjects(AdjacencyMatrixAggregationBuilder::setFiltersAsList, KeyedFilter.PARSER, FILTERS_FIELD);
PARSER.declareBoolean(
AdjacencyMatrixAggregationBuilder::setShowOnlyIntersecting,
AdjacencyMatrixAggregationBuilder.SHOW_ONLY_INTERSECTING
);
}

public static AggregationBuilder parse(XContentParser parser, String name) throws IOException {
Expand Down Expand Up @@ -115,6 +123,7 @@ protected AdjacencyMatrixAggregationBuilder(
super(clone, factoriesBuilder, metadata);
this.filters = new ArrayList<>(clone.filters);
this.separator = clone.separator;
this.showOnlyIntersecting = clone.showOnlyIntersecting;
}

@Override
Expand All @@ -138,13 +147,50 @@ public AdjacencyMatrixAggregationBuilder(String name, String separator, Map<Stri
setFiltersAsMap(filters);
}

/**
* @param name
* the name of this aggregation
* @param filters
* the filters and their key to use with this aggregation.
* @param showOnlyIntersecting
* show only the buckets that intersection multiple documents
*/
public AdjacencyMatrixAggregationBuilder(String name, Map<String, QueryBuilder> filters, boolean showOnlyIntersecting) {
this(name, DEFAULT_SEPARATOR, filters, showOnlyIntersecting);
}

/**
* @param name
* the name of this aggregation
* @param separator
* the string used to separate keys in intersections buckets e.g.
* &amp; character for keyed filters A and B would return an
* intersection bucket named A&amp;B
* @param filters
* the filters and their key to use with this aggregation.
* @param showOnlyIntersecting
* show only the buckets that intersection multiple documents
*/
public AdjacencyMatrixAggregationBuilder(
String name,
String separator,
Map<String, QueryBuilder> filters,
boolean showOnlyIntersecting
) {
this(name, separator, filters);
this.showOnlyIntersecting = showOnlyIntersecting;
}

/**
* Read from a stream.
*/
public AdjacencyMatrixAggregationBuilder(StreamInput in) throws IOException {
super(in);
int filtersSize = in.readVInt();
separator = in.readString();
if (in.getVersion().onOrAfter(Version.V_3_0_0)) {
showOnlyIntersecting = in.readBoolean();
}
filters = new ArrayList<>(filtersSize);
for (int i = 0; i < filtersSize; i++) {
filters.add(new KeyedFilter(in));
Expand All @@ -155,6 +201,9 @@ public AdjacencyMatrixAggregationBuilder(StreamInput in) throws IOException {
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeVInt(filters.size());
out.writeString(separator);
if (out.getVersion().onOrAfter(Version.V_3_0_0)) {
out.writeBoolean(showOnlyIntersecting);
}
for (KeyedFilter keyedFilter : filters) {
keyedFilter.writeTo(out);
}
Expand Down Expand Up @@ -185,6 +234,11 @@ private AdjacencyMatrixAggregationBuilder setFiltersAsList(List<KeyedFilter> fil
return this;
}

public AdjacencyMatrixAggregationBuilder setShowOnlyIntersecting(boolean showOnlyIntersecting) {
this.showOnlyIntersecting = showOnlyIntersecting;
return this;
}

/**
* Set the separator used to join pairs of bucket keys
*/
Expand Down Expand Up @@ -214,6 +268,10 @@ public Map<String, QueryBuilder> filters() {
return result;
}

public boolean isShowOnlyIntersecting() {
return showOnlyIntersecting;
}

@Override
protected AdjacencyMatrixAggregationBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException {
boolean modified = false;
Expand All @@ -224,7 +282,9 @@ protected AdjacencyMatrixAggregationBuilder doRewrite(QueryRewriteContext queryS
rewrittenFilters.add(new KeyedFilter(kf.key(), rewritten));
}
if (modified) {
return new AdjacencyMatrixAggregationBuilder(name).separator(separator).setFiltersAsList(rewrittenFilters);
return new AdjacencyMatrixAggregationBuilder(name).separator(separator)
.setFiltersAsList(rewrittenFilters)
.setShowOnlyIntersecting(showOnlyIntersecting);
}
return this;
}
Expand All @@ -245,7 +305,16 @@ protected AggregatorFactory doBuild(QueryShardContext queryShardContext, Aggrega
+ "] index level setting."
);
}
return new AdjacencyMatrixAggregatorFactory(name, filters, separator, queryShardContext, parent, subFactoriesBuilder, metadata);
return new AdjacencyMatrixAggregatorFactory(
name,
filters,
showOnlyIntersecting,
separator,
queryShardContext,
parent,
subFactoriesBuilder,
metadata
);
}

@Override
Expand All @@ -257,7 +326,8 @@ public BucketCardinality bucketCardinality() {
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(SEPARATOR_FIELD.getPreferredName(), separator);
builder.startObject(AdjacencyMatrixAggregator.FILTERS_FIELD.getPreferredName());
builder.field(SHOW_ONLY_INTERSECTING.getPreferredName(), showOnlyIntersecting);
builder.startObject(FILTERS_FIELD.getPreferredName());
for (KeyedFilter keyedFilter : filters) {
builder.field(keyedFilter.key(), keyedFilter.filter());
}
Expand All @@ -268,7 +338,7 @@ protected XContentBuilder internalXContent(XContentBuilder builder, Params param

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), filters, separator);
return Objects.hash(super.hashCode(), filters, showOnlyIntersecting, separator);
}

@Override
Expand All @@ -277,7 +347,9 @@ public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) return false;
if (super.equals(obj) == false) return false;
AdjacencyMatrixAggregationBuilder other = (AdjacencyMatrixAggregationBuilder) obj;
return Objects.equals(filters, other.filters) && Objects.equals(separator, other.separator);
return Objects.equals(filters, other.filters)
&& Objects.equals(separator, other.separator)
&& Objects.equals(showOnlyIntersecting, other.showOnlyIntersecting);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -70,8 +69,6 @@
*/
public class AdjacencyMatrixAggregator extends BucketsAggregator {

public static final ParseField FILTERS_FIELD = new ParseField("filters");

/**
* A keyed filter
*
Expand Down Expand Up @@ -145,6 +142,8 @@ public boolean equals(Object obj) {

private final String[] keys;
private final Weight[] filters;

private final boolean showOnlyIntersecting;
private final int totalNumKeys;
private final int totalNumIntersections;
private final String separator;
Expand All @@ -155,6 +154,7 @@ public AdjacencyMatrixAggregator(
String separator,
String[] keys,
Weight[] filters,
boolean showOnlyIntersecting,
SearchContext context,
Aggregator parent,
Map<String, Object> metadata
Expand All @@ -163,6 +163,7 @@ public AdjacencyMatrixAggregator(
this.separator = separator;
this.keys = keys;
this.filters = filters;
this.showOnlyIntersecting = showOnlyIntersecting;
this.totalNumIntersections = ((keys.length * keys.length) - keys.length) / 2;
this.totalNumKeys = keys.length + totalNumIntersections;
}
Expand All @@ -177,10 +178,12 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
return new LeafBucketCollectorBase(sub, null) {
@Override
public void collect(int doc, long bucket) throws IOException {
// Check each of the provided filters
for (int i = 0; i < bits.length; i++) {
if (bits[i].get(doc)) {
collectBucket(sub, doc, bucketOrd(bucket, i));
if (!showOnlyIntersecting) {
// Check each of the provided filters
for (int i = 0; i < bits.length; i++) {
if (bits[i].get(doc)) {
collectBucket(sub, doc, bucketOrd(bucket, i));
}
}
}
// Check all the possible intersections of the provided filters
Expand Down Expand Up @@ -229,7 +232,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
for (int i = 0; i < keys.length; i++) {
long bucketOrd = bucketOrd(owningBucketOrds[owningBucketOrdIdx], i);
long docCount = bucketDocCount(bucketOrd);
// Empty buckets are not returned because this aggregation will commonly be used under a
// Empty buckets are not returned because this aggregation will commonly be used under
// a date-histogram where we will look for transactions over time and can expect many
// empty buckets.
if (docCount > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,14 @@ public class AdjacencyMatrixAggregatorFactory extends AggregatorFactory {

private final String[] keys;
private final Weight[] weights;

private final boolean showOnlyIntersecting;
private final String separator;

public AdjacencyMatrixAggregatorFactory(
String name,
List<KeyedFilter> filters,
boolean showOnlyIntersecting,
String separator,
QueryShardContext queryShardContext,
AggregatorFactory parent,
Expand All @@ -79,6 +82,7 @@ public AdjacencyMatrixAggregatorFactory(
Query filter = keyedFilter.filter().toQuery(queryShardContext);
weights[i] = contextSearcher.createWeight(contextSearcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1f);
}
this.showOnlyIntersecting = showOnlyIntersecting;
}

@Override
Expand All @@ -88,7 +92,17 @@ public Aggregator createInternal(
CardinalityUpperBound cardinality,
Map<String, Object> metadata
) throws IOException {
return new AdjacencyMatrixAggregator(name, factories, separator, keys, weights, searchContext, parent, metadata);
return new AdjacencyMatrixAggregator(
name,
factories,
separator,
keys,
weights,
showOnlyIntersecting,
searchContext,
parent,
metadata
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
public class AdjacencyMatrixAggregationBuilderTests extends OpenSearchTestCase {

public void testFilterSizeLimitation() throws Exception {
// filter size grater than max size should thrown a exception
// filter size grater than max size should throw an exception
QueryShardContext queryShardContext = mock(QueryShardContext.class);
IndexShard indexShard = mock(IndexShard.class);
Settings settings = Settings.builder()
Expand Down Expand Up @@ -94,7 +94,7 @@ public void testFilterSizeLimitation() throws Exception {
)
);

// filter size not grater than max size should return an instance of AdjacencyMatrixAggregatorFactory
// filter size not greater than max size should return an instance of AdjacencyMatrixAggregatorFactory
Map<String, QueryBuilder> emptyFilters = Collections.emptyMap();

AdjacencyMatrixAggregationBuilder aggregationBuilder = new AdjacencyMatrixAggregationBuilder("dummy", emptyFilters);
Expand All @@ -106,4 +106,21 @@ public void testFilterSizeLimitation() throws Exception {
+ "removed in a future release! See the breaking changes documentation for the next major version."
);
}

public void testShowOnlyIntersecting() throws Exception {
QueryShardContext queryShardContext = mock(QueryShardContext.class);

Map<String, QueryBuilder> filters = new HashMap<>(3);
for (int i = 0; i < 2; i++) {
QueryBuilder queryBuilder = mock(QueryBuilder.class);
// return builder itself to skip rewrite
when(queryBuilder.rewrite(queryShardContext)).thenReturn(queryBuilder);
filters.put("filter" + i, queryBuilder);
}
AdjacencyMatrixAggregationBuilder builder = new AdjacencyMatrixAggregationBuilder("dummy", filters, true);
assertTrue(builder.isShowOnlyIntersecting());

builder = new AdjacencyMatrixAggregationBuilder("dummy", filters, false);
assertFalse(builder.isShowOnlyIntersecting());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,22 @@ public void testFiltersSameMap() {
assertEquals(original, builder.filters());
assert original != builder.filters();
}

public void testShowOnlyIntersecting() {
Map<String, QueryBuilder> original = new HashMap<>();
original.put("bbb", new MatchNoneQueryBuilder());
original.put("aaa", new MatchNoneQueryBuilder());
AdjacencyMatrixAggregationBuilder builder;
builder = new AdjacencyMatrixAggregationBuilder("my-agg", "&", original, true);
assertTrue(builder.isShowOnlyIntersecting());
}

public void testShowOnlyIntersectingAsFalse() {
Map<String, QueryBuilder> original = new HashMap<>();
original.put("bbb", new MatchNoneQueryBuilder());
original.put("aaa", new MatchNoneQueryBuilder());
AdjacencyMatrixAggregationBuilder builder;
builder = new AdjacencyMatrixAggregationBuilder("my-agg", original, false);
assertFalse(builder.isShowOnlyIntersecting());
}
}

0 comments on commit 8d5e1a3

Please sign in to comment.