Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Zetasql analysis fragment #187

Merged
merged 5 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

package no.nrk.bigquery

import cats.syntax.all._
import cats.syntax.all.*
import cats.effect.Sync
import no.nrk.bigquery.syntax._
import no.nrk.bigquery.syntax.*
import com.google.zetasql.{
AnalyzerOptions,
ParseLocationRange,
Expand All @@ -29,8 +29,9 @@ import com.google.zetasql.parser.{ASTNodes, ParseTreeVisitor}
import com.google.zetasql.toolkit.{AnalysisException, AnalyzedStatement, ZetaSQLToolkitAnalyzer}

import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._
import scala.collection.immutable
import scala.jdk.CollectionConverters.*
import scala.jdk.OptionConverters.*

class ZetaSql[F[_]](implicit F: Sync[F]) {
import ZetaSql._
Expand All @@ -49,36 +50,36 @@ class ZetaSql[F[_]](implicit F: Sync[F]) {

def parseAndBuildAnalysableFragment(
query: String,
allTables: List[BQTableLike[Any]],
allTables: immutable.Seq[BQTableLike[Any]],
toFragment: BQTableLike[Any] => BQSqlFrag = _.unpartitioned.bqShow,
eqv: (BQTableId, BQTableId) => Boolean = _ == _): F[BQSqlFrag] = {

def evalFragments(
parsedTables: List[(BQTableId, ParseLocationRange)]
): BQSqlFrag = {
val asString = query
val found = allTables
.flatMap(table =>
parsedTables.flatMap { case (id, range) => if (eqv(table.tableId, id)) List(table -> range) else Nil })
.distinct
val (rest, aggregate) = found.foldLeft((asString, BQSqlFrag.Empty)) { case ((input, agg), (t, loc)) =>
val frag = agg ++ BQSqlFrag.Frag(input.substring(0, loc.start() - 1)) ++ toFragment(t)
val rest = input.substring(loc.end())
rest -> frag
val (rest, aggregate) = found.foldLeft((0, Vector.empty[BQSqlFrag])) { case ((offset, agg), (t, loc)) =>
val frag =
agg ++ List(BQSqlFrag.Frag(query.substring(offset, loc.start() - 1) + " "), toFragment(t))
loc.end() -> frag
}
val str = query.substring(rest)

aggregate ++ BQSqlFrag.Frag(rest)
BQSqlFrag.Combined(if (str.nonEmpty) aggregate :+ BQSqlFrag.Frag(str) else aggregate)
}

parseScript(BQSqlFrag.Frag(query))
.flatMap(F.fromEither)
.flatMap { script =>
val list = script.getStatementListNode.getStatementList
if (list.size() != 1) {
Sync[F].raiseError[List[(no.nrk.bigquery.BQTableId, com.google.zetasql.ParseLocationRange)]](
F.raiseError[List[(BQTableId, ParseLocationRange)]](
new IllegalArgumentException("Expects only one statement"))
} else
Sync[F].delay {
F.delay {
val buffer = new ListBuffer[(BQTableId, ParseLocationRange)]
list.asScala.headOption.foreach(_.accept(new ParseTreeVisitor {
override def visit(node: ASTNodes.ASTTablePathExpression): Unit =
Expand Down Expand Up @@ -139,26 +140,7 @@ object ZetaSql {
catalog
}

def fromColumnNameAndType(name: String, typ: Type): BQField = {
val kind = typ.getKind match {
case TypeKind.TYPE_BOOL => BQField.Type.BOOL
case TypeKind.TYPE_DATE => BQField.Type.DATE
case TypeKind.TYPE_DATETIME => BQField.Type.DATETIME
case TypeKind.TYPE_JSON => BQField.Type.JSON
case TypeKind.TYPE_BYTES => BQField.Type.BYTES
case TypeKind.TYPE_STRING => BQField.Type.STRING
case TypeKind.TYPE_BIGNUMERIC => BQField.Type.BIGNUMERIC
case TypeKind.TYPE_INT64 => BQField.Type.INT64
case TypeKind.TYPE_INT32 => BQField.Type.INT64
case TypeKind.TYPE_FLOAT => BQField.Type.FLOAT64
case TypeKind.TYPE_DOUBLE => BQField.Type.FLOAT64
case TypeKind.TYPE_TIMESTAMP => BQField.Type.TIMESTAMP
case TypeKind.TYPE_TIME => BQField.Type.TIME
case TypeKind.TYPE_GEOGRAPHY => BQField.Type.GEOGRAPHY
case TypeKind.TYPE_INTERVAL => BQField.Type.INTERVAL
case _ => throw new IllegalArgumentException(s"$name with type ${typ.debugString()} is not supported ")
}

def fromColumnNameAndType(name: String, typ: Type): BQField =
if (typ.isArray) {
val elem = fromColumnNameAndType(name, typ.asArray().getElementType)
elem.copy(mode = BQField.Mode.REPEATED)
Expand All @@ -170,8 +152,27 @@ object ZetaSql {
.asScala
.map(subField => fromColumnNameAndType(subField.getName, subField.getType))
.toList: _*)
} else BQField(name, kind, BQField.Mode.NULLABLE)
}
} else {
val kind = typ.getKind match {
case TypeKind.TYPE_BOOL => BQField.Type.BOOL
case TypeKind.TYPE_DATE => BQField.Type.DATE
case TypeKind.TYPE_DATETIME => BQField.Type.DATETIME
case TypeKind.TYPE_JSON => BQField.Type.JSON
case TypeKind.TYPE_BYTES => BQField.Type.BYTES
case TypeKind.TYPE_STRING => BQField.Type.STRING
case TypeKind.TYPE_BIGNUMERIC => BQField.Type.BIGNUMERIC
case TypeKind.TYPE_INT64 => BQField.Type.INT64
case TypeKind.TYPE_INT32 => BQField.Type.INT64
case TypeKind.TYPE_FLOAT => BQField.Type.FLOAT64
case TypeKind.TYPE_DOUBLE => BQField.Type.FLOAT64
case TypeKind.TYPE_TIMESTAMP => BQField.Type.TIMESTAMP
case TypeKind.TYPE_TIME => BQField.Type.TIME
case TypeKind.TYPE_GEOGRAPHY => BQField.Type.GEOGRAPHY
case TypeKind.TYPE_INTERVAL => BQField.Type.INTERVAL
case _ => throw new IllegalArgumentException(s"$name with type ${typ.debugString()} is not supported ")
}
BQField(name, kind, BQField.Mode.NULLABLE)
}

def toSimpleTable(table: BQTableLike[Any]): SimpleTable = {
def toType(field: BQField): Type = {
Expand Down
35 changes: 35 additions & 0 deletions zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

package no.nrk.bigquery

import cats.syntax.all._
import cats.effect.IO
import no.nrk.bigquery.syntax._
import com.google.zetasql.toolkit.AnalysisException
Expand All @@ -27,6 +28,15 @@ class ZetaTest extends munit.CatsEffectSuite {
BQPartitionType.DatePartitioned(Ident("partitionDate"))
)

private val table2 = BQTableDef.Table(
BQTableId.unsafeOf(BQDataset.unsafeOf(ProjectId("com-example"), "example"), "test2"),
BQSchema.of(
BQField("partitionDate", BQField.Type.DATE, BQField.Mode.REQUIRED),
BQField("name", BQField.Type.STRING, BQField.Mode.REQUIRED)
),
BQPartitionType.DatePartitioned(Ident("partitionDate"))
)

test("parses select 1") {
zetaSql.analyzeFirst(bqsql"select 1").map(_.isRight).assertEquals(true)
}
Expand Down Expand Up @@ -94,6 +104,31 @@ class ZetaTest extends munit.CatsEffectSuite {
.assertEquals(expected)
}

test("parse then build analysis multiple tables") {
val query =
"""|with data as (
| select t1.partitionDate, t1.a, t1.b, t2.name
| from `com-example.example.test` t1
| JOIN `com-example.example.test2` t2 using (partitionDate)
|),
| grouped as (
| select partitionDate, a, b, COUNTIF(name = "foo") as countFoo from data
| group by 1, 2, 3
| )
|select * from grouped
|""".stripMargin

val expected =
(table.schema.fields.dropRight(2) ++ List(BQField("countFoo", BQField.Type.INT64, BQField.Mode.NULLABLE)))
.map(_.recursivelyNullable.withoutDescription)

val analysis = zetaSql
.parseAndBuildAnalysableFragment(query, List(table, table2))
analysis
.flatMap(fragment => zetaSql.queryFields(fragment).tupleRight(fragment.allReferencedTables.map(_.tableId)))
.assertEquals(expected -> List(table, table2).map(_.tableId))
}

override def munitTestTransforms: List[TestTransform] =
super.munitTestTransforms ++ List(
new TestTransform(
Expand Down