diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..78b256a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +default_stages: + - commit + - push +fail_fast: true +repos: + - repo: local + hooks: + - id: scalafmt-test + name: Scalafmt Fixes + pass_filenames: false + language: system + entry: scalafmt + always_run: true + - id: scalac-lint + name: Scala lint + language: system + always_run: true + pass_filenames: false + verbose: true + entry: sbt + args: [ '; clean ; set scalacOptions ++= Seq("-Xfatal-warnings", "-Wconf:any:warning-verbose") ; compile' ] \ No newline at end of file diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 0000000..6b1f7b2 --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,42 @@ +version=2.7.5 +align { + preset = some +} +optIn.configStyleArguments = false +maxColumn = 90 +newlines { + neverInResultType = false + alwaysBeforeElseAfterCurlyIf = false + topLevelStatements = [before,after] + implicitParamListModifierPrefer = before + avoidForSimpleOverflow = [tooLong] +} +danglingParentheses { + defnSite = false + callSite = false + ctrlSite = false +} +align { + openParenCallSite = false + openParenDefnSite = true +} +docstrings { + blankFirstLine = yes + style = SpaceAsterisk +} +continuationIndent { + extendSite = 2 + withSiteRelativeToExtends = 2 + ctorSite = 4 + callSite = 2 + defnSite = 4 +} +runner { + optimizer { + forceConfigStyleOnOffset = 45 + forceConfigStyleMinArgCount = 5 + } +} +verticalMultiline { + newlineAfterOpenParen = false +} \ No newline at end of file diff --git a/build.sbt b/build.sbt index 6c32b55..ae6b050 100644 --- a/build.sbt +++ b/build.sbt @@ -11,16 +11,15 @@ val dependencies = Seq( "org.apache.spark" %% "spark-hive" % sparkVersion % Provided, "org.apache.spark" %% "spark-avro" % sparkVersion % Provided, "io.delta" %% "delta-core" % "0.7.0" % Provided, - "com.typesafe" % "config" % "1.3.2" -) + "com.typesafe" % "config" % "1.3.2") val testDependencies = Seq( "org.scalatest" %% "scalatest" % "3.0.5" % Test, "org.scalamock" %% "scalamock" % "4.1.0" % Test, - "com.holdenkarau" %% "spark-testing-base" % s"${sparkTestVersion}_0.14.0" % Test -) + "com.holdenkarau" %% "spark-testing-base" % s"${sparkTestVersion}_0.14.0" % Test) import xerial.sbt.Sonatype._ + val settings = Seq( organization := "com.damavis", version := "0.3.10", @@ -29,30 +28,27 @@ val settings = Seq( libraryDependencies ++= dependencies ++ testDependencies, fork in Test := true, parallelExecution in Test := false, - envVars in Test := Map( - "MASTER" -> "local[*]" - ), + envVars in Test := Map("MASTER" -> "local[*]"), test in assembly := {}, // Sonatype sonatypeProfileName := "com.damavis", sonatypeProjectHosting := Some( GitHubHosting("damavis", "damavis-spark", "info@damavis.com")), publishMavenStyle := true, - licenses := Seq( - "APL2" -> url("http://www.apache.org/licenses/LICENSE-2.0.txt")), + licenses := Seq("APL2" -> url("http://www.apache.org/licenses/LICENSE-2.0.txt")), developers := List( - Developer(id = "piffall", - name = "Cristòfol Torrens", - email = "piffall@gmail.com", - url = url("http://piffall.com")), - Developer(id = "priera", - name = "Pedro Riera", - email = "pedro.riera at damavis dot com", - url = url("http://github.com/priera")), - ), + Developer( + id = "piffall", + name = "Cristòfol Torrens", + email = "piffall@gmail.com", + url = url("http://piffall.com")), + Developer( + id = "priera", + name = "Pedro Riera", + email = "pedro.riera at damavis dot com", + url = url("http://github.com/priera"))), publishTo := sonatypePublishToBundle.value, - credentials += Publish.credentials -) + credentials += Publish.credentials) lazy val root = (project in file(".")) .settings(name := "damavis-spark") @@ -63,16 +59,12 @@ lazy val root = (project in file(".")) lazy val core = (project in file("damavis-spark-core")) .settings(settings) .settings(name := "damavis-spark-core") - .settings( - crossScalaVersions := supportedScalaVersions, - ) + .settings(crossScalaVersions := supportedScalaVersions) lazy val azure = (project in file("damavis-spark-azure")) .settings(settings) .settings(name := "damavis-spark-azure") - .settings( - crossScalaVersions := supportedScalaVersions, - ) + .settings(crossScalaVersions := supportedScalaVersions) .dependsOn(core) lazy val snowflake = (project in file("damavis-spark-snowflake")) @@ -80,8 +72,5 @@ lazy val snowflake = (project in file("damavis-spark-snowflake")) .settings(name := "damavis-spark-snowflake") .settings( crossScalaVersions := supportedScalaVersions, - libraryDependencies ++= Seq( - "net.snowflake" %% "spark-snowflake" % "2.8.2-spark_3.0" - ) - ) - .dependsOn(core) \ No newline at end of file + libraryDependencies ++= Seq("net.snowflake" %% "spark-snowflake" % "2.8.2-spark_3.0")) + .dependsOn(core) diff --git a/damavis-spark-azure/src/main/scala/com/damavis/spark/resource/datasource/azure/SynapseReader.scala b/damavis-spark-azure/src/main/scala/com/damavis/spark/resource/datasource/azure/SynapseReader.scala index d6948e9..467720f 100644 --- a/damavis-spark-azure/src/main/scala/com/damavis/spark/resource/datasource/azure/SynapseReader.scala +++ b/damavis-spark-azure/src/main/scala/com/damavis/spark/resource/datasource/azure/SynapseReader.scala @@ -6,9 +6,8 @@ import com.damavis.spark.resource.ResourceReader import org.apache.hadoop.fs.Path import org.apache.spark.sql.{DataFrame, SparkSession} -class SynapseReader(url: URL, query: String, tempDir: Path)( - implicit spark: SparkSession) - extends ResourceReader { +class SynapseReader(url: URL, query: String, tempDir: Path)(implicit spark: SparkSession) + extends ResourceReader { override def read(): DataFrame = { spark.read @@ -19,4 +18,5 @@ class SynapseReader(url: URL, query: String, tempDir: Path)( .option("query", query) .load() } + } diff --git a/damavis-spark-azure/src/main/scala/com/damavis/spark/resource/datasource/azure/SynapseWriter.scala b/damavis-spark-azure/src/main/scala/com/damavis/spark/resource/datasource/azure/SynapseWriter.scala index 6ba8aba..947db08 100644 --- a/damavis-spark-azure/src/main/scala/com/damavis/spark/resource/datasource/azure/SynapseWriter.scala +++ b/damavis-spark-azure/src/main/scala/com/damavis/spark/resource/datasource/azure/SynapseWriter.scala @@ -6,9 +6,8 @@ import com.damavis.spark.resource.ResourceWriter import org.apache.hadoop.fs.Path import org.apache.spark.sql.{DataFrame, SparkSession} -class SynapseWriter(url: URL, table: String, tempDir: Path)( - implicit spark: SparkSession) - extends ResourceWriter { +class SynapseWriter(url: URL, table: String, tempDir: Path)(implicit spark: SparkSession) + extends ResourceWriter { override def write(data: DataFrame): Unit = { data.write diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/SparkApp.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/SparkApp.scala index c904649..b02443e 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/SparkApp.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/SparkApp.scala @@ -35,9 +35,11 @@ trait SparkApp extends SparkConf { else { val master = spark.conf.get("spark.master") val localCores = "local\\[(\\d+|\\*)\\]".r.findAllIn(master) - if (localCores.hasNext) localCores.group(1) match { - case "*" => sys.runtime.availableProcessors() - case x => x.toInt + if (localCores.hasNext) { + localCores.group(1) match { + case "*" => sys.runtime.availableProcessors() + case x => x.toInt + } } else { 1 } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/database/Database.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/database/Database.scala index edd9d7e..80f41ca 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/database/Database.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/database/Database.scala @@ -16,10 +16,9 @@ import org.slf4j.LoggerFactory import scala.language.postfixOps import scala.util.{Failure, Success, Try} -class Database( - db: SparkDatabase, - fs: FileSystem, - protected[database] val catalog: Catalog)(implicit spark: SparkSession) { +class Database(db: SparkDatabase, + fs: FileSystem, + protected[database] val catalog: Catalog)(implicit spark: SparkSession) { private lazy val logger = LoggerFactory.getLogger(this.getClass) @@ -59,9 +58,7 @@ class Database( } } - def getUnmanagedTable(name: String, - path: String, - format: Format): Try[Table] = { + def getUnmanagedTable(name: String, path: String, format: Format): Try[Table] = { Try { val dbPath = parseAndCheckTableName(name) val actualName = dbPath._2 @@ -142,7 +139,7 @@ class Database( logger.info( s"Table partitioned by ${catalogTable.partitionColumnNames.mkString("[", ",", "]")}") - catalogTable.schema.printTreeString() + logger.info(catalogTable.schema.treeString) // This block of code is necessary because Databricks runtime do not // provide DeltaTableUtils. @@ -163,9 +160,7 @@ class Database( logger.warn("Keeping catalog only data") catalogTable case ue: Throwable => - logger.error( - "Could not combine catalog and delta meta, Unknown Cause: ", - ue) + logger.error("Could not combine catalog and delta meta, Unknown Cause: ", ue) logger.warn("Keeping catalog only data") catalogTable } @@ -178,10 +173,11 @@ class Database( val partitions = metadata.partitionColumnNames val columns = metadata.schema.map(field => { - Column(field.name, - field.dataType.simpleString, - partitions.contains(field.name), - field.nullable) + Column( + field.name, + field.dataType.simpleString, + partitions.contains(field.name), + field.nullable) }) columns diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/database/DbManager.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/database/DbManager.scala index 7fbea56..595e739 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/database/DbManager.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/database/DbManager.scala @@ -42,4 +42,5 @@ object DbManager { new Database(db, HadoopFS(), catalog) } + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/database/Table.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/database/Table.scala index b0642e4..f7dfa81 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/database/Table.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/database/Table.scala @@ -2,10 +2,7 @@ package com.damavis.spark.database import com.damavis.spark.resource.Format.Format -case class Column(name: String, - dataType: String, - partitioned: Boolean, - nullable: Boolean) +case class Column(name: String, dataType: String, partitioned: Boolean, nullable: Boolean) sealed trait Table { def database: String @@ -24,7 +21,7 @@ case class RealTable(database: String, format: Format, managed: Boolean, columns: Seq[Column]) - extends Table + extends Table case class DummyTable(database: String, name: String) extends Table { override def path: String = ??? diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/DatabaseNotFoundException.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/DatabaseNotFoundException.scala index f5733e7..2b0ff11 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/DatabaseNotFoundException.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/DatabaseNotFoundException.scala @@ -1,4 +1,4 @@ package com.damavis.spark.database.exceptions class DatabaseNotFoundException(name: String) - extends Exception(s"""Database "$name" not found in catalog""") {} + extends Exception(s"""Database "$name" not found in catalog""") {} diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/InvalidDatabaseNameException.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/InvalidDatabaseNameException.scala index d438ce0..1dacdf0 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/InvalidDatabaseNameException.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/InvalidDatabaseNameException.scala @@ -1,4 +1,4 @@ package com.damavis.spark.database.exceptions class InvalidDatabaseNameException(name: String) - extends Exception(s""""$name" is not a valid database name""") {} + extends Exception(s""""$name" is not a valid database name""") {} diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/TableDefinitionException.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/TableDefinitionException.scala index 9462423..b09b655 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/TableDefinitionException.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/database/exceptions/TableDefinitionException.scala @@ -1,4 +1,3 @@ package com.damavis.spark.database.exceptions -class TableDefinitionException(val table: String, msg: String) - extends Exception(msg) {} +class TableDefinitionException(val table: String, msg: String) extends Exception(msg) {} diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlow.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlow.scala index f7225a9..dc4e789 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlow.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlow.scala @@ -1,6 +1,7 @@ package com.damavis.spark.dataflow class DataFlow(definition: DataFlowDefinition) { + def run(): Unit = { for (source <- definition.sources) source.compute() diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlowSource.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlowSource.scala index 080f839..e4fb7db 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlowSource.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlowSource.scala @@ -1,7 +1,6 @@ package com.damavis.spark.dataflow -class DataFlowSource(processor: SourceProcessor) - extends DataFlowStage(processor) { +class DataFlowSource(processor: SourceProcessor) extends DataFlowStage(processor) { override def ->(stage: StageSocket)( implicit definition: DataFlowDefinition): DataFlowStage = { diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlowStage.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlowStage.scala index 8d83bee..4431156 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlowStage.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/DataFlowStage.scala @@ -1,8 +1,10 @@ package com.damavis.spark.dataflow object DataFlowStage { + def apply(processor: Processor): DataFlowStage = new DataFlowStage(processor) + } class DataFlowStage(private val processor: Processor) { @@ -13,6 +15,7 @@ class DataFlowStage(private val processor: Processor) { protected val sockets: SocketSet = SocketSet(new StageSocket(this), new StageSocket(this)) + private var deliverySocket: StageSocket = _ protected def toRun: Boolean = _toRun @@ -38,8 +41,7 @@ class DataFlowStage(private val processor: Processor) { } - def ->(socket: StageSocket)( - implicit definition: DataFlowDefinition): DataFlowStage = { + def ->(socket: StageSocket)(implicit definition: DataFlowDefinition): DataFlowStage = { //TODO: disallow assignment to stages already connected if (this == socket.stage) diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/JoinProcessor.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/JoinProcessor.scala index b05ecc9..898d51a 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/JoinProcessor.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/JoinProcessor.scala @@ -8,4 +8,5 @@ abstract class JoinProcessor extends Processor { override def compute(sockets: SocketSet): DataFrame = computeImpl(sockets.left.get, sockets.right.get) + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/LinealProcessor.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/LinealProcessor.scala index a7d5014..8bb72c3 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/LinealProcessor.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/LinealProcessor.scala @@ -1,4 +1,5 @@ package com.damavis.spark.dataflow + import org.apache.spark.sql.DataFrame abstract class LinealProcessor extends Processor { @@ -8,4 +9,5 @@ abstract class LinealProcessor extends Processor { override def compute(sockets: SocketSet): DataFrame = { computeImpl(sockets.left.get) } + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/SourceProcessor.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/SourceProcessor.scala index 0ef9a80..d948b9b 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/SourceProcessor.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/SourceProcessor.scala @@ -1,4 +1,5 @@ package com.damavis.spark.dataflow + import org.apache.spark.sql.DataFrame abstract class SourceProcessor extends Processor { diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/package.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/package.scala index e9883c7..340978d 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/package.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/dataflow/package.scala @@ -8,6 +8,7 @@ import scala.language.implicitConversions package object dataflow { object implicits { + implicit def defaultSocketOfStage(stage: DataFlowStage): StageSocket = stage.left @@ -32,6 +33,7 @@ package object dataflow { target.left } + } } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/fs/HadoopFS.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/fs/HadoopFS.scala index 334c021..2e00714 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/fs/HadoopFS.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/fs/HadoopFS.scala @@ -5,8 +5,10 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession object HadoopFS { + def apply(root: String = "/")(implicit spark: SparkSession): HadoopFS = new HadoopFS(new Path(root)) + } class HadoopFS(root: Path)(implicit spark: SparkSession) extends FileSystem { @@ -27,8 +29,7 @@ class HadoopFS(root: Path)(implicit spark: SparkSession) extends FileSystem { val fs = root.getFileSystem(hadoopConf) if (!fs.isDirectory(pathToCheck)) - throw new IllegalArgumentException( - s"path: $path is not a directory in HDFS") + throw new IllegalArgumentException(s"path: $path is not a directory in HDFS") fs.listStatus(pathToCheck) .filter(_.isDirectory) diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/pipeline/Pipeline.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/pipeline/Pipeline.scala index db29e44..792d85b 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/pipeline/Pipeline.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/pipeline/Pipeline.scala @@ -10,6 +10,7 @@ object Pipeline { def apply(pipelines: Pipeline*): Pipeline = new Pipeline(pipelines.flatMap(x => x.getStages).toList) + } class Pipeline(stages: List[PipelineStage]) { @@ -24,4 +25,5 @@ class Pipeline(stages: List[PipelineStage]) { def ->(stage: PipelineStage): Pipeline = Pipeline(this.getStages ++ List(stage)) + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/pipeline/implicits/package.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/pipeline/implicits/package.scala index 9431aa3..e86d60f 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/pipeline/implicits/package.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/pipeline/implicits/package.scala @@ -15,4 +15,5 @@ package object implicits { new PipelineTarget { override def put(data: DataFrame): Unit = resource.write(data) } + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/BasicResourceRW.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/BasicResourceRW.scala index cbac38f..6e32758 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/BasicResourceRW.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/BasicResourceRW.scala @@ -1,8 +1,8 @@ package com.damavis.spark.resource + import org.apache.spark.sql.DataFrame -class BasicResourceRW(reader: ResourceReader, writer: ResourceWriter) - extends ResourceRW { +class BasicResourceRW(reader: ResourceReader, writer: ResourceWriter) extends ResourceRW { override def write(data: DataFrame): Unit = writer.write(data) override def read(): DataFrame = reader.read() diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/Format.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/Format.scala index 0aade03..280554d 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/Format.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/Format.scala @@ -9,4 +9,4 @@ object Format extends Enumeration { val Csv = Value("csv") val Orc = Value("orc") val Delta = Value("delta") -} \ No newline at end of file +} diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableRWBuilder.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableRWBuilder.scala index 1bf48e0..71ef9ae 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableRWBuilder.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableRWBuilder.scala @@ -5,11 +5,13 @@ import com.damavis.spark.resource.{BasicResourceRW, RWBuilder, ResourceRW} import org.apache.spark.sql.SparkSession class TableRWBuilder(table: Table)(implicit spark: SparkSession, db: Database) - extends RWBuilder { + extends RWBuilder { + override def build(): ResourceRW = { val reader = TableReaderBuilder(table).reader() val writer = TableWriterBuilder(table).writer() new BasicResourceRW(reader, writer) } + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableReaderBuilder.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableReaderBuilder.scala index f37f067..26b9052 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableReaderBuilder.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableReaderBuilder.scala @@ -9,8 +9,7 @@ import scala.util.{Failure, Success, Try} object TableReaderBuilder { - def apply(tryTable: Try[Table])( - implicit spark: SparkSession): TableReaderBuilder = { + def apply(tryTable: Try[Table])(implicit spark: SparkSession): TableReaderBuilder = { tryTable match { case Success(table) => apply(table) case Failure(exception) => throw exception @@ -28,10 +27,13 @@ object TableReaderBuilder { new TableReaderBuilder(table, spark) } + } class TableReaderBuilder protected (table: Table, spark: SparkSession) - extends ReaderBuilder { + extends ReaderBuilder { + override def reader(): ResourceReader = new TableResourceReader(spark, table) + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableResourceReader.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableResourceReader.scala index c1b9ca2..a97cb24 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableResourceReader.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableResourceReader.scala @@ -4,9 +4,10 @@ import com.damavis.spark.database.Table import com.damavis.spark.resource.ResourceReader import org.apache.spark.sql.{DataFrame, SparkSession} -class TableResourceReader(spark: SparkSession, table: Table) - extends ResourceReader { +class TableResourceReader(spark: SparkSession, table: Table) extends ResourceReader { + override def read(): DataFrame = spark.read .table(s"${table.database}.${table.name}") + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableResourceWriter.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableResourceWriter.scala index b600cb4..e0fb9a5 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableResourceWriter.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableResourceWriter.scala @@ -11,7 +11,7 @@ class TableResourceWriter(spark: SparkSession, table: Table, db: Database, params: TableWriterParameters) - extends ResourceWriter { + extends ResourceWriter { private var actualTable: Table = table @@ -20,8 +20,7 @@ class TableResourceWriter(spark: SparkSession, val format = params.storageFormat val partitionedBy = params.partitionedBy.getOrElse(Nil) - actualTable = - db.addTableIfNotExists(actualTable, schema, format, partitionedBy) + actualTable = db.addTableIfNotExists(actualTable, schema, format, partitionedBy) } private def mergeExpression(partitions: Seq[String]): String = { @@ -118,8 +117,7 @@ class TableResourceWriter(spark: SparkSession, case e: Throwable => throw e } finally { spark.conf - .set("spark.sql.sources.partitionOverwriteMode", - previousOverwriteConf) + .set("spark.sql.sources.partitionOverwriteMode", previousOverwriteConf) } } } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableWriterBuilder.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableWriterBuilder.scala index 388ea1f..04e245a 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableWriterBuilder.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/datasource/TableWriterBuilder.scala @@ -12,29 +12,30 @@ import scala.util.{Failure, Success, Try} object TableWriterBuilder { - def apply(tryTable: Try[Table])(implicit spark: SparkSession, - db: Database): BasicTableWriterBuilder = { + def apply(tryTable: Try[Table])( + implicit spark: SparkSession, + db: Database): BasicTableWriterBuilder = { tryTable match { case Success(table) => apply(table) case Failure(exception) => throw exception } } - def apply(table: Table)(implicit spark: SparkSession, - db: Database): BasicTableWriterBuilder = { + def apply(table: Table)( + implicit spark: SparkSession, + db: Database): BasicTableWriterBuilder = { val params = TableWriterParameters() table match { case _: DummyTable => new BasicTableWriterBuilder(table, db, params) case _: RealTable => new SealedTableWriterBuilder(table, db, params) } } + } -class BasicTableWriterBuilder( - table: Table, - db: Database, - params: TableWriterParameters)(implicit spark: SparkSession) - extends WriterBuilder { +class BasicTableWriterBuilder(table: Table, db: Database, params: TableWriterParameters)( + implicit spark: SparkSession) + extends WriterBuilder { private var myParams: TableWriterParameters = params @@ -59,8 +60,7 @@ class BasicTableWriterBuilder( datePartitioned(formatter) } - def datePartitioned( - formatter: DatePartitionFormatter): BasicTableWriterBuilder = { + def datePartitioned(formatter: DatePartitionFormatter): BasicTableWriterBuilder = { partitionedBy(formatter.columnNames: _*) } @@ -85,11 +85,9 @@ class BasicTableWriterBuilder( } -class SealedTableWriterBuilder( - table: Table, - db: Database, - params: TableWriterParameters)(implicit spark: SparkSession) - extends BasicTableWriterBuilder(table, db, params) { +class SealedTableWriterBuilder(table: Table, db: Database, params: TableWriterParameters)( + implicit spark: SparkSession) + extends BasicTableWriterBuilder(table, db, params) { override def withFormat(format: Format): BasicTableWriterBuilder = { if (table.format != format) { @@ -123,4 +121,5 @@ class SealedTableWriterBuilder( super.partitionedBy(columns: _*) } + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileRWBuilder.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileRWBuilder.scala index d463f51..32dd55c 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileRWBuilder.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileRWBuilder.scala @@ -8,19 +8,20 @@ import com.damavis.spark.resource.{BasicResourceRW, RWBuilder, ResourceRW} import org.apache.spark.sql.SparkSession object FileRWBuilder { - def apply(path: String, format: Format)( - implicit spark: SparkSession): FileRWBuilder = { + + def apply(path: String, format: Format)(implicit spark: SparkSession): FileRWBuilder = { val readParams = FileReaderParameters(format, path) val writeParams = FileWriterParameters(format, path) new FileRWBuilder(readParams, writeParams) } + } -class FileRWBuilder( - readParams: FileReaderParameters, - writeParams: FileWriterParameters)(implicit spark: SparkSession) - extends RWBuilder { +class FileRWBuilder(readParams: FileReaderParameters, writeParams: FileWriterParameters)( + implicit spark: SparkSession) + extends RWBuilder { + override def build(): ResourceRW = { val reader = FileReaderBuilder(readParams).reader() val writer = FileWriterBuilder(writeParams).writer() @@ -38,8 +39,7 @@ class FileRWBuilder( readParams.copy(from = Some(from), to = Some(to)) val newWriteParams = - writeParams.copy( - columnNames = newReadParams.partitionFormatter.columnNames) + writeParams.copy(columnNames = newReadParams.partitionFormatter.columnNames) new FileRWBuilder(newReadParams, newWriteParams) } @@ -48,12 +48,12 @@ class FileRWBuilder( val newReadParams = readParams.copy(partitionFormatter = formatter) val newWriteParams = - writeParams.copy( - columnNames = newReadParams.partitionFormatter.columnNames) + writeParams.copy(columnNames = newReadParams.partitionFormatter.columnNames) new FileRWBuilder(newReadParams, newWriteParams) } def writeMode(mode: String): FileRWBuilder = new FileRWBuilder(readParams, writeParams.copy(mode = mode)) + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReader.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReader.scala index 0e6cdfe..d09c339 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReader.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReader.scala @@ -5,7 +5,8 @@ import com.damavis.spark.resource.ResourceReader import org.apache.spark.sql.{DataFrame, SparkSession} class FileReader(params: FileReaderParameters)(implicit spark: SparkSession) - extends ResourceReader { + extends ResourceReader { + override def read(): DataFrame = { val path = params.path @@ -31,11 +32,12 @@ class FileReader(params: FileReaderParameters)(implicit spark: SparkSession) .map(partition => s"$path/$partition") reader - .option("basePath", path) + .option("basePath", params.options.getOrElse("basePath", path)) .load(partitionsToLoad: _*) } else { reader.load(path) } } + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReaderBuilder.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReaderBuilder.scala index ce4698d..20539f2 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReaderBuilder.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReaderBuilder.scala @@ -9,6 +9,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType object FileReaderBuilder { + def apply(format: Format, path: String)( implicit spark: SparkSession): FileReaderBuilder = { val params = FileReaderParameters(format, path) @@ -18,11 +19,11 @@ object FileReaderBuilder { def apply(params: FileReaderParameters)( implicit spark: SparkSession): FileReaderBuilder = new FileReaderBuilder(params) + } -class FileReaderBuilder(params: FileReaderParameters)( - implicit spark: SparkSession) - extends ReaderBuilder { +class FileReaderBuilder(params: FileReaderParameters)(implicit spark: SparkSession) + extends ReaderBuilder { override def reader(): ResourceReader = { if (params.datePartitioned) @@ -31,8 +32,7 @@ class FileReaderBuilder(params: FileReaderParameters)( new FileReader(params) } - def partitioning( - partitionFormatter: DatePartitionFormatter): FileReaderBuilder = { + def partitioning(partitionFormatter: DatePartitionFormatter): FileReaderBuilder = { val newParams = params.copy(partitionFormatter = partitionFormatter) new FileReaderBuilder(newParams) } @@ -42,6 +42,11 @@ class FileReaderBuilder(params: FileReaderParameters)( new FileReaderBuilder(newParams) } + def option(key: String, value: String): FileReaderBuilder = { + val newParams = params.copy(options = params.options + (key -> value)) + new FileReaderBuilder(newParams) + } + def schema(schema: StructType): FileReaderBuilder = { val newParams = params.copy(schema = Option(schema)) new FileReaderBuilder(newParams) @@ -52,16 +57,14 @@ class FileReaderBuilder(params: FileReaderParameters)( betweenDates(LocalDateTime.of(from, time), LocalDateTime.of(to, time)) } - def betweenDates(from: LocalDateTime, - to: LocalDateTime): FileReaderBuilder = { + def betweenDates(from: LocalDateTime, to: LocalDateTime): FileReaderBuilder = { val newParams = params.copy(from = Some(from), to = Some(to)) new FileReaderBuilder(newParams) } - def partitionDateFormat( - formatter: DatePartitionFormatter): FileReaderBuilder = { + def partitionDateFormat(formatter: DatePartitionFormatter): FileReaderBuilder = { val newParams = params.copy(partitionFormatter = formatter) new FileReaderBuilder(newParams) @@ -73,11 +76,9 @@ class FileReaderBuilder(params: FileReaderParameters)( if (from.isAfter(to)) { val errMsg = - s"""Invalid parameters defined for reading spark object with path: ${params.path}. - |"from" date is after "to" date. - |Dates are: from=$from to=$to - |""".stripMargin + s"""Invalid parameters. path: ${params.path}. "$from" date is after "$to" date.""" throw new RuntimeException(errMsg) } } + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReaderParameters.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReaderParameters.scala index 2749ac8..888043d 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReaderParameters.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileReaderParameters.scala @@ -14,9 +14,7 @@ private[resource] object FileReaderParameters { apply(format, path, DatePartitionFormatter.standard) } - def apply(format: Format, - path: String, - partitionFormatter: DatePartitionFormatter)( + def apply(format: Format, path: String, partitionFormatter: DatePartitionFormatter)( implicit spark: SparkSession): FileReaderParameters = { apply(format, path, partitionFormatter, Map.empty[String, String]) } @@ -26,10 +24,7 @@ private[resource] object FileReaderParameters { partitionFormatter: DatePartitionFormatter, options: Map[String, String])( implicit spark: SparkSession): FileReaderParameters = { - new FileReaderParameters(format, - path, - partitionFormatter, - options = options) + new FileReaderParameters(format, path, partitionFormatter, options = options) } } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileWriter.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileWriter.scala index 8211b5a..b05fbfb 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileWriter.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileWriter.scala @@ -24,15 +24,14 @@ class FileWriter(params: FileWriterParameters) extends ResourceWriter { .save(params.path) } - private def validatePartitioning(data: DataFrame, - columnNames: Seq[String]): Unit = { + private def validatePartitioning(data: DataFrame, columnNames: Seq[String]): Unit = { val fields = data.schema.fieldNames val numberPartitionColumns = columnNames.length if (fields.length < numberPartitionColumns) { val msg = s"""Partitioned DataFrame does not have required partition columns - |These are: "${columnNames.mkString(",")}" + |Columns in existing data: "${columnNames.mkString(",")}" |Columns in DataFrame: "${fields.mkString(",")}" |""".stripMargin throw new FileResourceWriteException(msg) @@ -46,11 +45,12 @@ class FileWriter(params: FileWriterParameters) extends ResourceWriter { if (orderDoesNotMatch) { val msg = s"""Partitioned DataFrame does not have partition columns in the required order - |They should be: "${columnNames.mkString(",")}" + |Order in existing data: "${columnNames.mkString(",")}" |Order in DataFrame: "${partitionColsInSchema.mkString(",")}" |""".stripMargin throw new FileResourceWriteException(msg) } } + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileWriterBuilder.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileWriterBuilder.scala index 3ad436a..17d4111 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileWriterBuilder.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/file/FileWriterBuilder.scala @@ -4,14 +4,17 @@ import com.damavis.spark.resource.Format.Format import com.damavis.spark.resource.{ResourceWriter, WriterBuilder} object FileWriterBuilder { + def apply(format: Format, path: String): FileWriterBuilder = new FileWriterBuilder(FileWriterParameters(format, path)) def apply(params: FileWriterParameters): FileWriterBuilder = new FileWriterBuilder(params) + } class FileWriterBuilder(params: FileWriterParameters) extends WriterBuilder { + override def writer(): ResourceWriter = new FileWriter(params) @@ -27,4 +30,5 @@ class FileWriterBuilder(params: FileWriterParameters) extends WriterBuilder { new FileWriterBuilder(params.copy(columnNames = columnNames)) } + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/partitioning/DatePartitionFormatter.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/partitioning/DatePartitionFormatter.scala index 7626586..667a9b5 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/partitioning/DatePartitionFormatter.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/partitioning/DatePartitionFormatter.scala @@ -1,4 +1,5 @@ package com.damavis.spark.resource.partitioning + import java.time.LocalDateTime import java.time.format.DateTimeFormatter import java.time.temporal.{ChronoField, TemporalUnit} @@ -6,6 +7,7 @@ import java.time.temporal.{ChronoField, TemporalUnit} import scala.collection.JavaConverters._ object DatePartitionFormatter { + protected case class ColumnFormatter(column: String, pattern: String, formatter: DateTimeFormatter) @@ -60,10 +62,12 @@ object DatePartitionFormatter { } new DatePartitionFormatter(cols) } + } class DatePartitionFormatter protected ( columns: Seq[DatePartitionFormatter.ColumnFormatter]) { + def dateToPath(date: LocalDateTime): String = { columns .map { part => @@ -92,4 +96,5 @@ class DatePartitionFormatter protected ( def isYearlyPartitioned: Boolean = columns.nonEmpty && columns.head.pattern == "yyyy" + } diff --git a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/partitioning/DatePartitions.scala b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/partitioning/DatePartitions.scala index 7323666..485c7f0 100644 --- a/damavis-spark-core/src/main/scala/com/damavis/spark/resource/partitioning/DatePartitions.scala +++ b/damavis-spark-core/src/main/scala/com/damavis/spark/resource/partitioning/DatePartitions.scala @@ -5,15 +5,16 @@ import com.damavis.spark.fs.{FileSystem, HadoopFS} import org.apache.spark.sql.SparkSession object DatePartitions { + def apply(root: String, pathGenerator: DatePartitionFormatter)( implicit spark: SparkSession): DatePartitions = { val fs = HadoopFS(root) DatePartitions(fs, pathGenerator) } - def apply(fs: FileSystem, - pathGenerator: DatePartitionFormatter): DatePartitions = + def apply(fs: FileSystem, pathGenerator: DatePartitionFormatter): DatePartitions = new DatePartitions(fs, pathGenerator) + } class DatePartitions(fs: FileSystem, pathGenerator: DatePartitionFormatter) { @@ -52,8 +53,7 @@ class DatePartitions(fs: FileSystem, pathGenerator: DatePartitionFormatter) { } } - private def datesGen(from: LocalDateTime, - to: LocalDateTime): List[LocalDateTime] = { + private def datesGen(from: LocalDateTime, to: LocalDateTime): List[LocalDateTime] = { val minimumTime = pathGenerator.minimumTemporalUnit() def datesGen(acc: List[LocalDateTime], @@ -64,4 +64,5 @@ class DatePartitions(fs: FileSystem, pathGenerator: DatePartitionFormatter) { } datesGen(List(), from, to) } + } diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/SparkAppTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/SparkAppTest.scala index 22d72d3..57688a2 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/SparkAppTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/SparkAppTest.scala @@ -24,10 +24,10 @@ class SparkAppTest extends FlatSpec with SparkApp { "An SparkApp" should "run successfully" in { - import spark.implicits._ - val df = spark.sparkContext.parallelize(List(1, 2, 3)).toDF("number") - assert(df.count() === 3) - } + import spark.implicits._ + val df = spark.sparkContext.parallelize(List(1, 2, 3)).toDF("number") + assert(df.count() === 3) + } it should "create databases in defined warehouse path" in { import spark.implicits._ @@ -42,8 +42,7 @@ class SparkAppTest extends FlatSpec with SparkApp { s"hdfs://localhost:8020${warehouseDir}/test.db/dummy_going_real", Format.Parquet, managed = true, - Column("number", "int", partitioned = false, nullable = true) :: Nil - ) + Column("number", "int", partitioned = false, nullable = true) :: Nil) assert(obtained === expected) } diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/database/DatabaseTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/database/DatabaseTest.scala index db4b59e..7a1cfc6 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/database/DatabaseTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/database/DatabaseTest.scala @@ -26,18 +26,18 @@ class DatabaseTest extends SparkTestBase { assert(db.catalog.listTables().isEmpty) - val tryTable = db.getUnmanagedTable("numbers", - s"/$name/numbers_external", - Format.Parquet) + val tryTable = + db.getUnmanagedTable("numbers", s"/$name/numbers_external", Format.Parquet) assert(tryTable.isSuccess) val table = tryTable.get - val expected = RealTable("test", - "numbers", - s"/$name/numbers_external", - Format.Parquet, - managed = false, - Nil) + val expected = RealTable( + "test", + "numbers", + s"/$name/numbers_external", + Format.Parquet, + managed = false, + Nil) assert(table === expected) assert(db.catalog.listTables().count() == 1) @@ -46,45 +46,47 @@ class DatabaseTest extends SparkTestBase { "fail to get an external table if there is no data" in { val tryTable = db.getUnmanagedTable("numbers", s"/$name/1234", Format.Parquet) - checkExceptionOfType(tryTable, - classOf[TableAccessException], - "Path not reachable") + checkExceptionOfType(tryTable, classOf[TableAccessException], "Path not reachable") } "fail to get an external table if validations do not succeed" in { val numbersDf = (1 :: 2 :: 3 :: 4 :: Nil).toDF("number") - db.catalog.createTable("numbers1", - "parquet", - numbersDf.schema, - Map[String, String]()) + db.catalog.createTable( + "numbers1", + "parquet", + numbersDf.schema, + Map[String, String]()) - val tryTable1 = db.getUnmanagedTable("numbers1", - s"/$name/numbers_external", - Format.Parquet) - checkExceptionOfType(tryTable1, - classOf[TableAccessException], - "already registered as MANAGED") + val tryTable1 = + db.getUnmanagedTable("numbers1", s"/$name/numbers_external", Format.Parquet) + checkExceptionOfType( + tryTable1, + classOf[TableAccessException], + "already registered as MANAGED") //Register an external table, and try to get it again but with wrong parameters numbersDf.write .parquet(s"/$name/numbers_external2") - db.getUnmanagedTable("numbers_wrong_path", - s"/$name/numbers_external2", - Format.Parquet) + db.getUnmanagedTable( + "numbers_wrong_path", + s"/$name/numbers_external2", + Format.Parquet) - val tryTable2 = db.getUnmanagedTable("numbers_wrong_path", - s"/$name/numbers_external", - Format.Parquet) + val tryTable2 = db.getUnmanagedTable( + "numbers_wrong_path", + s"/$name/numbers_external", + Format.Parquet) checkExceptionOfType( tryTable2, classOf[TableAccessException], "It is already registered in the catalog with a different path") - val tryTable3 = db.getUnmanagedTable("numbers_wrong_path", - s"/$name/numbers_external2", - Format.Avro) + val tryTable3 = db.getUnmanagedTable( + "numbers_wrong_path", + s"/$name/numbers_external2", + Format.Avro) checkExceptionOfType( tryTable3, classOf[TableAccessException], @@ -106,8 +108,7 @@ class DatabaseTest extends SparkTestBase { s"hdfs://localhost:8020${warehouseDir.get}/test.db/dummy_going_real", Format.Parquet, managed = true, - Column("number", "int", partitioned = false, nullable = true) :: Nil - ) + Column("number", "int", partitioned = false, nullable = true) :: Nil) assert(obtained === expected) diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/database/DbManagerTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/database/DbManagerTest.scala index 3537be4..7d801ea 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/database/DbManagerTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/database/DbManagerTest.scala @@ -1,6 +1,9 @@ package com.damavis.spark.database -import com.damavis.spark.database.exceptions.{DatabaseNotFoundException, InvalidDatabaseNameException} +import com.damavis.spark.database.exceptions.{ + DatabaseNotFoundException, + InvalidDatabaseNameException +} import com.damavis.spark.utils.{SparkTestBase} import org.apache.spark.sql.functions._ diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/DataFlowBuilderTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/DataFlowBuilderTest.scala index 658c8eb..7f30b82 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/DataFlowBuilderTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/DataFlowBuilderTest.scala @@ -5,12 +5,12 @@ import org.scalatest.WordSpec class DataFlowBuilderTest extends WordSpec { /*TODO - * - empty pipelines are not valid - * - pipelines with no sources are not valid - * - for every source, all executions paths must lead to a target. Otherwise, pipeline is not valid - * - pipelines with loops are not valid - * - pipelines with null pointers along any execution path are not valid - * - choose: pipelines with unreachable stages are valid? Throw an error or just log it instead? - * */ + * - empty pipelines are not valid + * - pipelines with no sources are not valid + * - for every source, all executions paths must lead to a target. Otherwise, pipeline is not valid + * - pipelines with loops are not valid + * - pipelines with null pointers along any execution path are not valid + * - choose: pipelines with unreachable stages are valid? Throw an error or just log it instead? + * */ } diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/JoinDataFlowTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/JoinDataFlowTest.scala index 5d8ebc2..ab7d43e 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/JoinDataFlowTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/JoinDataFlowTest.scala @@ -2,10 +2,7 @@ package com.damavis.spark.dataflow import com.damavis.spark.database.{Database, DbManager} import com.damavis.spark.testdata._ -import com.damavis.spark.resource.datasource.{ - TableReaderBuilder, - TableWriterBuilder -} +import com.damavis.spark.resource.datasource.{TableReaderBuilder, TableWriterBuilder} import com.damavis.spark.utils.{SparkTestBase} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.expressions.Window @@ -32,37 +29,38 @@ class JoinDataFlowTest extends SparkTestBase { TableWriterBuilder(authorsTable).writer().write(authorsData) val booksTable = db.getTable(booksTableName).get - val booksData = dfFromBooks(farewell, - oldMan, - timeMachine, - moreau, - oliverTwist, - expectations) + val booksData = + dfFromBooks(farewell, oldMan, timeMachine, moreau, oliverTwist, expectations) TableWriterBuilder(booksTable).writer().write(booksData) } private val joinAuthorsProcessor = new JoinProcessor { + override def computeImpl(left: DataFrame, right: DataFrame): DataFrame = { left .join(right, left("name") === right("author"), "inner") .select("author", "title", "publicationYear") } + } private val groupByProcessor = new LinealProcessor { + override def computeImpl(data: DataFrame): DataFrame = { val window = Window .partitionBy("author") .orderBy(data("publicationYear") asc) data - .select(col("author"), - col("title"), - col("publicationYear"), - rank().over(window) as "rank") + .select( + col("author"), + col("title"), + col("publicationYear"), + rank().over(window) as "rank") .filter(col("rank") === lit(1)) .drop("rank") } + } "a pipeline with a join" should { @@ -75,15 +73,14 @@ class JoinDataFlowTest extends SparkTestBase { val authorsReader = TableReaderBuilder(authorsTable).reader() val oldBookWriter = TableWriterBuilder(oldestBooksTable).writer() - val pipeline = DataFlowBuilder.create { - implicit definition: DataFlowDefinition => - import implicits._ + val pipeline = DataFlowBuilder.create { implicit definition: DataFlowDefinition => + import implicits._ - val joinStage = new DataFlowStage(joinAuthorsProcessor) - val authorOldestBook = new DataFlowStage(groupByProcessor) + val joinStage = new DataFlowStage(joinAuthorsProcessor) + val authorOldestBook = new DataFlowStage(groupByProcessor) - authorsReader -> joinStage.left -> authorOldestBook -> oldBookWriter - booksReader -> joinStage.right + authorsReader -> joinStage.left -> authorOldestBook -> oldBookWriter + booksReader -> joinStage.right } pipeline.run() @@ -100,8 +97,7 @@ class JoinDataFlowTest extends SparkTestBase { val schema = StructType( StructField("author", StringType, nullable = true) :: StructField("title", StringType, nullable = true) :: - StructField("publicationYear", IntegerType, nullable = true) :: Nil - ) + StructField("publicationYear", IntegerType, nullable = true) :: Nil) val expected = session.createDataFrame(expectedData, schema) assertDataFrameEquals(generated, expected) diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/utils/package.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/utils/package.scala index 5890313..329f023 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/utils/package.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/dataflow/utils/package.scala @@ -15,6 +15,7 @@ package object utils { new DataFlowSource(processor) } + } } diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/pipeline/PipelineTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/pipeline/PipelineTest.scala index bb40278..1cc3ccd 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/pipeline/PipelineTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/pipeline/PipelineTest.scala @@ -30,9 +30,10 @@ class PipelineTest extends SparkTestBase { val nationalitiesTable = db.getTable("nationalities").get val inTable = - db.getUnmanagedTable("external_authors_table", - s"/$name/external-authors", - Format.Parquet) + db.getUnmanagedTable( + "external_authors_table", + s"/$name/external-authors", + Format.Parquet) .get val extractNationality = new PipelineStage { @@ -61,9 +62,7 @@ class PipelineTest extends SparkTestBase { StructType( StructField("nationality", StringType, nullable = true) :: StructField("count", LongType, nullable = true) :: - Nil - ) - ) + Nil)) assertDataFrameEquals(written, expectedDf) } diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/DeltaTableWriterBuilderTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/DeltaTableWriterBuilderTest.scala index b6c09c7..3f5f1a2 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/DeltaTableWriterBuilderTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/DeltaTableWriterBuilderTest.scala @@ -21,8 +21,7 @@ class DeltaTableWriterBuilderTest extends SparkTestBase { val sink = TableWriterBuilder(table) .withFormat(Format.Parquet) .partitionedBy("nationality") - .overwritePartitionBehavior( - OverwritePartitionBehavior.OVERWRITE_MATCHING) + .overwritePartitionBehavior(OverwritePartitionBehavior.OVERWRITE_MATCHING) .writer() sink.write(authors) @@ -38,8 +37,7 @@ class DeltaTableWriterBuilderTest extends SparkTestBase { val sink2 = TableWriterBuilder(table2) .withFormat(Format.Delta) .pk(Seq("name")) - .overwritePartitionBehavior( - OverwritePartitionBehavior.OVERWRITE_MATCHING) + .overwritePartitionBehavior(OverwritePartitionBehavior.OVERWRITE_MATCHING) .writer() sink2.write(authors) diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/TableResourceReaderTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/TableResourceReaderTest.scala index 73b32d3..63e4794 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/TableResourceReaderTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/TableResourceReaderTest.scala @@ -32,10 +32,11 @@ class TableResourceReaderTest extends SparkTestBase { } "read successfully a managed table registered in the catalog" in { - session.catalog.createTable("uk_authors", - "parquet", - authorsSchema, - Map[String, String]()) + session.catalog.createTable( + "uk_authors", + "parquet", + authorsSchema, + Map[String, String]()) val authors = dfFromAuthors(dickens, wells) authors.write.mode(SaveMode.Overwrite).saveAsTable("uk_authors") @@ -45,8 +46,7 @@ class TableResourceReaderTest extends SparkTestBase { val obtained = TableReaderBuilder(tryTable.get).reader().read() - assertDataFrameEquals(obtained.sort("birthDate"), - authors.sort("birthDate")) + assertDataFrameEquals(obtained.sort("birthDate"), authors.sort("birthDate")) } "fail to read a table not yet present in the catalog" in { diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/TableResourceWriterTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/TableResourceWriterTest.scala index e027f53..160bca9 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/TableResourceWriterTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/resource/datasource/TableResourceWriterTest.scala @@ -59,8 +59,7 @@ class TableResourceWriterTest extends SparkTestBase { assert(before == after) val written = session.read.table(tableName) - assertDataFrameEquals(written.sort("birthDate"), - personDf.sort("birthDate")) + assertDataFrameEquals(written.sort("birthDate"), personDf.sort("birthDate")) } "there is no partitioning" should { @@ -96,8 +95,7 @@ class TableResourceWriterTest extends SparkTestBase { assert(written.count() == 2) val expectedDf = dfFromAuthors(hemingway, wells) - assertDataFrameEquals(written.sort("birthDate"), - expectedDf.sort("birthDate")) + assertDataFrameEquals(written.sort("birthDate"), expectedDf.sort("birthDate")) } "apply properly overwrite save mode" in { @@ -179,8 +177,7 @@ class TableResourceWriterTest extends SparkTestBase { val table = nextTable() val writer = TableWriterBuilder(table) .partitionedBy("nationality") - .overwritePartitionBehavior( - OverwritePartitionBehavior.OVERWRITE_MATCHING) + .overwritePartitionBehavior(OverwritePartitionBehavior.OVERWRITE_MATCHING) .writer() val authors = dfFromAuthors(hemingway, wells) @@ -214,8 +211,9 @@ class TableResourceWriterTest extends SparkTestBase { val expectedAuthors = dfFromAuthors(hemingway, wells, bradbury) assert(finalDf.count() == 3) - assertDataFrameEquals(finalDf.sort("birthDate"), - expectedAuthors.sort("birthDate")) + assertDataFrameEquals( + finalDf.sort("birthDate"), + expectedAuthors.sort("birthDate")) } } @@ -233,20 +231,17 @@ class TableResourceWriterTest extends SparkTestBase { StructField("name", StringType, nullable = true) :: StructField("nationality", StringType, nullable = true) :: StructField("deceaseAge", IntegerType, nullable = true) :: - Nil - ) + Nil) - val dickensList = (Row(dickens.name, - dickens.nationality, - dickens.deceaseAge) :: Nil).asJava + val dickensList = + (Row(dickens.name, dickens.nationality, dickens.deceaseAge) :: Nil).asJava val newDf = session.createDataFrame(dickensList, anotherSchema) val ex = intercept[TableAccessException] { writer.write(newDf) } - assert( - ex.getMessage.contains("does not have columns in required order")) + assert(ex.getMessage.contains("does not have columns in required order")) } } } diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/resource/file/FileReaderTest.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/resource/file/FileReaderTest.scala index 0b6df12..09b6320 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/resource/file/FileReaderTest.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/resource/file/FileReaderTest.scala @@ -86,8 +86,7 @@ class FileReaderTest extends SparkTestBase { StructField("year", IntegerType, nullable = true) :: StructField("month", IntegerType, nullable = true) :: StructField("day", IntegerType, nullable = true) :: - Nil - ) + Nil) val expected = spark.createDataFrame(rows, schema) assertDataFrameEquals(actorsFromTwo, expected) @@ -119,8 +118,8 @@ class FileReaderTest extends SparkTestBase { Log("::1", Timestamp.valueOf("2020-01-01 20:00:00"), "DEBUG", "20"), Log("::1", Timestamp.valueOf("2020-01-01 21:00:00"), "DEBUG", "21"), Log("::1", Timestamp.valueOf("2020-01-01 22:00:00"), "DEBUG", "22"), - Log("::1", Timestamp.valueOf("2020-01-01 23:00:00"), "DEBUG", "23") - ).withColumn("year", date_format(col("ts"), "yyyy")) + Log("::1", Timestamp.valueOf("2020-01-01 23:00:00"), "DEBUG", "23")) + .withColumn("year", date_format(col("ts"), "yyyy")) .withColumn("month", date_format(col("ts"), "MM")) .withColumn("day", date_format(col("ts"), "dd")) .withColumn("hour", date_format(col("ts"), "H")) diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/testdata/package.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/testdata/package.scala index 530288a..7444e4f 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/testdata/package.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/testdata/package.scala @@ -6,12 +6,7 @@ import java.util.concurrent.atomic.AtomicInteger import com.damavis.spark.database.{Database, Table} import com.damavis.spark.entities.{Author, Book, Log} -import org.apache.spark.sql.types.{ - IntegerType, - StringType, - StructField, - StructType -} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SparkSession} import collection.JavaConverters._ @@ -20,14 +15,19 @@ package object testdata { val hemingway: Author = Author("Hemingway", 61, LocalDate.parse("1899-07-21"), "USA") + val wells: Author = Author("H.G. Wells", 79, LocalDate.parse("1866-09-21"), "UK") + val dickens: Author = Author("Dickens", 58, LocalDate.parse("1812-02-07"), "UK") + val bradbury: Author = Author("Ray Bradbury", 91, LocalDate.parse("1920-08-22"), "USA") + val dumas: Author = Author("Alexandre Dumas", 68, LocalDate.parse("1802-07-24"), "FR") + val hugo: Author = Author("Victor Hugo", 83, LocalDate.parse("1802-02-26"), "FR") @@ -44,16 +44,15 @@ package object testdata { StructField("deceaseAge", IntegerType, nullable = true) :: StructField("birthDate", StringType, nullable = true) :: StructField("nationality", StringType, nullable = true) :: - Nil - ) + Nil) - def dfFromAuthors(authors: Author*)( - implicit session: SparkSession): DataFrame = { + def dfFromAuthors(authors: Author*)(implicit session: SparkSession): DataFrame = { def rowFromAuthor(author: Author): Row = { - Row(author.name, - author.deceaseAge, - author.birthDate.format(ISO_LOCAL_DATE), - author.nationality) + Row( + author.name, + author.deceaseAge, + author.birthDate.format(ISO_LOCAL_DATE), + author.nationality) } val data = authors.map(rowFromAuthor).asJava @@ -65,8 +64,7 @@ package object testdata { StructField("title", StringType, nullable = true) :: StructField("publicationYear", IntegerType, nullable = true) :: StructField("author", StringType, nullable = false) :: - Nil - ) + Nil) def rowFromBook(book: Book): Row = { Row(book.title, book.publicationYear, book.author) @@ -91,4 +89,5 @@ package object testdata { tryTable.get } + } diff --git a/damavis-spark-core/src/test/scala/com/damavis/spark/utils/SparkTestBase.scala b/damavis-spark-core/src/test/scala/com/damavis/spark/utils/SparkTestBase.scala index dff894b..0ca6b80 100644 --- a/damavis-spark-core/src/test/scala/com/damavis/spark/utils/SparkTestBase.scala +++ b/damavis-spark-core/src/test/scala/com/damavis/spark/utils/SparkTestBase.scala @@ -11,7 +11,7 @@ import org.apache.spark.sql.SparkSession import org.scalatest.WordSpec class SparkTestBase - extends WordSpec + extends WordSpec with DataFrameSuiteBase with SparkTestSupport with SparkContextProvider @@ -33,8 +33,9 @@ class SparkTestBase .setMaster("local[*]") .set("spark.sql.catalogImplementation", "hive") .set("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") - .set("spark.sql.catalog.spark_catalog", - "org.apache.spark.sql.delta.catalog.DeltaCatalog") + .set( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog") .set("spark.hadoop.fs.default.name", HDFSCluster.uri) .set("spark.sql.warehouse.dir", warehouseConf) // Ignored by Holden Karau } diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReader.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReader.scala index 4e736ee..92154fe 100644 --- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReader.scala +++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReader.scala @@ -5,10 +5,11 @@ import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} case class SnowflakeJoinReader(reader: SnowflakeReader, joinData: DataFrame)( implicit spark: SparkSession) - extends ResourceReader { + extends ResourceReader { - assert(reader.table.isDefined, - "SnowflakeJoinReader only accept table reads, not queries.") + assert( + reader.table.isDefined, + "SnowflakeJoinReader only accept table reads, not queries.") val stagingTable = s"join_tmp__${reader.table.get}" val targetTable = s"${reader.table.get}" @@ -28,14 +29,16 @@ case class SnowflakeJoinReader(reader: SnowflakeReader, joinData: DataFrame)( } override def read(): DataFrame = { - SnowflakeWriter(reader.account, - reader.user, - reader.password, - reader.warehouse, - reader.database, - reader.schema, - stagingTable, - SaveMode.Overwrite).write(joinData) + SnowflakeWriter( + reader.account, + reader.user, + reader.password, + reader.warehouse, + reader.database, + reader.schema, + stagingTable, + SaveMode.Overwrite, + reader.sfExtraOptions).write(joinData) reader.copy(table = None, query = Some(query)).read() } diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeMerger.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeMerger.scala index 0fe16fd..a10d683 100644 --- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeMerger.scala +++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeMerger.scala @@ -3,16 +3,17 @@ package com.damavis.spark.resource.datasource.snowflake import net.snowflake.spark.snowflake.Utils import org.apache.spark.sql.SparkSession -case class SnowflakeMerger( - account: String, - user: String, - password: String, - warehouse: String, - database: String, - schema: String, - sourceTable: String, - targetTable: String, - pkColumns: Seq[String])(implicit spark: SparkSession) { +case class SnowflakeMerger(account: String, + user: String, + password: String, + warehouse: String, + database: String, + schema: String, + sourceTable: String, + targetTable: String, + pkColumns: Seq[String], + sfExtraOptions: Map[String, String] = Map())( + implicit spark: SparkSession) { val sfOptions = Map( "sfURL" -> s"${account}.snowflakecomputing.com", @@ -22,8 +23,7 @@ case class SnowflakeMerger( "sfSchema" -> schema, "sfWarehouse" -> warehouse, "sfCompress" -> "on", - "sfSSL" -> "on" - ) + "sfSSL" -> "on") private def mergeExpression(pkColumns: Seq[String]): String = { pkColumns @@ -40,7 +40,7 @@ case class SnowflakeMerger( |WHEN MATCHED THEN DELETE |""".stripMargin - Utils.runQuery(sfOptions, deleteQuery) + Utils.runQuery(sfOptions ++ sfExtraOptions, deleteQuery) } } diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeReader.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeReader.scala index b830780..12b0a19 100644 --- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeReader.scala +++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeReader.scala @@ -3,23 +3,25 @@ package com.damavis.spark.resource.datasource.snowflake import com.damavis.spark.resource.ResourceReader import org.apache.spark.sql.{DataFrame, SparkSession} -case class SnowflakeReader( - account: String, - user: String, - password: String, - warehouse: String, - database: String, - schema: String, - table: Option[String] = None, - query: Option[String] = None)(implicit spark: SparkSession) - extends ResourceReader { +case class SnowflakeReader(account: String, + user: String, + password: String, + warehouse: String, + database: String, + schema: String, + table: Option[String] = None, + query: Option[String] = None, + sfExtraOptions: Map[String, String] = Map())( + implicit spark: SparkSession) + extends ResourceReader { val settings = (table, query) match { case (Some(tableName), None) => ("dbtable", tableName) case (None, Some(queryBody)) => ("query", queryBody) case (Some(_), Some(_)) => - throw new IllegalArgumentException( - "SnowflakeReader cannot read table and query.") + throw new IllegalArgumentException("SnowflakeReader cannot read table and query.") + case (None, None) => + throw new IllegalArgumentException("SnowflakeReader cannot read table and query.") } val sfOptions = Map( @@ -30,13 +32,13 @@ case class SnowflakeReader( "sfSchema" -> schema, "sfWarehouse" -> warehouse, "sfCompress" -> "on", - "sfSSL" -> "on" - ) + settings + "sfSSL" -> "on") + settings override def read(): DataFrame = { spark.read .format("net.snowflake.spark.snowflake") - .options(sfOptions) + .options(sfOptions ++ sfExtraOptions) .load() } + } diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriter.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriter.scala index e6a336f..76a8263 100644 --- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriter.scala +++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriter.scala @@ -3,17 +3,17 @@ package com.damavis.spark.resource.datasource.snowflake import com.damavis.spark.resource.ResourceWriter import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} -case class SnowflakeWriter( - account: String, - user: String, - password: String, - warehouse: String, - database: String, - schema: String, - table: String, - mode: SaveMode = SaveMode.Ignore, - preScript: Option[String] = None)(implicit spark: SparkSession) - extends ResourceWriter { +case class SnowflakeWriter(account: String, + user: String, + password: String, + warehouse: String, + database: String, + schema: String, + table: String, + mode: SaveMode = SaveMode.Ignore, + sfExtraOptions: Map[String, String] = Map(), + preScript: Option[String] = None)(implicit spark: SparkSession) + extends ResourceWriter { val sfOptions = Map( "sfURL" -> s"${account}.snowflakecomputing.com", @@ -24,13 +24,12 @@ case class SnowflakeWriter( "sfWarehouse" -> warehouse, "dbtable" -> table, "sfCompress" -> "on", - "sfSSL" -> "on" - ) + "sfSSL" -> "on") override def write(data: DataFrame): Unit = { data.write .format("net.snowflake.spark.snowflake") - .options(sfOptions) + .options(sfOptions ++ sfExtraOptions) .mode(mode) .save() } diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMerger.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMerger.scala index 0ce6fb3..bb4b5c3 100644 --- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMerger.scala +++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMerger.scala @@ -1,12 +1,25 @@ package com.damavis.spark.resource.datasource.snowflake import com.damavis.spark.resource.ResourceWriter +import net.snowflake.spark.snowflake.Utils import org.apache.spark.sql.functions.col import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} -case class SnowflakeWriterMerger(writer: SnowflakeWriter, columns: Seq[String])( +/** + * Merge writer data to target table by specified columns + * + * @param writer Snowflake writer + * @param columns columns used to merge on + * @param stagingSchema staging database schema + * @param deleteStagingTable Delete staging table if true + * @param spark SparkSession object + */ +case class SnowflakeWriterMerger(writer: SnowflakeWriter, + columns: Seq[String], + stagingSchema: Option[String] = None, + deleteStagingTable: Boolean = false)( implicit spark: SparkSession) - extends ResourceWriter { + extends ResourceWriter { val stagingTable = s"merge_tmp_delta__${writer.table}" val targetTable = s"${writer.table}" @@ -25,15 +38,19 @@ case class SnowflakeWriterMerger(writer: SnowflakeWriter, columns: Seq[String])( .copy(table = stagingTable, mode = SaveMode.Overwrite) .write(data.select(columns.map(col): _*).distinct()) - SnowflakeMerger(writer.account, - writer.user, - writer.password, - writer.warehouse, - writer.database, - writer.schema, - stagingTable, - targetTable, - columns).merge() + SnowflakeMerger( + writer.account, + writer.user, + writer.password, + writer.warehouse, + writer.database, + stagingSchema.getOrElse(writer.schema), + stagingTable, + targetTable, + columns, + writer.sfExtraOptions).merge() + + if (deleteStagingTable) dropStagingTable() } private def targetExists(): Boolean = { @@ -44,9 +61,9 @@ case class SnowflakeWriterMerger(writer: SnowflakeWriter, columns: Seq[String])( writer.warehouse, writer.database, "INFORMATION_SCHEMA", - query = Some( - s"SELECT COUNT(1) = 1 FROM TABLES WHERE TABLE_NAME = '${targetTable}'") - ) + query = + Some(s"SELECT COUNT(1) = 1 FROM TABLES WHERE TABLE_NAME = '${targetTable}'"), + sfExtraOptions = writer.sfExtraOptions) reader .read() @@ -55,4 +72,12 @@ case class SnowflakeWriterMerger(writer: SnowflakeWriter, columns: Seq[String])( .head } + private def dropStagingTable(): Unit = { + + val deleteSourceTableQuery = + s"DROP TABLE IF EXISTS $stagingTable RESTRICT" + + Utils.runQuery(writer.sfOptions, deleteSourceTableQuery) + } + } diff --git a/damavis-spark-snowflake/src/test/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReaderTest.scala b/damavis-spark-snowflake/src/test/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReaderTest.scala index 6c02585..1a390f6 100644 --- a/damavis-spark-snowflake/src/test/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReaderTest.scala +++ b/damavis-spark-snowflake/src/test/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReaderTest.scala @@ -17,17 +17,16 @@ class SnowflakeJoinReaderTest extends WordSpec with DataFrameSuiteBase { "is performed it" should { "filter reader with given data" ignore { import spark.implicits._ - val data = Seq( - (Date.valueOf("2020-01-01")), - ).toDF("dt") + val data = Seq((Date.valueOf("2020-01-01"))).toDF("dt") - val reader = SnowflakeReader(account, - user, - password, - warehouse, - db, - "PUBLIC", - Some("MY_TEST_TABLE"))(spark) + val reader = SnowflakeReader( + account, + user, + password, + warehouse, + db, + "PUBLIC", + Some("MY_TEST_TABLE"))(spark) val joinReader = SnowflakeJoinReader(reader, data)(spark) assert(joinReader.read().count() == 1) } diff --git a/damavis-spark-snowflake/src/test/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMergerTest.scala b/damavis-spark-snowflake/src/test/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMergerTest.scala index 98be960..4c04c3c 100644 --- a/damavis-spark-snowflake/src/test/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMergerTest.scala +++ b/damavis-spark-snowflake/src/test/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMergerTest.scala @@ -20,16 +20,16 @@ class SnowflakeWriterMergerTest extends WordSpec with DataFrameSuiteBase { val data = Seq( (1, "user1", Date.valueOf("2020-01-02")), (2, "user1", Date.valueOf("2020-01-02")), - (3, "user1", Date.valueOf("2020-01-01")) - ).toDF("id", "username", "dt") + (3, "user1", Date.valueOf("2020-01-01"))).toDF("id", "username", "dt") - val writer = SnowflakeWriter(account, - user, - password, - warehouse, - db, - "PUBLIC", - "MY_TEST_TABLE")(spark) + val writer = SnowflakeWriter( + account, + user, + password, + warehouse, + db, + "PUBLIC", + "MY_TEST_TABLE")(spark) val merger = SnowflakeWriterMerger(writer, Seq("dt"))(spark) merger.write(data) } diff --git a/project/Publish.scala b/project/Publish.scala index ef1f125..8cf6dbe 100644 --- a/project/Publish.scala +++ b/project/Publish.scala @@ -5,8 +5,10 @@ object Publish { val PASSWORD: String = sys.env.getOrElse("PASSWORD", "") val credentials = - Credentials("Sonatype Nexus Repository Manager", - "oss.sonatype.org", - USERNAME, - PASSWORD) + Credentials( + "Sonatype Nexus Repository Manager", + "oss.sonatype.org", + USERNAME, + PASSWORD) + } diff --git a/project/plugins.sbt b/project/plugins.sbt index 22d9f8f..82717e6 100755 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -5,3 +5,5 @@ addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.10") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.8.1") addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.1") + +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6") diff --git a/scalastyle.xml b/scalastyle.xml new file mode 100644 index 0000000..df56d78 --- /dev/null +++ b/scalastyle.xml @@ -0,0 +1,202 @@ + + Scalastyle configuration based on Scala style guide + + + + + + + + + 100 + 2 + + + + + + + + + + ^[a-z][a-z\._]*$ + + + + + + + + + + ^[A-Z][A-Za-z]*$ + + + + + + + + + + + + + COLON, DOT, COMMA + + + + + + + COLON, COMMA, EQUALS, + + IF, FOR, WHILE, + LBRACE, RBRACE + + + + + + DOT + + + + + EQUALS, LBRACE, RBRACE + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +