From ff9e9b930226aef637633e9888e89e77377cce59 Mon Sep 17 00:00:00 2001
From: Victor Prats <victor.prats@aneior.com>
Date: Mon, 27 Sep 2021 21:01:34 +0200
Subject: [PATCH] Added snowflake extra options to resources

---
 .../datasource/snowflake/SnowflakeJoinReader.scala       | 3 ++-
 .../resource/datasource/snowflake/SnowflakeMerger.scala  | 5 +++--
 .../resource/datasource/snowflake/SnowflakeReader.scala  | 7 ++++---
 .../resource/datasource/snowflake/SnowflakeWriter.scala  | 5 +++--
 .../datasource/snowflake/SnowflakeWriterMerger.scala     | 9 ++++++---
 5 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReader.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReader.scala
index a938bbb..92154fe 100644
--- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReader.scala
+++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeJoinReader.scala
@@ -37,7 +37,8 @@ case class SnowflakeJoinReader(reader: SnowflakeReader, joinData: DataFrame)(
       reader.database,
       reader.schema,
       stagingTable,
-      SaveMode.Overwrite).write(joinData)
+      SaveMode.Overwrite,
+      reader.sfExtraOptions).write(joinData)
     reader.copy(table = None, query = Some(query)).read()
   }
 
diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeMerger.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeMerger.scala
index 3a3d23c..30753af 100644
--- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeMerger.scala
+++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeMerger.scala
@@ -11,7 +11,8 @@ case class SnowflakeMerger(account: String,
                            schema: String,
                            sourceTable: String,
                            targetTable: String,
-                           pkColumns: Seq[String])(implicit spark: SparkSession) {
+                           pkColumns: Seq[String],
+                           sfExtraOptions: Map[String, String] = Map())(implicit spark: SparkSession) {
 
   val sfOptions = Map(
     "sfURL" -> s"${account}.snowflakecomputing.com",
@@ -38,7 +39,7 @@ case class SnowflakeMerger(account: String,
         |WHEN MATCHED THEN DELETE
         |""".stripMargin
 
-    Utils.runQuery(sfOptions, deleteQuery)
+    Utils.runQuery(sfOptions ++ sfExtraOptions, deleteQuery)
   }
 
 }
diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeReader.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeReader.scala
index 79df9f2..acbfe45 100644
--- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeReader.scala
+++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeReader.scala
@@ -10,8 +10,9 @@ case class SnowflakeReader(account: String,
                            database: String,
                            schema: String,
                            table: Option[String] = None,
-                           query: Option[String] = None)(implicit spark: SparkSession)
-  extends ResourceReader {
+                           query: Option[String] = None,
+                           sfExtraOptions: Map[String, String] = Map())(implicit spark: SparkSession)
+    extends ResourceReader {
 
   val settings = (table, query) match {
     case (Some(tableName), None) => ("dbtable", tableName)
@@ -35,7 +36,7 @@ case class SnowflakeReader(account: String,
   override def read(): DataFrame = {
     spark.read
       .format("net.snowflake.spark.snowflake")
-      .options(sfOptions)
+      .options(sfOptions ++ sfExtraOptions)
       .load()
   }
 
diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriter.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriter.scala
index a10efc5..816e969 100644
--- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriter.scala
+++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriter.scala
@@ -11,8 +11,9 @@ case class SnowflakeWriter(account: String,
                            schema: String,
                            table: String,
                            mode: SaveMode = SaveMode.Ignore,
+                           sfExtraOptions: Map[String, String] = Map(),
                            preScript: Option[String] = None)(implicit spark: SparkSession)
-  extends ResourceWriter {
+    extends ResourceWriter {
 
   val sfOptions = Map(
     "sfURL" -> s"${account}.snowflakecomputing.com",
@@ -28,7 +29,7 @@ case class SnowflakeWriter(account: String,
   override def write(data: DataFrame): Unit = {
     data.write
       .format("net.snowflake.spark.snowflake")
-      .options(sfOptions)
+      .options(sfOptions ++ sfExtraOptions)
       .mode(mode)
       .save()
   }
diff --git a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMerger.scala b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMerger.scala
index e45796c..5e3528b 100644
--- a/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMerger.scala
+++ b/damavis-spark-snowflake/src/main/scala/com/damavis/spark/resource/datasource/snowflake/SnowflakeWriterMerger.scala
@@ -34,7 +34,8 @@ case class SnowflakeWriterMerger(writer: SnowflakeWriter, columns: Seq[String])(
       writer.schema,
       stagingTable,
       targetTable,
-      columns).merge()
+      columns,
+      writer.sfExtraOptions).merge()
   }
 
   private def targetExists(): Boolean = {
@@ -45,8 +46,10 @@ case class SnowflakeWriterMerger(writer: SnowflakeWriter, columns: Seq[String])(
       writer.warehouse,
       writer.database,
       "INFORMATION_SCHEMA",
-      query =
-        Some(s"SELECT COUNT(1) = 1 FROM TABLES WHERE TABLE_NAME = '${targetTable}'"))
+      query = Some(
+        s"SELECT COUNT(1) = 1 FROM TABLES WHERE TABLE_NAME = '${targetTable}'"),
+      sfExtraOptions = writer.sfExtraOptions
+    )
 
     reader
       .read()