Skip to content

Commit

Permalink
Initial sql merge operation (not tested or used by implementations)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrix committed Jun 2, 2024
1 parent 8b1017b commit 6eb36c9
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 0 deletions.
208 changes: 208 additions & 0 deletions common/src/main/scala/dataprism/platform/sql/SqlMergeOperations.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package dataprism.platform.sql

import cats.data.State
import cats.syntax.all.*
import dataprism.platform.MapRes
import dataprism.platform.sql.value.SqlDbValuesBase
import dataprism.sharedast.{MergeAst, SqlExpr}
import dataprism.sql.*
import perspective.*

trait SqlMergeOperations extends SqlOperationsBase, SqlDbValuesBase {

class SqlMergeCompanion:
def into[A[_[_]]](table: Table[Codec, A]): SqlMergeInto[A] = new SqlMergeInto(table)

class SqlMergeInto[A[_[_]]](protected val table: Table[Codec, A]):
def using[B[_[_]]](dataSource: Query[B]): SqlMergeIntoUsing[A, B] = new SqlMergeIntoUsing(table, dataSource)

class SqlMergeIntoUsing[A[_[_]], B[_[_]]](protected val table: Table[Codec, A], protected val dataSource: Query[B]):
def on(joinCondition: (A[DbValue], B[DbValue]) => DbValue[Boolean]): SqlMergeIntoUsingOn[A, B] =
new SqlMergeIntoUsingOn(table, dataSource, joinCondition)

class SqlMergeIntoUsingOn[A[_[_]], B[_[_]]](
protected val table: Table[Codec, A],
protected val dataSource: Query[B],
protected val joinCondition: (A[DbValue], B[DbValue]) => DbValue[Boolean]
):
def whenMatched: SqlMergeMatchedKeyword[A, B] =
new SqlMergeMatchedKeyword(table, dataSource, joinCondition, Nil, None)
def whenNotMatched: SqlMergeNotMatchedKeyword[A, B] =
new SqlMergeNotMatchedKeyword(table, dataSource, joinCondition, Nil, None)

enum SqlMergeMatch[A[_[_]], B[_[_]]]:
case MatchUpdate[A1[_[_]], B1[_[_]], C[_[_]]](
cond: Option[(A1[DbValue], B1[DbValue]) => DbValue[Boolean]],
columns: A1[[X] =>> Column[Codec, X]] => C[[X] =>> Column[Codec, X]],
setValues: (A1[DbValue], B1[DbValue]) => C[DbValue]
)(using val CA: ApplyKC[C], val CT: TraverseKC[C]) extends SqlMergeMatch[A1, B1]
case MatchDelete(cond: Option[(A[DbValue], B[DbValue]) => DbValue[Boolean]])
case NotMatchInsert[A1[_[_]], B1[_[_]], C[_[_]]](
cond: Option[B1[DbValue] => DbValue[Boolean]],
columns: A1[[X] =>> Column[Codec, X]] => C[[X] =>> Column[Codec, X]],
setValues: B1[DbValue] => C[DbValue]
)(using val CA: ApplyKC[C], val CT: TraverseKC[C]) extends SqlMergeMatch[A1, B1]

class SqlMergeMatchedKeyword[A[_[_]], B[_[_]]](
protected val table: Table[Codec, A],
protected val dataSource: Query[B],
protected val joinCondition: (A[DbValue], B[DbValue]) => DbValue[Boolean],
protected val whens: Seq[SqlMergeMatch[A, B]],
protected val cond: Option[(A[DbValue], B[DbValue]) => DbValue[Boolean]]
):
def and(cond: (A[DbValue], B[DbValue]) => DbValue[Boolean]): SqlMergeMatchedKeyword[A, B] =
new SqlMergeMatchedKeyword(
table,
dataSource,
joinCondition,
whens,
Some(this.cond.fold(cond)(oldCond => (a, b) => oldCond(a, b) && cond(a, b)))
)

def thenUpdate: SqlMergeUpdateKeyword[A, B] =
new SqlMergeUpdateKeyword(table, dataSource, joinCondition, whens, cond)

def thenDelete: SqlMergeOperation[A, B] =
new SqlMergeOperation(table, dataSource, joinCondition, whens :+ SqlMergeMatch.MatchDelete(cond))

class SqlMergeUpdateKeyword[A[_[_]], B[_[_]]](
protected val table: Table[Codec, A],
protected val dataSource: Query[B],
protected val joinCondition: (A[DbValue], B[DbValue]) => DbValue[Boolean],
protected val whens: Seq[SqlMergeMatch[A, B]],
protected val cond: Option[(A[DbValue], B[DbValue]) => DbValue[Boolean]]
):
def values(setValues: (A[DbValue], B[DbValue]) => A[DbValue]): SqlMergeOperation[A, B] =
import table.given
new SqlMergeOperation(
table,
dataSource,
joinCondition,
whens :+ SqlMergeMatch.MatchUpdate(cond, identity, setValues)
)

def valuesInColumnsK[C[_[_]]: ApplyKC: TraverseKC](
columns: A[[X] =>> Column[Codec, X]] => C[[X] =>> Column[Codec, X]]
)(
setValues: (A[DbValue], B[DbValue]) => C[DbValue]
): SqlMergeOperation[A, B] =
new SqlMergeOperation(
table,
dataSource,
joinCondition,
whens :+ SqlMergeMatch.MatchUpdate(cond, columns, setValues)
)

inline def valuesInColumns[T](columns: A[[X] =>> Column[Codec, X]] => T)(
using mr: MapRes[[X] =>> Column[Codec, X], T]
)(setValues: (A[DbValue], B[DbValue]) => mr.K[DbValue]): SqlMergeOperation[A, B] =
valuesInColumnsK(a => mr.toK(columns(a)))((a, b) => setValues(a, b))(using mr.applyKC, mr.traverseKC)

class SqlMergeNotMatchedKeyword[A[_[_]], B[_[_]]](
protected val table: Table[Codec, A],
protected val dataSource: Query[B],
protected val joinCondition: (A[DbValue], B[DbValue]) => DbValue[Boolean],
protected val whens: Seq[SqlMergeMatch[A, B]],
protected val cond: Option[B[DbValue] => DbValue[Boolean]]
):
def and(cond: B[DbValue] => DbValue[Boolean]): SqlMergeNotMatchedKeyword[A, B] =
new SqlMergeNotMatchedKeyword(
table,
dataSource,
joinCondition,
whens,
Some(this.cond.fold(cond)(oldCond => a => oldCond(a) && cond(a)))
)

def thenInsert: SqlMergeInsertKeyword[A, B] =
new SqlMergeInsertKeyword(table, dataSource, joinCondition, whens, cond)

class SqlMergeInsertKeyword[A[_[_]], B[_[_]]](
protected val table: Table[Codec, A],
protected val dataSource: Query[B],
protected val joinCondition: (A[DbValue], B[DbValue]) => DbValue[Boolean],
protected val whens: Seq[SqlMergeMatch[A, B]],
protected val cond: Option[B[DbValue] => DbValue[Boolean]]
):

def values(setValues: B[DbValue] => A[DbValue]): SqlMergeOperation[A, B] =
import table.given
new SqlMergeOperation(
table,
dataSource,
joinCondition,
whens :+ SqlMergeMatch.NotMatchInsert(cond, identity, setValues)
)

def valuesInColumnsK[C[_[_]]: ApplyKC: TraverseKC](
columns: A[[X] =>> Column[Codec, X]] => C[[X] =>> Column[Codec, X]]
)(
setValues: B[DbValue] => C[DbValue]
): SqlMergeOperation[A, B] =
new SqlMergeOperation(
table,
dataSource,
joinCondition,
whens :+ SqlMergeMatch.NotMatchInsert(cond, columns, setValues)
)

inline def valuesInColumns[T](columns: A[[X] =>> Column[Codec, X]] => T)(
using mr: MapRes[[X] =>> Column[Codec, X], T]
)(setValues: B[DbValue] => mr.K[DbValue]): SqlMergeOperation[A, B] =
valuesInColumnsK(a => mr.toK(columns(a)))(b => setValues(b))(using mr.applyKC, mr.traverseKC)

class SqlMergeOperation[A[_[_]], B[_[_]]](
protected val table: Table[Codec, A],
protected val dataSource: Query[B],
protected val joinCondition: (A[DbValue], B[DbValue]) => DbValue[Boolean],
protected val whens: Seq[SqlMergeMatch[A, B]]
) extends IntOperation:

def whenMatched: SqlMergeMatchedKeyword[A, B] =
new SqlMergeMatchedKeyword(table, dataSource, joinCondition, whens, None)

def whenNotMatched: SqlMergeNotMatchedKeyword[A, B] =
new SqlMergeNotMatchedKeyword(table, dataSource, joinCondition, whens, None)

override def sqlAndTypes: (SqlStr[Codec], Type[Int]) =
val valuesQuery = Query.from(table).join(dataSource)(joinCondition)
import valuesQuery.given

val st = for
astMetadata <- valuesQuery.selectAstAndValues
whenAsts <- whens.traverse:
case m: SqlMergeMatch.MatchUpdate[a, b, c] =>
import m.given
val usedColumns = m.CT.foldMapK[[X] =>> Column[Codec, X], Nothing](m.columns(table.columns))(
[X] => (col: Column[Codec, X]) => List(col.name)
)

val newValues = m.setValues.tupled(astMetadata.values)

for
condAst <- m.cond.traverse(f => f.tupled(astMetadata.values).ast)
values <- m.CT.traverseConst[DbValue, Nothing](newValues)([Z] => (v: DbValue[Z]) => v.ast).map(_.toListK)
yield MergeAst.When(not = false, condAst, MergeAst.WhenOperation.Update(usedColumns, values))

case SqlMergeMatch.MatchDelete(cond) =>
cond
.traverse(f => f.tupled(astMetadata.values).ast)
.map(condAst => MergeAst.When(not = false, condAst, MergeAst.WhenOperation.Delete()))

case m: SqlMergeMatch.NotMatchInsert[a, b, c] =>
import m.given
val usedColumns = m.CT.foldMapK[[X] =>> Column[Codec, X], Nothing](m.columns(table.columns))(
[X] => (col: Column[Codec, X]) => List(col.name)
)
val bValues = astMetadata.values._2

val newValues = m.setValues(bValues)

for
condAst <- m.cond.traverse(f => f(bValues).ast)
values <- m.CT.traverseConst[DbValue, Nothing](newValues)([Z] => (v: DbValue[Z]) => v.ast).map(_.toListK)
yield MergeAst.When(not = true, condAst, MergeAst.WhenOperation.Insert(usedColumns, values))
yield MergeAst(astMetadata.ast, whenAsts)

(sqlRenderer.renderMerge(st.runA(freshTaggedState).value), AnsiTypes.integer)
}
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ trait SqlDbValuesBase extends SqlQueryPlatformBase { platform =>

type SqlLogic[A] <: SqlLogicBase[A]

given booleanSqlLogic: SqlLogic[Boolean]
given booleanOptSqlLogic: SqlLogic[Option[Boolean]]

type Many[A]
val Many: ManyCompanion

Expand Down
58 changes: 58 additions & 0 deletions common/src/main/scala/dataprism/sharedast/AstRenderer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,64 @@ class AstRenderer[Codec[_]](ansiTypes: AnsiTypes[Codec], getCodecTypeName: [A] =
)
end renderDelete

def renderMerge(ast: MergeAst[Codec]): SqlStr[Codec] =
val (table, alias, usingV, joinCondition) = ast.selectAst match {
case SelectAst.SelectFrom(
None,
_,
Some(SelectAst.From.InnerJoin(SelectAst.From.FromTable(table, alias), usingV, joinCondition)),
None,
None,
None,
None,
None,
None
) =>
(table, alias, usingV, joinCondition)

case _ =>
// TODO: Enforce statically in the API
throw new IllegalArgumentException("Can't use any other operator than a join with renderMerge")
}

spaceConcat(
sql"MERGE INTO",
SqlStr.const(quote(table)),
alias.fold(sql"")(a => sql"AS ${SqlStr.const(quote(a))}"),
sql"USING ${renderFrom(usingV)}",
sql"ON",
renderExpr(joinCondition),
ast.whens.map(renderMergeWhen).intercalate(sql"")
)
end renderMerge

protected def renderMergeWhen(when: MergeAst.When[Codec]): SqlStr[Codec] =
spaceConcat(
sql"WHEN",
if when.not then sql"NOT" else sql"",
sql"MATCHED",
when.cond.fold(sql"")(cond => sql"AND ${renderExpr(cond)}"),
sql"THEN",
when.operation match
case MergeAst.WhenOperation.Update(usedColumns, values) =>
spaceConcat(
sql"UPDATE SET",
usedColumns.zip(values).map((col, e) => sql"${quoteSql(col)} = ${renderExpr(e)}").intercalate(sql", ")
)
case MergeAst.WhenOperation.Delete() => sql"DELETE"
case MergeAst.WhenOperation.Insert(usedColumns, values) =>
spaceConcat(
sql"INSERT",
sql"(",
usedColumns.map(quoteSql).intercalate(sql", "),
sql")",
sql"VALUES",
sql"(",
values.map(renderExpr).intercalate(sql", "),
sql")"
)
)

def renderSelectStatement(data: SelectAst[Codec]): SqlStr[Codec] = renderSelect(data)

protected def renderSelect(data: SelectAst[Codec]): SqlStr[Codec] = data match
Expand Down
13 changes: 13 additions & 0 deletions common/src/main/scala/dataprism/sharedast/SharedSqlAst.scala
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,16 @@ object SelectAst {

case class Locks() // TODO
}

case class MergeAst[Codec[_]](
selectAst: SelectAst[Codec],
whens: Seq[MergeAst.When[Codec]]
)
object MergeAst {
case class When[Codec[_]](not: Boolean, cond: Option[SqlExpr[Codec]], operation: WhenOperation[Codec])

enum WhenOperation[Codec[_]]:
case Update(usedColumns: Seq[SqlStr[Codec]], values: Seq[SqlExpr[Codec]])
case Delete()
case Insert(usedColumns: Seq[SqlStr[Codec]], values: Seq[SqlExpr[Codec]])
}

0 comments on commit 6eb36c9

Please sign in to comment.