diff --git a/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala b/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala index 52ccd552..91857fbf 100644 --- a/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala +++ b/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala @@ -6,9 +6,9 @@ package no.nrk.bigquery -import cats.syntax.all._ +import cats.syntax.all.* import cats.effect.Sync -import no.nrk.bigquery.syntax._ +import no.nrk.bigquery.syntax.* import com.google.zetasql.{ AnalyzerOptions, ParseLocationRange, @@ -29,8 +29,9 @@ import com.google.zetasql.parser.{ASTNodes, ParseTreeVisitor} import com.google.zetasql.toolkit.{AnalysisException, AnalyzedStatement, ZetaSQLToolkitAnalyzer} import scala.collection.mutable.ListBuffer -import scala.jdk.CollectionConverters._ -import scala.jdk.OptionConverters._ +import scala.collection.immutable +import scala.jdk.CollectionConverters.* +import scala.jdk.OptionConverters.* class ZetaSql[F[_]](implicit F: Sync[F]) { import ZetaSql._ @@ -49,25 +50,25 @@ class ZetaSql[F[_]](implicit F: Sync[F]) { def parseAndBuildAnalysableFragment( query: String, - allTables: List[BQTableLike[Any]], + allTables: immutable.Seq[BQTableLike[Any]], toFragment: BQTableLike[Any] => BQSqlFrag = _.unpartitioned.bqShow, eqv: (BQTableId, BQTableId) => Boolean = _ == _): F[BQSqlFrag] = { def evalFragments( parsedTables: List[(BQTableId, ParseLocationRange)] ): BQSqlFrag = { - val asString = query val found = allTables .flatMap(table => parsedTables.flatMap { case (id, range) => if (eqv(table.tableId, id)) List(table -> range) else Nil }) .distinct - val (rest, aggregate) = found.foldLeft((asString, BQSqlFrag.Empty)) { case ((input, agg), (t, loc)) => - val frag = agg ++ BQSqlFrag.Frag(input.substring(0, loc.start() - 1)) ++ toFragment(t) - val rest = input.substring(loc.end()) - rest -> frag + val (rest, aggregate) = found.foldLeft((0, Vector.empty[BQSqlFrag])) { case ((offset, agg), (t, loc)) => + val frag = + agg ++ List(BQSqlFrag.Frag(query.substring(offset, loc.start() - 1) + " "), toFragment(t)) + loc.end() -> frag } + val str = query.substring(rest) - aggregate ++ BQSqlFrag.Frag(rest) + BQSqlFrag.Combined(if (str.nonEmpty) aggregate :+ BQSqlFrag.Frag(str) else aggregate) } parseScript(BQSqlFrag.Frag(query)) @@ -75,10 +76,10 @@ class ZetaSql[F[_]](implicit F: Sync[F]) { .flatMap { script => val list = script.getStatementListNode.getStatementList if (list.size() != 1) { - Sync[F].raiseError[List[(no.nrk.bigquery.BQTableId, com.google.zetasql.ParseLocationRange)]]( + F.raiseError[List[(BQTableId, ParseLocationRange)]]( new IllegalArgumentException("Expects only one statement")) } else - Sync[F].delay { + F.delay { val buffer = new ListBuffer[(BQTableId, ParseLocationRange)] list.asScala.headOption.foreach(_.accept(new ParseTreeVisitor { override def visit(node: ASTNodes.ASTTablePathExpression): Unit = @@ -139,26 +140,7 @@ object ZetaSql { catalog } - def fromColumnNameAndType(name: String, typ: Type): BQField = { - val kind = typ.getKind match { - case TypeKind.TYPE_BOOL => BQField.Type.BOOL - case TypeKind.TYPE_DATE => BQField.Type.DATE - case TypeKind.TYPE_DATETIME => BQField.Type.DATETIME - case TypeKind.TYPE_JSON => BQField.Type.JSON - case TypeKind.TYPE_BYTES => BQField.Type.BYTES - case TypeKind.TYPE_STRING => BQField.Type.STRING - case TypeKind.TYPE_BIGNUMERIC => BQField.Type.BIGNUMERIC - case TypeKind.TYPE_INT64 => BQField.Type.INT64 - case TypeKind.TYPE_INT32 => BQField.Type.INT64 - case TypeKind.TYPE_FLOAT => BQField.Type.FLOAT64 - case TypeKind.TYPE_DOUBLE => BQField.Type.FLOAT64 - case TypeKind.TYPE_TIMESTAMP => BQField.Type.TIMESTAMP - case TypeKind.TYPE_TIME => BQField.Type.TIME - case TypeKind.TYPE_GEOGRAPHY => BQField.Type.GEOGRAPHY - case TypeKind.TYPE_INTERVAL => BQField.Type.INTERVAL - case _ => throw new IllegalArgumentException(s"$name with type ${typ.debugString()} is not supported ") - } - + def fromColumnNameAndType(name: String, typ: Type): BQField = if (typ.isArray) { val elem = fromColumnNameAndType(name, typ.asArray().getElementType) elem.copy(mode = BQField.Mode.REPEATED) @@ -170,8 +152,27 @@ object ZetaSql { .asScala .map(subField => fromColumnNameAndType(subField.getName, subField.getType)) .toList: _*) - } else BQField(name, kind, BQField.Mode.NULLABLE) - } + } else { + val kind = typ.getKind match { + case TypeKind.TYPE_BOOL => BQField.Type.BOOL + case TypeKind.TYPE_DATE => BQField.Type.DATE + case TypeKind.TYPE_DATETIME => BQField.Type.DATETIME + case TypeKind.TYPE_JSON => BQField.Type.JSON + case TypeKind.TYPE_BYTES => BQField.Type.BYTES + case TypeKind.TYPE_STRING => BQField.Type.STRING + case TypeKind.TYPE_BIGNUMERIC => BQField.Type.BIGNUMERIC + case TypeKind.TYPE_INT64 => BQField.Type.INT64 + case TypeKind.TYPE_INT32 => BQField.Type.INT64 + case TypeKind.TYPE_FLOAT => BQField.Type.FLOAT64 + case TypeKind.TYPE_DOUBLE => BQField.Type.FLOAT64 + case TypeKind.TYPE_TIMESTAMP => BQField.Type.TIMESTAMP + case TypeKind.TYPE_TIME => BQField.Type.TIME + case TypeKind.TYPE_GEOGRAPHY => BQField.Type.GEOGRAPHY + case TypeKind.TYPE_INTERVAL => BQField.Type.INTERVAL + case _ => throw new IllegalArgumentException(s"$name with type ${typ.debugString()} is not supported ") + } + BQField(name, kind, BQField.Mode.NULLABLE) + } def toSimpleTable(table: BQTableLike[Any]): SimpleTable = { def toType(field: BQField): Type = { diff --git a/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala b/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala index b9257863..64cdb52d 100644 --- a/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala +++ b/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala @@ -6,6 +6,7 @@ package no.nrk.bigquery +import cats.syntax.all._ import cats.effect.IO import no.nrk.bigquery.syntax._ import com.google.zetasql.toolkit.AnalysisException @@ -27,6 +28,15 @@ class ZetaTest extends munit.CatsEffectSuite { BQPartitionType.DatePartitioned(Ident("partitionDate")) ) + private val table2 = BQTableDef.Table( + BQTableId.unsafeOf(BQDataset.unsafeOf(ProjectId("com-example"), "example"), "test2"), + BQSchema.of( + BQField("partitionDate", BQField.Type.DATE, BQField.Mode.REQUIRED), + BQField("name", BQField.Type.STRING, BQField.Mode.REQUIRED) + ), + BQPartitionType.DatePartitioned(Ident("partitionDate")) + ) + test("parses select 1") { zetaSql.analyzeFirst(bqsql"select 1").map(_.isRight).assertEquals(true) } @@ -94,6 +104,31 @@ class ZetaTest extends munit.CatsEffectSuite { .assertEquals(expected) } + test("parse then build analysis multiple tables") { + val query = + """|with data as ( + | select t1.partitionDate, t1.a, t1.b, t2.name + | from `com-example.example.test` t1 + | JOIN `com-example.example.test2` t2 using (partitionDate) + |), + | grouped as ( + | select partitionDate, a, b, COUNTIF(name = "foo") as countFoo from data + | group by 1, 2, 3 + | ) + |select * from grouped + |""".stripMargin + + val expected = + (table.schema.fields.dropRight(2) ++ List(BQField("countFoo", BQField.Type.INT64, BQField.Mode.NULLABLE))) + .map(_.recursivelyNullable.withoutDescription) + + val analysis = zetaSql + .parseAndBuildAnalysableFragment(query, List(table, table2)) + analysis + .flatMap(fragment => zetaSql.queryFields(fragment).tupleRight(fragment.allReferencedTables.map(_.tableId))) + .assertEquals(expected -> List(table, table2).map(_.tableId)) + } + override def munitTestTransforms: List[TestTransform] = super.munitTestTransforms ++ List( new TestTransform(