Skip to content

Commit

Permalink
feat: add support for parsing tvf with zetasql
Browse files Browse the repository at this point in the history
  • Loading branch information
ingarabr committed Nov 20, 2023
1 parent c9c7f56 commit 6c17825
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 53 deletions.
165 changes: 114 additions & 51 deletions zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@ package no.nrk.bigquery

import cats.syntax.all.*
import cats.effect.Sync
import com.google.common.collect.ImmutableList
import com.google.zetasql.ZetaSQLFunctions.SignatureArgumentKind
import no.nrk.bigquery.syntax.*
import com.google.zetasql.{
AnalyzerOptions,
FunctionArgumentType,
FunctionSignature,
ParseLocationRange,
Parser,
SimpleColumn,
SimpleTable,
SqlException,
StructType,
TVFRelation,
Type,
TypeFactory
}
Expand All @@ -26,7 +31,9 @@ import com.google.zetasql.resolvedast.ResolvedNodes
import com.google.zetasql.toolkit.catalog.basic.BasicCatalogWrapper
import com.google.zetasql.toolkit.options.BigQueryLanguageOptions
import com.google.zetasql.parser.{ASTNodes, ParseTreeVisitor}
import com.google.zetasql.toolkit.catalog.TVFInfo
import com.google.zetasql.toolkit.{AnalysisException, AnalyzedStatement, ZetaSQLToolkitAnalyzer}
import no.nrk.bigquery.TVF.TVFId

import scala.collection.mutable.ListBuffer
import scala.collection.immutable
Expand All @@ -48,40 +55,61 @@ class ZetaSql[F[_]](implicit F: Sync[F]) {
}
}

@deprecated("for binary compatiblity, remove next version", "0.10.2")
private[bigquery] def parseAndBuildAnalysableFragment(
query: String,
allTables: List[BQTableLike[Any]],
toFragment: BQTableLike[Any] => BQSqlFrag,
eqv: (BQTableId, BQTableId) => Boolean
): F[BQSqlFrag] =
parseAndBuildAnalysableFragmentImpl(query, allTables, toFragment, eqv)

def parseAndBuildAnalysableFragment(
query: String,
allTables: immutable.Seq[BQTableLike[Any]],
toFragment: BQTableLike[Any] => BQSqlFrag = _.unpartitioned.bqShow,
eqv: (BQTableId, BQTableId) => Boolean = _ == _): F[BQSqlFrag] =
parseAndBuildAnalysableFragmentImpl(query, allTables, toFragment, eqv)

private def parseAndBuildAnalysableFragmentImpl(
query: String,
allTables: immutable.Seq[BQTableLike[Any]],
toFragment: BQTableLike[Any] => BQSqlFrag,
eqv: (BQTableId, BQTableId) => Boolean): F[BQSqlFrag] = {
allTVF: immutable.Seq[TVF[Any, _]],
toTableFragment: BQTableLike[Any] => BQSqlFrag = _.unpartitioned.bqShow,
tableIdEqv: (BQTableId, BQTableId) => Boolean = _ == _,
tvfIdEqv: (TVFId, TVFId) => Boolean = _ == _
): F[BQSqlFrag] = {
sealed trait FragLoc {
def loc: ParseLocationRange
}
object FragLoc {
case class Table(id: BQTableId, loc: ParseLocationRange) extends FragLoc
case class TVF(id: TVFId, argsLoc: List[ParseLocationRange], loc: ParseLocationRange) extends FragLoc
}

def evalFragments(
parsedTables: List[(BQTableId, ParseLocationRange)]
parsedTables: List[FragLoc]
): BQSqlFrag = {
val found = allTables
.flatMap(table =>
parsedTables.flatMap { case (id, range) => if (eqv(table.tableId, id)) List(table -> range) else Nil })
.distinct
val (rest, aggregate) = found.foldLeft((0, Vector.empty[BQSqlFrag])) { case ((offset, agg), (t, loc)) =>
val frag =
agg ++ List(BQSqlFrag.Frag(query.substring(offset, loc.start() - 1) + " "), toFragment(t))
loc.end() -> frag
}
val (rest, aggregate) = parsedTables
.filter {
case FragLoc.Table(id, _) => allTables.exists(t => tableIdEqv(id, t.tableId))
case FragLoc.TVF(id, _, _) => allTVF.exists(t => tvfIdEqv(id, t.name))
}
.sortBy(_.loc.start())
.foldLeft((0, Vector.empty[BQSqlFrag])) { case ((offset, agg), fragLoc) =>
val prefix = BQSqlFrag.Frag(query.substring(offset, fragLoc.loc.start() - 1) + " ")
val frag = fragLoc match {
case FragLoc.Table(id, _) =>
agg ++ allTables
.find(t => tableIdEqv(id, t.tableId))
.map(t => List(prefix, toTableFragment(t)))
.getOrElse(Nil)
case FragLoc.TVF(id, argsLoc, _) =>
val tvfFrag = allTVF
.find(t => tvfIdEqv(id, t.name))
.map(tvf =>
List(
prefix,
BQSqlFrag.TableRef(BQAppliedTableValuedFunction[Any](
tvf.name,
tvf.partitionType,
tvf.params.unsized.toList,
tvf.query,
tvf.schema,
tvf.description,
argsLoc.map(l => BQSqlFrag(query.substring(l.start(), l.end())))
))
))
.getOrElse(Nil)
agg ++ tvfFrag

}
fragLoc.loc.end() -> frag
}
val str = query.substring(rest)

BQSqlFrag.Combined(if (str.nonEmpty) aggregate :+ BQSqlFrag.Frag(str) else aggregate)
Expand All @@ -92,19 +120,33 @@ class ZetaSql[F[_]](implicit F: Sync[F]) {
.flatMap { script =>
val list = script.getStatementListNode.getStatementList
if (list.size() != 1) {
F.raiseError[List[(BQTableId, ParseLocationRange)]](
new IllegalArgumentException("Expects only one statement"))
F.raiseError[List[FragLoc]](new IllegalArgumentException("Expects only one statement"))
} else
F.delay {
val buffer = new ListBuffer[(BQTableId, ParseLocationRange)]
list.asScala.headOption.foreach(_.accept(new ParseTreeVisitor {
override def visit(node: ASTNodes.ASTTablePathExpression): Unit =
node.getPathExpr.getNames.forEach(ident =>
BQTableId
.fromString(ident.getIdString)
.toOption
.foreach(id => buffer += (id -> ident.getParseLocationRange)))
}))
val buffer = new ListBuffer[FragLoc]
list.asScala.headOption.foreach { rootNode =>
rootNode.accept(new ParseTreeVisitor {
override def visit(node: ASTNodes.ASTTablePathExpression): Unit =
node.getPathExpr.getNames.forEach(ident =>
BQTableId
.fromString(ident.getIdString)
.toOption
.foreach(id => buffer += FragLoc.Table(id, ident.getParseLocationRange)))

override def visit(node: ASTNodes.ASTTVF): Unit =
node.getName.getNames.forEach { ident =>
val argLoc = new ListBuffer[ParseLocationRange]
node.getArgumentEntries.asScala.foreach(_.accept(new ParseTreeVisitor {
override def visit(node: ASTNodes.ASTTVFArgument): Unit =
argLoc += node.getParseLocationRange
}))
BQTableId
.fromString(ident.getIdString)
.toOption
.foreach(id => buffer += FragLoc.TVF(TVFId(id), argLoc.toList, node.getParseLocationRange))
}
})
}
buffer.toList
}
}
Expand Down Expand Up @@ -162,8 +204,12 @@ object ZetaSql {
toSimpleTable(tableRef.tableId, None),
CreateMode.CREATE_IF_NOT_EXISTS,
CreateScope.CREATE_DEFAULT_SCOPE)
case BQAppliedTableValuedFunction(_, _) =>
// todo: add support for TVF
case atvf: BQAppliedTableValuedFunction[Any] =>
catalog.register(
toTableValuedFunction(atvf),
CreateMode.CREATE_IF_NOT_EXISTS,
CreateScope.CREATE_DEFAULT_SCOPE
)
}

catalog
Expand Down Expand Up @@ -203,10 +249,9 @@ object ZetaSql {
BQField(name, kind, BQField.Mode.NULLABLE)
}

private def toType(field: BQField): Type = {
val isArray = field.mode == BQField.Mode.REPEATED

val elemType = field.tpe match {
private def toType(bqType: BQType): Type = {
val isArray = bqType.mode == BQField.Mode.REPEATED
val elemType = bqType.tpe match {
case BQField.Type.BOOL => TypeFactory.createSimpleType(TypeKind.TYPE_BOOL)
case BQField.Type.INT64 => TypeFactory.createSimpleType(TypeKind.TYPE_INT64)
case BQField.Type.FLOAT64 => TypeFactory.createSimpleType(TypeKind.TYPE_FLOAT)
Expand All @@ -216,12 +261,9 @@ object ZetaSql {
case BQField.Type.BYTES => TypeFactory.createSimpleType(TypeKind.TYPE_BYTES)
case BQField.Type.STRUCT =>
TypeFactory.createStructType(
field.subFields
.map(sub => new StructType.StructField(sub.name, toType(sub)))
.asJavaCollection
bqType.subFields.map(f => new StructType.StructField(f._1, toType(f._2))).asJavaCollection
)
case BQField.Type.ARRAY =>
TypeFactory.createArrayType(toType(field.subFields.head))
case BQField.Type.ARRAY => TypeFactory.createArrayType(toType(bqType.subFields.head._2))
case BQField.Type.TIMESTAMP => TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP)
case BQField.Type.DATE => TypeFactory.createSimpleType(TypeKind.TYPE_DATE)
case BQField.Type.TIME => TypeFactory.createSimpleType(TypeKind.TYPE_TIME)
Expand All @@ -233,10 +275,31 @@ object ZetaSql {
if (isArray) TypeFactory.createArrayType(elemType) else elemType
}

def toTableValuedFunction(atvf: BQAppliedTableValuedFunction[Any]): TVFInfo = {
val tableType = new FunctionArgumentType(SignatureArgumentKind.ARG_TYPE_RELATION)
val args = new java.util.ArrayList(
atvf.params
.map(_.maybeType.map(tpe => toType(tpe)).getOrElse(TypeFactory.createSimpleType(TypeKind.TYPE_UNKNOWN)))
.map(new FunctionArgumentType(_))
.asJavaCollection)

TVFInfo
.newBuilder()
.setSignature(new FunctionSignature(tableType, args, -1))
.setOutputSchema(
TVFRelation.createColumnBased(
new java.util.ArrayList(atvf.schema.fields
.map(f => TVFRelation.Column.create(f.name, toType(BQType.fromField(f))))
.asJavaCollection)
))
.setNamePath(ImmutableList.of(atvf.name.asString))
.build()
}

def toSimpleTable(tableId: BQTableId, schema: => Option[BQSchema]): SimpleTable = {

def toSimpleField(field: BQField) =
new SimpleColumn(tableId.tableName, field.name, toType(field), false, true)
new SimpleColumn(tableId.tableName, field.name, toType(BQType.fromField(field)), false, true)

val simple = schema match {
case None =>
Expand Down
38 changes: 36 additions & 2 deletions zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ class ZetaTest extends munit.CatsEffectSuite {
),
BQPartitionType.DatePartitioned(Ident("partitionDate"))
)
private val tvf = TVF(
TVF.TVFId(BQDataset.unsafeOf(ProjectId("com-example"), "example"), ident"tvftest"),
BQPartitionType.NotPartitioned,
BQRoutine.Params(BQRoutine.Param("d", BQType.DATE)),
bqfr"select a from ${table.unpartitioned} where partitionDate = d",
BQSchema.of(BQField("a", BQField.Type.STRING, BQField.Mode.REQUIRED))
)

test("parses select 1") {
zetaSql.analyzeFirst(bqsql"select 1").map(_.isRight).assertEquals(true)
Expand Down Expand Up @@ -99,7 +106,7 @@ class ZetaTest extends munit.CatsEffectSuite {
.map(_.recursivelyNullable.withoutDescription)

zetaSql
.parseAndBuildAnalysableFragment(query, List(table))
.parseAndBuildAnalysableFragment(query, List(table), Nil)
.flatMap(zetaSql.queryFields)
.assertEquals(expected)
}
Expand All @@ -123,12 +130,39 @@ class ZetaTest extends munit.CatsEffectSuite {
.map(_.recursivelyNullable.withoutDescription)

val analysis = zetaSql
.parseAndBuildAnalysableFragment(query, List(table, table2))
.parseAndBuildAnalysableFragment(query, List(table, table2), Nil)
analysis
.flatMap(fragment => zetaSql.queryFields(fragment).tupleRight(fragment.allReferencedTables.map(_.tableId)))
.assertEquals(expected -> List(table, table2).map(_.tableId))
}

test("parse then build analysis with tvf") {
val query = """select a from `com-example.example.tvftest`(current_date)"""

val expected = tvf.schema.fields.map(_.recursivelyNullable.withoutDescription)

zetaSql
.parseAndBuildAnalysableFragment(query, List(table), List(tvf))
.flatMap(fragment => zetaSql.queryFields(fragment).tupleRight(fragment.allReferencedTables.map(_.tableId)))
.assertEquals(expected -> List(BQTableId(tvf.name.dataset, tvf.name.name.value)))
}

test("parse then build analysis with tvf 2") {
val query = """select a from `com-example.example.tvftest`(current_date())"""

val expected = tvf.schema.fields.map(_.recursivelyNullable.withoutDescription)

zetaSql
.parseAndBuildAnalysableFragment(query, List(table), List(tvf))
.flatMap(fragment =>
zetaSql
.queryFields(fragment)
.tupleRight(fragment.collect { case BQSqlFrag.TableRef(atvf: BQAppliedTableValuedFunction[Any]) =>
atvf.name -> atvf.args
}))
.assertEquals(expected -> List((tvf.name, List(bqfr"current_date()"))))
}

override def munitTestTransforms: List[TestTransform] =
super.munitTestTransforms ++ List(
new TestTransform(
Expand Down

0 comments on commit 6c17825

Please sign in to comment.