Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add native 'array contains element' filter #15366

Merged
merged 10 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,24 @@ public class SqlBenchmark
"SELECT APPROX_COUNT_DISTINCT_BUILTIN(dimZipf) FROM foo",
"SELECT APPROX_COUNT_DISTINCT_DS_HLL(dimZipf) FROM foo",
"SELECT APPROX_COUNT_DISTINCT_DS_HLL_UTF8(dimZipf) FROM foo",
"SELECT APPROX_COUNT_DISTINCT_DS_THETA(dimZipf) FROM foo"
"SELECT APPROX_COUNT_DISTINCT_DS_THETA(dimZipf) FROM foo",
// 32: LATEST aggregator long
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for moving these from expression benchmarks to here

"SELECT LATEST(long1) FROM foo",
// 33: LATEST aggregator double
"SELECT LATEST(double4) FROM foo",
// 34: LATEST aggregator double
"SELECT LATEST(float3) FROM foo",
// 35: LATEST aggregator double
"SELECT LATEST(float3), LATEST(long1), LATEST(double4) FROM foo",
// 36,37: filter numeric nulls
"SELECT SUM(long5) FROM foo WHERE long5 IS NOT NULL",
"SELECT string2, SUM(long5) FROM foo WHERE long5 IS NOT NULL GROUP BY 1",
// 38: EARLIEST aggregator long
"SELECT EARLIEST(long1) FROM foo",
// 39: EARLIEST aggregator double
"SELECT EARLIEST(double4) FROM foo",
// 40: EARLIEST aggregator float
"SELECT EARLIEST(float3) FROM foo"
);

@Param({"5000000"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer;
Expand All @@ -31,10 +32,12 @@
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.segment.IndexSpec;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.generator.GeneratorBasicSchemas;
import org.apache.druid.segment.generator.GeneratorSchemaInfo;
import org.apache.druid.segment.generator.SegmentGenerator;
import org.apache.druid.segment.transform.TransformSpec;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.server.security.AuthConfig;
Expand Down Expand Up @@ -197,23 +200,8 @@ public String getFormatString()
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long4), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 37: time shift + expr agg (group by), uniform distribution high cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long5), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 38: LATEST aggregator long
"SELECT LATEST(long1) FROM foo",
// 39: LATEST aggregator double
"SELECT LATEST(double4) FROM foo",
// 40: LATEST aggregator double
"SELECT LATEST(float3) FROM foo",
// 41: LATEST aggregator double
"SELECT LATEST(float3), LATEST(long1), LATEST(double4) FROM foo",
// 42,43: filter numeric nulls
"SELECT SUM(long5) FROM foo WHERE long5 IS NOT NULL",
"SELECT string2, SUM(long5) FROM foo WHERE long5 IS NOT NULL GROUP BY 1",
// 44: EARLIEST aggregator long
"SELECT EARLIEST(long1) FROM foo",
// 45: EARLIEST aggregator double
"SELECT EARLIEST(double4) FROM foo",
// 46: EARLIEST aggregator float
"SELECT EARLIEST(float3) FROM foo"
// 38: array filtering
"SELECT string1, long1 FROM foo WHERE ARRAY_CONTAINS(\"multi-string3\", 100) GROUP BY 1,2"
);

@Param({"5000000"})
Expand All @@ -225,6 +213,12 @@ public String getFormatString()
})
private String vectorize;

@Param({
"explicit",
"auto"
})
private String schema;

@Param({
// non-expression reference
"0",
Expand Down Expand Up @@ -266,16 +260,7 @@ public String getFormatString()
"35",
"36",
"37",
"38",
"39",
"40",
"41",
"42",
"43",
"44",
"45",
"46",
"47"
"38"
})
private String query;

Expand All @@ -300,8 +285,21 @@ public void setup()
final PlannerConfig plannerConfig = new PlannerConfig();

final SegmentGenerator segmentGenerator = closer.register(new SegmentGenerator());
log.info("Starting benchmark setup using cacheDir[%s], rows[%,d].", segmentGenerator.getCacheDir(), rowsPerSegment);
final QueryableIndex index = segmentGenerator.generate(dataSegment, schemaInfo, Granularities.NONE, rowsPerSegment);
log.info("Starting benchmark setup using cacheDir[%s], rows[%,d], schema[%s].", segmentGenerator.getCacheDir(), rowsPerSegment, schema);
final QueryableIndex index;
if ("auto".equals(schema)) {
index = segmentGenerator.generate(
dataSegment,
schemaInfo,
DimensionsSpec.builder().useSchemaDiscovery(true).build(),
TransformSpec.NONE,
IndexSpec.DEFAULT,
Granularities.NONE,
rowsPerSegment
);
} else {
index = segmentGenerator.generate(dataSegment, schemaInfo, Granularities.NONE, rowsPerSegment);
}

final QueryRunnerFactoryConglomerate conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(
closer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.druid.data.input.InputStats;
import org.apache.druid.data.input.MapBasedInputRow;
import org.apache.druid.data.input.SplitHintSpec;
import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.SplittableInputSource;
import org.apache.druid.guice.IndexingServiceInputSourceModule;
import org.apache.druid.java.util.common.CloseableIterators;
Expand Down Expand Up @@ -179,7 +180,10 @@ public boolean hasNext()
public InputRow next()
{
rowCount++;
return generator.nextRow();
return MapInputRowParser.parse(
inputRowSchema,
generator.nextRaw(inputRowSchema.getTimestampSpec().getTimestampColumn())
);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
import com.google.common.collect.ImmutableSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.data.input.InputRowSchema;
import org.apache.druid.data.input.InputSourceReader;
import org.apache.druid.data.input.InputSplit;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.TimestampSpec;
import org.apache.druid.guice.IndexingServiceInputSourceModule;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.DateTimes;
Expand Down Expand Up @@ -128,11 +132,20 @@ public void testReader() throws IOException
timestampIncrement
);

InputSourceReader reader = inputSource.fixedFormatReader(null, null);
InputRowSchema rowSchema = new InputRowSchema(
new TimestampSpec(null, null, null),
DimensionsSpec.builder().useSchemaDiscovery(true).build(),
null
);

InputSourceReader reader = inputSource.fixedFormatReader(
rowSchema,
null
);
CloseableIterator<InputRow> iterator = reader.read();

InputRow first = iterator.next();
InputRow generatorFirst = generator.nextRow();
InputRow generatorFirst = MapInputRowParser.parse(rowSchema, generator.nextRaw(rowSchema.getTimestampSpec().getTimestampColumn()));
Assert.assertEquals(generatorFirst, first);
Assert.assertTrue(iterator.hasNext());
int i;
Expand All @@ -157,7 +170,7 @@ public void testSplits()
);

Assert.assertEquals(2, inputSource.estimateNumSplits(null, null));
Assert.assertEquals(false, inputSource.needsFormat());
Assert.assertFalse(inputSource.needsFormat());
Assert.assertEquals(2, inputSource.createSplits(null, null).count());
Assert.assertEquals(
new Long(2048L),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,13 @@ public static Number computeNumber(@Nullable String value)
@Nullable
public static ExprEval<?> castForEqualityComparison(ExprEval<?> valueToCompare, ExpressionType typeToCompareWith)
{
if (valueToCompare.isArray() && !typeToCompareWith.isArray()) {
final Object[] array = valueToCompare.asArray();
// cannot cast array to scalar if array length is greater than 1
if (array != null && array.length > 1) {
return null;
}
}
ExprEval<?> cast = valueToCompare.castTo(typeToCompareWith);
if (ExpressionType.LONG.equals(typeToCompareWith) && valueToCompare.asDouble() != cast.asDouble()) {
// make sure the DOUBLE value when cast to LONG is the same before and after the cast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3327,11 +3327,11 @@ public void validateArguments(List<Expr> args)
@Override
public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
ExpressionType type = ExpressionType.LONG;
ExpressionType type = null;
for (Expr arg : args) {
type = ExpressionTypeConversion.function(type, arg.getOutputType(inspector));
type = ExpressionTypeConversion.leastRestrictiveType(type, arg.getOutputType(inspector));
}
return ExpressionType.asArrayType(type);
return type == null ? null : ExpressionTypeFactory.getInstance().ofArray(type);
}
Comment on lines -3330 to 3335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this changed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was wrong in some cases and written before leastRestrictiveType existed, for example if you try to use the array constructor function on an array type input, you end up with

org.apache.druid.segment.column.Types$IncompatibleTypeException: Cannot implicitly cast [LONG] to [ARRAY<STRING>]

which doesn't happen after the change. Will add a test.


/**
Expand Down
Loading
Loading