Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix string first/last aggregator comparator #12773

Merged
merged 2 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

package org.apache.druid.query.aggregation;

import org.apache.druid.collections.SerializablePair;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory;
import org.apache.druid.segment.GenericColumnSerializer;
import org.apache.druid.segment.column.ColumnBuilder;
import org.apache.druid.segment.data.GenericIndexed;
Expand All @@ -34,6 +34,7 @@

import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.Comparator;

/**
* The SerializablePairLongStringSerde serializes a Long-String pair (SerializablePairLongString).
Expand All @@ -46,6 +47,12 @@ public class SerializablePairLongStringSerde extends ComplexMetricSerde
{

private static final String TYPE_NAME = "serializablePairLongString";
// Null SerializablePairLongString values are put first
private static final Comparator<SerializablePairLongString> COMPARATOR = Comparator.nullsFirst(
// assumes that the LHS of the pair will never be null
Comparator.<SerializablePairLongString>comparingLong(SerializablePair::getLhs)
.thenComparing(SerializablePair::getRhs, Comparator.nullsFirst(Comparator.naturalOrder()))
);

@Override
public String getTypeName()
Expand Down Expand Up @@ -87,7 +94,7 @@ public ObjectStrategy getObjectStrategy()
@Override
public int compare(@Nullable SerializablePairLongString o1, @Nullable SerializablePairLongString o2)
{
return StringFirstAggregatorFactory.VALUE_COMPARATOR.compare(o1, o2);
return COMPARATOR.compare(o1, o2);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Longs;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.query.aggregation.AggregateCombiner;
import org.apache.druid.query.aggregation.Aggregator;
Expand Down Expand Up @@ -87,42 +88,11 @@ public void aggregate(ByteBuffer buf, int position)
((SerializablePairLongString) o2).lhs
);

public static final Comparator<SerializablePairLongString> VALUE_COMPARATOR = (o1, o2) -> {
int comparation;

// First we check if the objects are null
if (o1 == null && o2 == null) {
comparation = 0;
} else if (o1 == null) {
comparation = -1;
} else if (o2 == null) {
comparation = 1;
} else {

// If the objects are not null, we will try to compare using timestamp
comparation = o1.lhs.compareTo(o2.lhs);

// If both timestamp are the same, we try to compare the Strings
if (comparation == 0) {

// First we check if the strings are null
if (o1.rhs == null && o2.rhs == null) {
comparation = 0;
} else if (o1.rhs == null) {
comparation = -1;
} else if (o2.rhs == null) {
comparation = 1;
} else {

// If the strings are not null, we will compare them
// Note: This comparation maybe doesn't make sense to first/last aggregators
comparation = o1.rhs.compareTo(o2.rhs);
}
}
}

return comparation;
};
// used in comparing aggregation results amongst distinct groups. hence the comparison is done on the finalized
// result which is string/value part of the result pair. Null SerializablePairLongString values are put first.
public static final Comparator<SerializablePairLongString> VALUE_COMPARATOR = Comparator.nullsFirst(
Comparator.comparing(SerializablePair::getRhs, Comparator.nullsFirst(Comparator.naturalOrder()))
);

private final String fieldName;
private final String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.junit.Test;

import java.nio.ByteBuffer;
import java.util.Comparator;

public class StringFirstAggregationTest extends InitializedNullHandlingTest
{
Expand Down Expand Up @@ -215,6 +216,32 @@ public void testStringFirstAggregateCombiner()
Assert.assertEquals(pairs[1], stringFirstAggregateCombiner.getObject());
}

@Test
@SuppressWarnings("EqualsWithItself")
public void testStringLastAggregatorComparator()
{
Comparator<SerializablePairLongString> comparator =
(Comparator<SerializablePairLongString>) stringFirstAggFactory.getComparator();
SerializablePairLongString pair1 = new SerializablePairLongString(1L, "Z");
SerializablePairLongString pair2 = new SerializablePairLongString(2L, "A");
SerializablePairLongString pair3 = new SerializablePairLongString(3L, null);

// check non null values
Assert.assertEquals(0, comparator.compare(pair1, pair1));
Assert.assertTrue(comparator.compare(pair1, pair2) > 0);
Assert.assertTrue(comparator.compare(pair2, pair1) < 0);

// check non null value with null value (null values first comparator)
Assert.assertEquals(0, comparator.compare(pair3, pair3));
Assert.assertTrue(comparator.compare(pair1, pair3) > 0);
Assert.assertTrue(comparator.compare(pair3, pair1) < 0);

// check non null pair with null pair (null pairs first comparator)
Assert.assertEquals(0, comparator.compare(null, null));
Assert.assertTrue(comparator.compare(pair1, null) > 0);
Assert.assertTrue(comparator.compare(null, pair1) < 0);
}

private void aggregate(
Aggregator agg
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.junit.Test;

import java.nio.ByteBuffer;
import java.util.Comparator;

public class StringLastAggregationTest
{
Expand Down Expand Up @@ -217,6 +218,32 @@ public void testStringLastAggregateCombiner()
Assert.assertEquals(pairs[1], stringFirstAggregateCombiner.getObject());
}

@Test
@SuppressWarnings("EqualsWithItself")
public void testStringLastAggregatorComparator()
{
Comparator<SerializablePairLongString> comparator =
(Comparator<SerializablePairLongString>) stringLastAggFactory.getComparator();
SerializablePairLongString pair1 = new SerializablePairLongString(1L, "Z");
SerializablePairLongString pair2 = new SerializablePairLongString(2L, "A");
SerializablePairLongString pair3 = new SerializablePairLongString(3L, null);

// check non null values
Assert.assertEquals(0, comparator.compare(pair1, pair1));
Assert.assertTrue(comparator.compare(pair1, pair2) > 0);
Assert.assertTrue(comparator.compare(pair2, pair1) < 0);

// check non null value with null value (null values first comparator)
Assert.assertEquals(0, comparator.compare(pair3, pair3));
Assert.assertTrue(comparator.compare(pair1, pair3) > 0);
Assert.assertTrue(comparator.compare(pair3, pair1) < 0);

// check non null pair with null pair (null pairs first comparator)
Assert.assertEquals(0, comparator.compare(null, null));
Assert.assertTrue(comparator.compare(pair1, null) > 0);
Assert.assertTrue(comparator.compare(null, pair1) < 0);
}

private void aggregate(
Aggregator agg
)
Expand Down