Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23228][PYSPARK] Add Python Created jsparkSession to JVM's defaultSession #20404

Closed
wants to merge 10 commits into from
10 changes: 9 additions & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,12 @@ def __init__(self, sparkContext, jsparkSession=None):
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
if jsparkSession is None:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
if self._jvm.SparkSession.getDefaultSession().isDefined() \
and not self._jvm.SparkSession.getDefaultSession().get() \
.sparkContext().isStopped():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this change at 4ba3aa2 is enough to fix the previous test failure (ERROR: test_sparksession_with_stopped_sparkcontext (pyspark.sql.tests.SQLTests2)) and we can revert moving self._jvm.SparkSession.clearDefaultSession() to SparkContext.stop() at
0319fa5 now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. I will do it.

jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
else:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
self._jsparkSession = jsparkSession
self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
Expand All @@ -225,6 +230,7 @@ def __init__(self, sparkContext, jsparkSession=None):
if SparkSession._instantiatedSession is None \
or SparkSession._instantiatedSession._sc._jsc is None:
SparkSession._instantiatedSession = self
self._jvm.SparkSession.setDefaultSession(self._jsparkSession)

def _repr_html_(self):
return """
Expand Down Expand Up @@ -759,6 +765,8 @@ def stop(self):
"""Stop the underlying :class:`SparkContext`.
"""
self._sc.stop()
# We should clean the default session up. See SPARK-23228.
self._jvm.SparkSession.clearDefaultSession()
SparkSession._instantiatedSession = None

@since(2.0)
Expand Down
28 changes: 27 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings
from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
from pyspark.sql.types import _merge_type
from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
Expand Down Expand Up @@ -2912,6 +2912,32 @@ def test_sparksession_with_stopped_sparkcontext(self):
sc.stop()


class SparkSessionTests(PySparkTestCase):

# This test is separate because it's closely related with session's start and stop.
# See SPARK-23228.
def test_set_jvm_default_session(self):
spark = SparkSession.builder.getOrCreate()
try:
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
finally:
spark.stop()
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())

def test_jvm_default_session_already_set(self):
# Here, we assume there is the default session already set in JVM.
jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc())
self.sc._jvm.SparkSession.setDefaultSession(jsession)

spark = SparkSession.builder.getOrCreate()
try:
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
# The session should be the same with the exiting one.
self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
finally:
spark.stop()


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
if SparkSession._instantiatedSession is not None:
Expand Down