Skip to content

Commit

Permalink
feat: model tvf as tableLike
Browse files Browse the repository at this point in the history
  • Loading branch information
ingarabr committed Nov 14, 2023
1 parent e5ea400 commit 48a6946
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 110 deletions.
11 changes: 8 additions & 3 deletions core/src/main/scala/no/nrk/bigquery/BQRoutine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ import no.nrk.bigquery.util.{Nat, Sized}

sealed trait BQRoutine[N <: Nat] {
def params: BQRoutine.Params[N]
def call(args: Sized[IndexedSeq[BQSqlFrag.Magnet], N]): BQSqlFrag.Call =
BQSqlFrag.Call(this, args.unsized.toList.map(_.frag))
def call(args: Sized[IndexedSeq[BQSqlFrag.Magnet], N]): BQSqlFrag
}

sealed trait BQPersistentRoutine[N <: Nat] extends BQRoutine[N] {
Expand Down Expand Up @@ -65,7 +64,10 @@ case class TVF[+P, N <: Nat](
query: BQSqlFrag,
schema: BQSchema,
description: Option[String] = None
) extends BQPersistentRoutine[N]
) extends BQPersistentRoutine[N] {
def call(args: Sized[IndexedSeq[BQSqlFrag.Magnet], N]): BQSqlFrag =
BQSqlFrag.TableRef(BQAppliedTableValuedFunction(this, args.map(_.frag)))
}

object TVF {
case class TVFId(dataset: BQDataset, name: Ident) extends BQPersistentRoutine.PersistentRoutineId {
Expand Down Expand Up @@ -98,6 +100,9 @@ sealed trait UDF[+A <: UDFId, N <: Nat] extends BQRoutine[N] {
def name: A
def params: BQRoutine.Params[N]
def returnType: Option[BQType]

def call(args: Sized[IndexedSeq[BQSqlFrag.Magnet], N]): BQSqlFrag =
BQSqlFrag.Call(this, args.unsized.toList.map(_.frag))
}

object UDF {
Expand Down
25 changes: 9 additions & 16 deletions core/src/main/scala/no/nrk/bigquery/BQSqlFrag.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ sealed trait BQSqlFrag {
this match {
case BQSqlFrag.Frag(string) =>
string
case BQSqlFrag.Call(routine, args) =>
val name = routine match {
case tvf: TVF[_, _] => tvf.name.asFragment
case udf: UDF[_, _] => udf.name.asFragment
}
args.mkFragment(bqfr"${name}(", bqfr", ", bqfr")").asString
case BQSqlFrag.Call(udf, args) =>
args.mkFragment(bqfr"${udf.name}(", bqfr", ", bqfr")").asString

case BQSqlFrag.Combined(values) =>
val partitions = values.collect {
Expand All @@ -66,7 +62,7 @@ sealed trait BQSqlFrag {
.PartitionRef(fill.tableDef.unpartitioned.assertPartition)
.asString

case BQSqlFrag.TableRef(table) => table.tableId.asFragment.asString
case BQSqlFrag.TableRef(table) => table.asFragment.asString

case BQSqlFrag.PartitionRef(partitionId) =>
partitionId match {
Expand Down Expand Up @@ -134,6 +130,8 @@ sealed trait BQSqlFrag {
partitionRef.wholeTable match {
case tableDef: BQTableDef.View[_] if expandAndExcludeViews =>
tableDef.query.collect(pf(Some(partitionRef))).flatten
case tvf: BQAppliedTableValuedFunction[_, _] if expandAndExcludeViews =>
tvf.tvf.query.collect(pf(Some(partitionRef))).flatten
case _ => List(partitionRef)
}

Expand Down Expand Up @@ -166,7 +164,7 @@ sealed trait BQSqlFrag {
.filterNot(pid => pid.wholeTable.isInstanceOf[BQTableDef.View[_]])

final def allReferencedUDFs: Seq[UDF[UDF.UDFId, _]] =
this.collect { case BQSqlFrag.Call(udf: UDF[_, _], _) => udf }.distinct
this.collect { case BQSqlFrag.Call(udf, _) => udf }.distinct

override def toString: String = asString
}
Expand All @@ -176,15 +174,10 @@ object BQSqlFrag {
def backticks(string: String): BQSqlFrag = Frag("`" + string + "`")

case class Frag(string: String) extends BQSqlFrag
case class Call(routine: BQRoutine[_], args: List[BQSqlFrag]) extends BQSqlFrag {
case class Call(udf: UDF[UDF.UDFId, _], args: List[BQSqlFrag]) extends BQSqlFrag {
require(
routine.params.length == args.length, {
val name = routine match {
case tvf: TVF[_, _] => tvf.name.asString
case udf: UDF[_, _] => udf.name.asString
}
show"Routine ${name}: Expected ${routine.params.length} arguments, got ${args.length}"
}
udf.params.length == args.length,
show"UDF ${udf.name}: Expected ${udf.params.length} arguments, got ${args.length}"
)
}
case class Combined(values: Seq[BQSqlFrag]) extends BQSqlFrag
Expand Down
26 changes: 26 additions & 0 deletions core/src/main/scala/no/nrk/bigquery/BQTableLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

package no.nrk.bigquery

import no.nrk.bigquery.syntax._
import cats.effect.Concurrent
import cats.syntax.all._
import no.nrk.bigquery.util.{Nat, Sized}

/** @tparam P
* partition specifier. typically [[java.time.LocalDate]] or [[scala.Unit]]
Expand All @@ -18,6 +20,7 @@ sealed trait BQTableLike[+P] {
def withTableType[PP](tpe: BQPartitionType[PP]): BQTableLike[PP]
def unpartitioned: BQTableLike[Unit]
def wholeTable: WholeTable[P] = WholeTable(this)
def asFragment: BQSqlFrag
}

object BQTableLike {
Expand Down Expand Up @@ -90,6 +93,28 @@ case class BQTableRef[+P](

override def withTableType[PP](tpe: BQPartitionType[PP]): BQTableRef[PP] =
BQTableRef(tableId, tpe)

def asFragment: BQSqlFrag = tableId.asFragment
}

case class BQAppliedTableValuedFunction[+P, N <: Nat](
tvf: TVF[P, N],
args: Sized[IndexedSeq[BQSqlFrag], N]
) extends BQTableLike[P] {
override def tableId: BQTableId =
BQTableId(tvf.name.dataset, tvf.name.name.value)

override def partitionType: BQPartitionType[P] = tvf.partitionType

override def unpartitioned: BQTableLike[Unit] =
withTableType(BQPartitionType.ignoredPartitioning(partitionType))

override def withTableType[PP](tpe: BQPartitionType[PP]): BQTableLike[PP] =
BQAppliedTableValuedFunction(
TVF(tvf.name, tpe, tvf.params, tvf.query, tvf.schema, tvf.description),
args
)
def asFragment: BQSqlFrag = tableId.asFragment ++ args.unsized.mkFragment("(", ", ", ")")
}

/** Our version of a description of what a BQ table/view should look like.
Expand All @@ -99,6 +124,7 @@ sealed trait BQTableDef[+P] extends BQTableLike[P] {
def description: Option[String]
def schema: BQSchema
def labels: TableLabels
def asFragment: BQSqlFrag = tableId.asFragment

labels.verify(tableId)
}
Expand Down
47 changes: 24 additions & 23 deletions core/src/main/scala/no/nrk/bigquery/internal/RoutineSyntax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,29 @@ object RoutineSyntax {
}

class RoutineOps0(routine: BQRoutine[_0]) {
def apply(): BQSqlFrag.Call = BQSqlFrag.Call(routine, List.empty)
def apply(): BQSqlFrag =
routine.call(RoutineSyntax.builder.empty)
}

class RoutineOps1(routine: BQRoutine[_1]) {
def apply(
m1: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1))
}

class RoutineOps2(routine: BQRoutine[_2]) {
def apply(
m1: BQSqlFrag.Magnet,
m2: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2))
}

class RoutineOps3(routine: BQRoutine[_3]) {
def apply(
m1: BQSqlFrag.Magnet,
m2: BQSqlFrag.Magnet,
m3: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3))
}

class RoutineOps4(routine: BQRoutine[_4]) {
Expand All @@ -71,7 +72,7 @@ class RoutineOps4(routine: BQRoutine[_4]) {
m2: BQSqlFrag.Magnet,
m3: BQSqlFrag.Magnet,
m4: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4))
}

class RoutineOps5(routine: BQRoutine[_5]) {
Expand All @@ -81,7 +82,7 @@ class RoutineOps5(routine: BQRoutine[_5]) {
m3: BQSqlFrag.Magnet,
m4: BQSqlFrag.Magnet,
m5: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5))
}

class RoutineOps6(routine: BQRoutine[_6]) {
Expand All @@ -92,7 +93,7 @@ class RoutineOps6(routine: BQRoutine[_6]) {
m4: BQSqlFrag.Magnet,
m5: BQSqlFrag.Magnet,
m6: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6))
}

class RoutineOps7(routine: BQRoutine[_7]) {
Expand All @@ -104,7 +105,7 @@ class RoutineOps7(routine: BQRoutine[_7]) {
m5: BQSqlFrag.Magnet,
m6: BQSqlFrag.Magnet,
m7: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7))
}

class RoutineOps8(routine: BQRoutine[_8]) {
Expand All @@ -117,7 +118,7 @@ class RoutineOps8(routine: BQRoutine[_8]) {
m6: BQSqlFrag.Magnet,
m7: BQSqlFrag.Magnet,
m8: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8))
}

class RoutineOps9(routine: BQRoutine[_9]) {
Expand All @@ -131,7 +132,7 @@ class RoutineOps9(routine: BQRoutine[_9]) {
m7: BQSqlFrag.Magnet,
m8: BQSqlFrag.Magnet,
m9: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9))
}

class RoutineOps10(routine: BQRoutine[_10]) {
Expand All @@ -146,7 +147,7 @@ class RoutineOps10(routine: BQRoutine[_10]) {
m8: BQSqlFrag.Magnet,
m9: BQSqlFrag.Magnet,
m10: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10))
}

class RoutineOps11(routine: BQRoutine[_11]) {
Expand All @@ -162,7 +163,7 @@ class RoutineOps11(routine: BQRoutine[_11]) {
m9: BQSqlFrag.Magnet,
m10: BQSqlFrag.Magnet,
m11: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11))
}

class RoutineOps12(routine: BQRoutine[_12]) {
Expand All @@ -179,7 +180,7 @@ class RoutineOps12(routine: BQRoutine[_12]) {
m10: BQSqlFrag.Magnet,
m11: BQSqlFrag.Magnet,
m12: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12))
}

class RoutineOps13(routine: BQRoutine[_13]) {
Expand All @@ -197,7 +198,7 @@ class RoutineOps13(routine: BQRoutine[_13]) {
m11: BQSqlFrag.Magnet,
m12: BQSqlFrag.Magnet,
m13: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13))
): BQSqlFrag = routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13))
}

class RoutineOps14(routine: BQRoutine[_14]) {
Expand All @@ -216,7 +217,7 @@ class RoutineOps14(routine: BQRoutine[_14]) {
m12: BQSqlFrag.Magnet,
m13: BQSqlFrag.Magnet,
m14: BQSqlFrag.Magnet
): BQSqlFrag.Call =
): BQSqlFrag =
routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14))
}

Expand All @@ -237,7 +238,7 @@ class RoutineOps15(routine: BQRoutine[_15]) {
m13: BQSqlFrag.Magnet,
m14: BQSqlFrag.Magnet,
m15: BQSqlFrag.Magnet
): BQSqlFrag.Call =
): BQSqlFrag =
routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15))
}

Expand All @@ -259,7 +260,7 @@ class RoutineOps16(routine: BQRoutine[_16]) {
m14: BQSqlFrag.Magnet,
m15: BQSqlFrag.Magnet,
m16: BQSqlFrag.Magnet
): BQSqlFrag.Call =
): BQSqlFrag =
routine.call(RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, m16))
}

Expand All @@ -282,7 +283,7 @@ class RoutineOps17(routine: BQRoutine[_17]) {
m15: BQSqlFrag.Magnet,
m16: BQSqlFrag.Magnet,
m17: BQSqlFrag.Magnet
): BQSqlFrag.Call =
): BQSqlFrag =
routine.call(
RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, m16, m17))
}
Expand All @@ -307,7 +308,7 @@ class RoutineOps18(routine: BQRoutine[_18]) {
m16: BQSqlFrag.Magnet,
m17: BQSqlFrag.Magnet,
m18: BQSqlFrag.Magnet
): BQSqlFrag.Call =
): BQSqlFrag =
routine.call(
RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, m16, m17, m18))
}
Expand All @@ -333,7 +334,7 @@ class RoutineOps19(routine: BQRoutine[_19]) {
m17: BQSqlFrag.Magnet,
m18: BQSqlFrag.Magnet,
m19: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(
): BQSqlFrag = routine.call(
RoutineSyntax.builder.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, m16, m17, m18, m19))
}

Expand All @@ -359,7 +360,7 @@ class RoutineOps20(routine: BQRoutine[_20]) {
m18: BQSqlFrag.Magnet,
m19: BQSqlFrag.Magnet,
m20: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(
): BQSqlFrag = routine.call(
RoutineSyntax.builder
.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, m16, m17, m18, m19, m20))
}
Expand Down Expand Up @@ -387,7 +388,7 @@ class RoutineOps21(routine: BQRoutine[_21]) {
m19: BQSqlFrag.Magnet,
m20: BQSqlFrag.Magnet,
m21: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(
): BQSqlFrag = routine.call(
RoutineSyntax.builder
.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, m16, m17, m18, m19, m20, m21))
}
Expand Down Expand Up @@ -416,7 +417,7 @@ class RoutineOps22(routine: BQRoutine[_22]) {
m20: BQSqlFrag.Magnet,
m21: BQSqlFrag.Magnet,
m22: BQSqlFrag.Magnet
): BQSqlFrag.Call = routine.call(
): BQSqlFrag = routine.call(
RoutineSyntax.builder
.apply(m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, m16, m17, m18, m19, m20, m21, m22))
}
6 changes: 3 additions & 3 deletions core/src/test/scala/no/nrk/bigquery/BQSqlFragTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BQSqlFragTest extends FunSuite {

test("collect nested UDFs") {
val udfIdents = bqfr"select ${udfToString(udfAddOne(bqfr"1"))}"
.collect { case BQSqlFrag.Call(udf: UDF[_, _], _) => udf }
.collect { case BQSqlFrag.Call(udf, _) => udf }
.map(_.name)
.sortBy(_.show)

Expand All @@ -58,7 +58,7 @@ class BQSqlFragTest extends FunSuite {
)

val udfIdents = bqsql"select ${outerUdf(1)}"
.collect { case BQSqlFrag.Call(udf: UDF[_, _], _) => udf }
.collect { case BQSqlFrag.Call(udf, _) => udf }
.map(_.name)

assertEquals(udfIdents, innerUdf1.name :: innerUdf2.name :: outerUdf.name :: Nil)
Expand Down Expand Up @@ -110,7 +110,7 @@ class BQSqlFragTest extends FunSuite {
)

val udfIdents = fill1.query
.collect { case BQSqlFrag.Call(udf: UDF[_, _], _) => udf }
.collect { case BQSqlFrag.Call(udf, _) => udf }
.map(_.name)

assertEquals(udfIdents, outerUdf1.name :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ object BQSmokeTest {
val schemaOpt: Option[BQSchema] =
pid.wholeTable match {
case BQTableRef(_, _, _) => None
case BQAppliedTableValuedFunction(_, _) => None
case x: BQTableDef[Any] => Some(x.schema)
}

Expand Down
Loading

0 comments on commit 48a6946

Please sign in to comment.