Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luigi.contrib.pyspark_runner] SparkSession support in PySparkTask #2862

Merged
merged 6 commits into from
Jan 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 84 additions & 5 deletions luigi/contrib/pyspark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

from __future__ import print_function

import abc

try:
import cPickle as pickle
except ImportError:
Expand All @@ -36,11 +38,70 @@
import sys
import os

from luigi import configuration
from luigi import six

# this prevents the modules in the directory of this script from shadowing global packages
sys.path.append(sys.path.pop(0))


class PySparkRunner(object):
@six.add_metaclass(abc.ABCMeta)
class _SparkEntryPoint(object):
drowoseque marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, conf):
self.conf = conf

@abc.abstractmethod
def __enter__(self):
pass

@abc.abstractmethod
def __exit__(self, exc_type, exc_val, exc_tb):
drowoseque marked this conversation as resolved.
Show resolved Hide resolved
pass


class SparkContextEntryPoint(_SparkEntryPoint):
sc = None

def __enter__(self):
from pyspark import SparkContext
self.sc = SparkContext(conf=self.conf)
return self.sc, self.sc

def __exit__(self, exc_type, exc_val, exc_tb):
self.sc.stop()


class SparkSessionEntryPoint(_SparkEntryPoint):
spark = None

def _check_major_spark_version(self):
from pyspark import __version__ as spark_version
major_version = int(spark_version.split('.')[0])
if major_version < 2:
raise RuntimeError(
'''
Apache Spark {} does not support SparkSession entrypoint.
Try to set 'pyspark_runner.use_spark_session' to 'False' and switch to old-style syntax
'''.format(spark_version)
)

def __enter__(self):
self._check_major_spark_version()
from pyspark.sql import SparkSession
self.spark = SparkSession \
.builder \
.config(conf=self.conf) \
.enableHiveSupport() \
drowoseque marked this conversation as resolved.
Show resolved Hide resolved
.getOrCreate()

return self.spark, self.spark.sparkContext

def __exit__(self, exc_type, exc_val, exc_tb):
self.spark.stop()


class AbstractPySparkRunner(object):
_entry_point_class = None

def __init__(self, job, *args):
# Append job directory to PYTHON_PATH to enable dynamic import
Expand All @@ -51,14 +112,32 @@ def __init__(self, job, *args):
self.args = args

def run(self):
from pyspark import SparkContext, SparkConf
from pyspark import SparkConf
conf = SparkConf()
self.job.setup(conf)
with SparkContext(conf=conf) as sc:
with self._entry_point_class(conf=conf) as (entry_point, sc):
self.job.setup_remote(sc)
self.job.main(sc, *self.args)
self.job.main(entry_point, *self.args)
drowoseque marked this conversation as resolved.
Show resolved Hide resolved


def _pyspark_runner_with(name, entry_point_class):
return type(name, (AbstractPySparkRunner,), {'_entry_point_class': entry_point_class})


PySparkRunner = _pyspark_runner_with('PySparkRunner', SparkContextEntryPoint)
PySparkSessionRunner = _pyspark_runner_with('PySparkSessionRunner', SparkSessionEntryPoint)


def _use_spark_session():
return bool(configuration.get_config().get('pyspark_runner', "use_spark_session", False))
drowoseque marked this conversation as resolved.
Show resolved Hide resolved


def _get_runner_class():
if _use_spark_session():
return PySparkSessionRunner
return PySparkRunner


if __name__ == '__main__':
logging.basicConfig(level=logging.WARN)
PySparkRunner(*sys.argv[1:]).run()
_get_runner_class()(*sys.argv[1:]).run()
59 changes: 58 additions & 1 deletion test/contrib/spark_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ def main(self, sc, *args):
sc.textFile(self.input().path).saveAsTextFile(self.output().path)


class TestPySparkSessionTask(PySparkTask):
def input(self):
return MockTarget('input')

def output(self):
return MockTarget('output')

def main(self, spark, *args):
spark.sql(self.input().path).write.saveAsTable(self.output().path)


class MessyNamePySparkTask(TestPySparkTask):
name = 'AppName(a,b,c,1:2,3/4)'

Expand Down Expand Up @@ -292,7 +303,7 @@ def test_run_with_cluster(self, proc):
@patch.dict('sys.modules', {'pyspark': MagicMock()})
@patch('pyspark.SparkContext')
def test_pyspark_runner(self, spark_context):
sc = spark_context.return_value.__enter__.return_value
sc = spark_context.return_value

def mock_spark_submit(task):
from luigi.contrib.pyspark_runner import PySparkRunner
Expand All @@ -315,6 +326,52 @@ def mock_spark_submit(task):

sc.textFile.assert_called_with('input')
sc.textFile.return_value.saveAsTextFile.assert_called_with('output')
sc.stop.assert_called_once_with()

def test_pyspark_session_runner_use_spark_session_true(self):
pyspark = MagicMock()
pyspark.__version__ = '2.1.0'
pyspark_sql = MagicMock()
with patch.dict(sys.modules, {'pyspark': pyspark, 'pyspark.sql': pyspark_sql}):
spark = pyspark_sql.SparkSession.builder.config.return_value.enableHiveSupport.return_value.getOrCreate.return_value
sc = spark.sparkContext

def mock_spark_submit(task):
from luigi.contrib.pyspark_runner import PySparkSessionRunner
PySparkSessionRunner(*task.app_command()[1:]).run()
# Check py-package exists
self.assertTrue(os.path.exists(sc.addPyFile.call_args[0][0]))
# Check that main module containing the task exists.
run_path = os.path.dirname(task.app_command()[1])
self.assertTrue(os.path.exists(os.path.join(run_path, os.path.basename(__file__))))
# Check that the python path contains the run_path
self.assertTrue(run_path in sys.path)
# Check if find_class finds the class for the correct module name.
with open(task.app_command()[1], 'rb') as fp:
self.assertTrue(pickle.Unpickler(fp).find_class('spark_test', 'TestPySparkSessionTask'))

with patch.object(SparkSubmitTask, 'run', mock_spark_submit):
job = TestPySparkSessionTask()
with temporary_unloaded_module(b'') as task_module:
with_config({'spark': {'py-packages': task_module}})(job.run)()

spark.sql.assert_called_with('input')
spark.sql.return_value.write.saveAsTable.assert_called_with('output')
spark.stop.assert_called_once_with()

def test_pyspark_session_runner_use_spark_session_true_spark1(self):
pyspark = MagicMock()
pyspark.__version__ = '1.6.3'
pyspark_sql = MagicMock()
with patch.dict(sys.modules, {'pyspark': pyspark, 'pyspark.sql': pyspark_sql}):
def mock_spark_submit(task):
from luigi.contrib.pyspark_runner import PySparkSessionRunner
self.assertRaises(RuntimeError, PySparkSessionRunner(*task.app_command()[1:]).run)

with patch.object(SparkSubmitTask, 'run', mock_spark_submit):
job = TestPySparkSessionTask()
with temporary_unloaded_module(b'') as task_module:
with_config({'spark': {'py-packages': task_module}})(job.run)()

@patch('luigi.contrib.external_program.subprocess.Popen')
def test_name_cleanup(self, proc):
Expand Down