Skip to content

Commit

Permalink
feat: add support for inline struct creation (#4120)
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Dec 19, 2019
1 parent 6c6695c commit 6e558da
Show file tree
Hide file tree
Showing 22 changed files with 602 additions and 31 deletions.
7 changes: 7 additions & 0 deletions docs/developer-guide/syntax-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ encapsulate a street address and a postal code:
orderId BIGINT,
address STRUCT<street VARCHAR, zip INTEGER>) WITH (...);
You can create a struct in a query by specifying the names of the columns
and expressions that construct the values, separated by ``,`` and wrapped with
curly braces. For example: ``SELECT STRUCT(name := col0, ageInDogYears := col1*7) AS dogs FROM animals``
creates a schema ``col0 STRUCT<name VARCHAR, ageInDogYears INTEGER>``, assuming ``col0`` was a string and
``col1`` was an integer.

Access the fields in a ``STRUCT`` by using the dereference operator (``->``):

.. code:: sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
package io.confluent.ksql.engine.rewrite;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.Cast;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field;
import io.confluent.ksql.execution.expression.tree.DecimalLiteral;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.DoubleLiteral;
Expand Down Expand Up @@ -185,6 +188,15 @@ public Expression visitSubscriptExpression(
return new SubscriptExpression(node.getLocation(), base, index);
}

@Override
public Expression visitStructExpression(CreateStructExpression node, C context) {
final Builder<Field> fields = ImmutableList.builder();
for (Field field : node.getFields()) {
fields.add(new Field(field.getName(), rewriter.apply(field.getValue(), context)));
}
return new CreateStructExpression(node.getLocation(), fields.build());
}

@Override
public Expression visitComparisonExpression(
final ComparisonExpression node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.Cast;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field;
import io.confluent.ksql.execution.expression.tree.DecimalLiteral;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.DoubleLiteral;
Expand All @@ -48,9 +51,6 @@
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.NotExpression;
import io.confluent.ksql.execution.expression.tree.NullLiteral;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.SearchedCaseExpression;
import io.confluent.ksql.execution.expression.tree.SimpleCaseExpression;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
Expand All @@ -61,11 +61,13 @@
import io.confluent.ksql.execution.expression.tree.WhenClause;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.parser.KsqlParser.PreparedStatement;
import io.confluent.ksql.parser.KsqlParserTestUtil;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.SelectItem;
import io.confluent.ksql.parser.tree.SingleColumn;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.types.SqlPrimitiveType;
import io.confluent.ksql.util.MetaStoreFixture;
import java.util.List;
Expand Down Expand Up @@ -530,6 +532,22 @@ public void shouldRewriteSubscriptExpression() {
assertThat(rewritten, equalTo(new SubscriptExpression(parsed.getLocation(), expr1, expr2)));
}

@Test
public void shouldRewriteStructExpression() {
// Given:
final CreateStructExpression parsed = parseExpression("STRUCT(FOO := 'foo', BAR := col4[1])");
final Expression fooVal = parsed.getFields().stream().filter(f -> f.getName().equals("FOO")).findFirst().get().getValue();
final Expression barVal = parsed.getFields().stream().filter(f -> f.getName().equals("BAR")).findFirst().get().getValue();
when(processor.apply(fooVal, context)).thenReturn(expr1);
when(processor.apply(barVal, context)).thenReturn(expr2);

// When:
final Expression rewritten = expressionRewriter.rewrite(parsed, context);

// Then:
assertThat(rewritten, equalTo(new CreateStructExpression(ImmutableList.of(new Field("FOO", expr1), new Field("BAR", expr2)))));
}

@Test
public void shouldRewriteSubscriptExpressionUsingPlugin() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.confluent.ksql.execution.codegen;

import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
Expand All @@ -40,6 +41,8 @@
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.kafka.connect.data.Schema;
import org.codehaus.commons.compiler.CompileException;
import org.codehaus.commons.compiler.CompilerFactoryFactory;
import org.codehaus.commons.compiler.IExpressionEvaluator;
Expand Down Expand Up @@ -175,6 +178,17 @@ public Void visitSubscriptExpression(SubscriptExpression node, Void context) {
return null;
}

@Override
public Void visitStructExpression(CreateStructExpression exp, @Nullable Void context) {
exp.getFields().forEach(val -> process(val.getValue(), context));
final Schema schema = SchemaConverters
.sqlToConnectConverter()
.toConnectSchema(expressionTypeManager.getExpressionSqlType(exp));

spec.addStructSchema(exp, schema);
return null;
}

@Override
public Void visitColumnReference(ColumnReferenceExp node, Void context) {
addRequiredColumn(node.getReference());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.Immutable;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.expression.formatter.ExpressionFormatter;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression;
import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.ColumnRef;
Expand All @@ -30,21 +32,26 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.kafka.connect.data.Schema;

@Immutable
public final class CodeGenSpec {

private final ImmutableList<ArgumentSpec> arguments;
private final ImmutableMap<ColumnRef, String> columnToCodeName;
private final ImmutableListMultimap<FunctionName, String> functionToCodeName;
private final ImmutableMap<CreateStructExpression, String> structToCodeName;

private CodeGenSpec(
ImmutableList<ArgumentSpec> arguments, ImmutableMap<ColumnRef, String> columnToCodeName,
ImmutableListMultimap<FunctionName, String> functionToCodeName
ImmutableList<ArgumentSpec> arguments,
ImmutableMap<ColumnRef, String> columnToCodeName,
ImmutableListMultimap<FunctionName, String> functionToCodeName,
ImmutableMap<CreateStructExpression, String> structToCodeName
) {
this.arguments = arguments;
this.columnToCodeName = columnToCodeName;
this.functionToCodeName = functionToCodeName;
this.structToCodeName = structToCodeName;
}

public String[] argumentNames() {
Expand Down Expand Up @@ -77,14 +84,27 @@ public void resolve(GenericRow row, Object[] parameters) {
}
}

public String getStructSchemaName(CreateStructExpression createStructExpression) {
final String schemaName = structToCodeName.get(createStructExpression);
if (schemaName == null) {
throw new KsqlException(
"Cannot get name for " + ExpressionFormatter.formatExpression(createStructExpression)
);
}
return schemaName;
}

static class Builder {

private final ImmutableList.Builder<ArgumentSpec> argumentBuilder = ImmutableList.builder();
private final Map<ColumnRef, String> columnRefToName = new HashMap<>();
private final ImmutableListMultimap.Builder<FunctionName, String> functionNameBuilder =
ImmutableListMultimap.builder();
private final ImmutableMap.Builder<CreateStructExpression, String> structToSchemaName =
ImmutableMap.builder();

private int argumentCount = 0;
private int structSchemaCount = 0;

void addParameter(
final ColumnRef columnRef,
Expand All @@ -102,11 +122,18 @@ void addFunction(FunctionName functionName, Kudf function) {
argumentBuilder.add(new FunctionArgumentSpec(codeName, function.getClass(), function));
}

void addStructSchema(CreateStructExpression struct, Schema schema) {
final String structSchemaName = CodeGenUtil.schemaName(structSchemaCount++);
structToSchemaName.put(struct, structSchemaName);
argumentBuilder.add(new SchemaArgumentSpec(structSchemaName, schema));
}

CodeGenSpec build() {
return new CodeGenSpec(
argumentBuilder.build(),
ImmutableMap.copyOf(columnRefToName),
functionNameBuilder.build()
functionNameBuilder.build(),
structToSchemaName.build()
);
}
}
Expand Down Expand Up @@ -208,4 +235,32 @@ public String toString() {
+ '}';
}
}

@Immutable
public static final class SchemaArgumentSpec extends BaseArgumentSpec {

private final Schema schema;

SchemaArgumentSpec(
String name,
Schema schema
) {
super(name, Schema.class);
this.schema = requireNonNull(schema, "schema");
}

@Override
public Object resolve(GenericRow value) {
return schema;
}

@Override
public String toString() {
return "StructSchemaArgumentSpec{"
+ "name='" + name() + '\''
+ ", type=" + type()
+ ", schema=" + schema
+ '}';
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
public final class CodeGenUtil {

private static final String PARAM_NAME_PREFIX = "var";
private static final String SCHEMA_NAME_PREFIX = "schema";

private CodeGenUtil() {
}
Expand All @@ -28,6 +29,10 @@ public static String paramName(int index) {
return PARAM_NAME_PREFIX + index;
}

public static String schemaName(int index) {
return SCHEMA_NAME_PREFIX + index;
}

public static String functionName(FunctionName fun, int index) {
return fun.name() + "_" + index;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import io.confluent.ksql.execution.expression.tree.Cast;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field;
import io.confluent.ksql.execution.expression.tree.DecimalLiteral;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.DoubleLiteral;
Expand Down Expand Up @@ -87,6 +89,8 @@
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.text.StrSubstitutor;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.connect.data.Struct;

public class SqlToJavaVisitor {

Expand All @@ -104,7 +108,9 @@ public class SqlToJavaVisitor {
DecimalUtil.class.getCanonicalName(),
BigDecimal.class.getCanonicalName(),
MathContext.class.getCanonicalName(),
RoundingMode.class.getCanonicalName()
RoundingMode.class.getCanonicalName(),
SchemaBuilder.class.getCanonicalName(),
Struct.class.getCanonicalName()
);

private static final Map<Operator, String> DECIMAL_OPERATOR_NAME = ImmutableMap
Expand Down Expand Up @@ -133,6 +139,7 @@ public class SqlToJavaVisitor {
private final ExpressionTypeManager expressionTypeManager;
private final Function<FunctionName, String> funNameToCodeName;
private final Function<ColumnRef, String> colRefToCodeName;
private final Function<CreateStructExpression, String> structToCodeName;

public static SqlToJavaVisitor of(
LogicalSchema schema, FunctionRegistry functionRegistry, CodeGenSpec spec
Expand All @@ -145,21 +152,24 @@ public static SqlToJavaVisitor of(
name -> {
int index = nameCounts.add(name, 1);
return spec.getUniqueNameForFunction(name, index);
}
);
},
spec::getStructSchemaName);
}

@VisibleForTesting
SqlToJavaVisitor(
LogicalSchema schema, FunctionRegistry functionRegistry,
Function<ColumnRef, String> colRefToCodeName, Function<FunctionName, String> funNameToCodeName
Function<ColumnRef, String> colRefToCodeName,
Function<FunctionName, String> funNameToCodeName,
Function<CreateStructExpression, String> structToCodeName
) {
this.expressionTypeManager =
new ExpressionTypeManager(schema, functionRegistry);
this.schema = Objects.requireNonNull(schema, "schema");
this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry");
this.colRefToCodeName = Objects.requireNonNull(colRefToCodeName, "colRefToCodeName");
this.funNameToCodeName = Objects.requireNonNull(funNameToCodeName, "funNameToCodeName");
this.structToCodeName = Objects.requireNonNull(structToCodeName, "structToCodeName");
}

public String process(Expression expression) {
Expand Down Expand Up @@ -712,6 +722,25 @@ public Pair<String, SqlType> visitSubscriptExpression(SubscriptExpression node,
}
}

@Override
public Pair<String, SqlType> visitStructExpression(CreateStructExpression node, Void context) {
final String schemaName = structToCodeName.apply(node);
final StringBuilder struct = new StringBuilder("new Struct(").append(schemaName).append(")");
for (Field field : node.getFields()) {
struct.append(".put(")
.append('"')
.append(field.getName())
.append('"')
.append(",")
.append(process(field.getValue(), context).getLeft())
.append(")");
}
return new Pair<>(
"((Struct)" + struct.toString() + ")",
expressionTypeManager.getExpressionSqlType(node)
);
}

@Override
public Pair<String, SqlType> visitBetweenPredicate(BetweenPredicate node, Void context) {
Pair<String, SqlType> value = process(node.getValue(), context);
Expand Down
Loading

0 comments on commit 6e558da

Please sign in to comment.