Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23228][PYSPARK] Add Python Created jsparkSession to JVM's defaultSession #20404

Closed
wants to merge 10 commits into from
9 changes: 7 additions & 2 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ def __init__(self, sparkContext, jsparkSession=None):
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
if jsparkSession is None:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
if self._jvm.SparkSession.getDefaultSession().isDefined():
jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
else:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
self._jsparkSession = jsparkSession
self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
Expand All @@ -225,7 +228,8 @@ def __init__(self, sparkContext, jsparkSession=None):
if SparkSession._instantiatedSession is None \
or SparkSession._instantiatedSession._sc._jsc is None:
SparkSession._instantiatedSession = self
self._jvm.org.apache.spark.sql.SparkSession.setDefaultSession(self._jsparkSession)
if self._jvm.SparkSession.getDefaultSession().isEmpty():
self._jvm.SparkSession.setDefaultSession(self._jsparkSession)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can simply overwrite the default session.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felixcheung has concern about simply overwriting the default session.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might miss something, but I guess @felixcheung's concern was fixed by checking if the default session is defined and not stopped so we can put the valid session or the same session from JVM without checking anymore.
But I'm okay to leave it as it is as well.


def _repr_html_(self):
return """
Expand Down Expand Up @@ -760,6 +764,7 @@ def stop(self):
"""Stop the underlying :class:`SparkContext`.
"""
self._sc.stop()
self._jvm.SparkSession.clearDefaultSession()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm.. If we didn't set it in L231, perhaps we shouldn't clear it?
Or if we are picking up the JVM one in L217, we shouldn't clear it either?

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, since we already stop the jvm SparkContext with above line, it is not necessary to keep it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, let me make a PR to your branch @jerryshao to deal with the failure soon. I was looking into this out of my curiosity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also working on the failure, already figured out why.

SparkSession._instantiatedSession = None

@since(2.0)
Expand Down
42 changes: 42 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,48 @@ def assertPandasEqual(self, expected, result):
self.assertTrue(expected.equals(result), msg=msg)


class PySparkSessionTests(unittest.TestCase):

def test_set_jvm_default_session(self):
spark = None
sc = None
try:
sc = SparkContext('local[4]', "test_spark_session")
spark = SparkSession(sc)
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
finally:
if spark is not None:
spark.stop()
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())
spark = None
sc = None

if sc is not None:
sc.stop()
sc = None

def test_jvm_default_session_already_set(self):
spark = None
sc = None
try:
sc = SparkContext('local[4]', "test_spark_session")
jsession = sc._jvm.SparkSession(sc._jsc.sc())
sc._jvm.SparkSession.setDefaultSession(jsession)

spark = SparkSession(sc, jsession)
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
finally:
if spark is not None:
spark.stop()
spark = None
sc = None

if sc is not None:
sc.stop()
sc = None


class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
Expand Down