Skip to content

Commit

Permalink
Fixes #82 - added auxilary constructor and registerBody method
Browse files Browse the repository at this point in the history
  • Loading branch information
TebaleloS committed Mar 9, 2023
1 parent 5976fec commit e4daeaf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package za.co.absa.spark.commons

import org.apache.spark.sql.SparkSession
import org.fusesource.hawtjni.runtime.Library

import java.util.concurrent.ConcurrentHashMap

Expand All @@ -32,11 +33,18 @@ import java.util.concurrent.ConcurrentHashMap
*
* @param sparkToRegisterTo Spark session to which we wish to attach objects
*/
abstract class OncePerSparkSession()(implicit sparkToRegisterTo: SparkSession) extends Serializable {
abstract class OncePerSparkSession() extends Serializable {

protected def register(implicit spark: SparkSession): Unit
def this()(implicit sparkToRegisterTo: SparkSession) = {
this()
register(sparkToRegisterTo)
}
def register(implicit spark: SparkSession): Unit = {
OncePerSparkSession.registerMe(this, spark)
}

protected def registerBody(spark: SparkSession): Unit

OncePerSparkSession.registerMe(this, sparkToRegisterTo)
}

object OncePerSparkSession {
Expand All @@ -54,6 +62,7 @@ object OncePerSparkSession {
protected def registerMe(library: OncePerSparkSession, spark: SparkSession): Unit = {
// the function is `protected` to make it visible to `ScalaDoc`
Option(registry.putIfAbsent(makeKey(library, spark), Unit))
.getOrElse(library.register(spark))
.getOrElse(library.registerBody(spark))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class OncePerSparkSessionTest extends AnyFunSuite with MockitoSugar with SparkTe
var libraryBInitCounter = 0

val anotherSpark: SparkSession = mock[SparkSession]
class UDFLibraryA()(implicit sparkToRegisterTo: SparkSession) extends OncePerSparkSession()(sparkToRegisterTo) {
override protected def register(implicit spark: SparkSession): Unit = {
class UDFLibraryA()(implicit sparkToRegisterTo: Option[SparkSession]) extends OncePerSparkSession()(sparkToRegisterTo) {
override def register(implicit spark: Option[SparkSession]): Unit = {
libraryAInitCounter += 1
}
}
Expand Down

0 comments on commit e4daeaf

Please sign in to comment.