From dd743d57e26bc33a589637cc837b03110ad37853 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 13 Mar 2019 13:51:11 -0700 Subject: [PATCH 1/7] refactor grouped map to use StructType return --- python/pyspark/worker.py | 19 +++++-------------- .../python/FlatMapGroupsInPandasExec.scala | 11 ++++++++++- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0e9b6d665a36f..3f570335f013c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -101,10 +101,7 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): - assign_cols_by_name = runner_conf.get( - "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true") - assign_cols_by_name = assign_cols_by_name.lower() == "true" +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd @@ -123,15 +120,9 @@ def wrapped(key_series, value_series): "Number of columns of the returned pandas.DataFrame " "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + return result - # Assign result columns by schema name if user labeled with strings, else use position - if assign_cols_by_name and any(isinstance(name, basestring) for name in result.columns): - return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] - else: - return [(result[result.columns[i]], to_arrow_type(field.dataType)) - for i, field in enumerate(return_type)] - - return wrapped + return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] def wrap_grouped_agg_pandas_udf(f, return_type): @@ -225,7 +216,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -255,7 +246,7 @@ def read_udfs(pickleSer, infile, eval_type): timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", "false").lower() == 'true' - # NOTE: this is duplicated from wrap_grouped_map_pandas_udf + # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType assign_cols_by_name = runner_conf.get( "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ .lower() == "true" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index e9cff1a5a2007..9d0ec2e227fd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} /** * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] @@ -145,7 +146,15 @@ case class FlatMapGroupsInPandasExec( sessionLocalTimeZone, pythonRunnerConf).compute(grouped, context.partitionId(), context) - columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) + columnarBatchIter.flatMap { batch => + // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here + // TODO: ColumnVector getChild is protected, so use ArrowColumnVector which is public + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild(_).asInstanceOf[ColumnVector]) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(UnsafeProjection.create(output, output)) } } } From 9a32b4469f7e7127df73d106d259a903d54898f9 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 13 Mar 2019 15:36:53 -0700 Subject: [PATCH 2/7] ArrowStreamPandasSerializer inherits basic ArrowStreamSerializer --- python/pyspark/serializers.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 0c3c68ec0bd95..fe8108e5add6a 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -229,6 +229,7 @@ def dump_stream(self, iterator, stream): try: for batch in iterator: if writer is None: + self._init_dump_stream(stream) writer = pa.RecordBatchStreamWriter(stream, batch.schema) writer.write_batch(batch) finally: @@ -241,6 +242,10 @@ def load_stream(self, stream): for batch in reader: yield batch + def _init_dump_stream(self, stream): + """Called just before writing an Arrow stream""" + pass + def __repr__(self): return "ArrowStreamSerializer" @@ -328,7 +333,7 @@ def create_array(s, t): return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) -class ArrowStreamPandasSerializer(Serializer): +class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ Serializes Pandas.Series as Arrow data with Arrow streaming format. """ @@ -347,33 +352,28 @@ def arrow_to_pandas(self, arrow_column): s = _check_series_localize_timestamps(s, self._timezone) return s + def _init_dump_stream(self, stream): + # Override to signal the start of writing an Arrow stream + # NOTE: this is required by Pandas UDFs to be called after creating first record batch so + # that any errors can be sent back to the JVM, but not interfere with the Arrow stream + write_int(SpecialLengths.START_ARROW_STREAM, stream) + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or a list of series accompanied by an optional pyarrow type to coerce the data to. """ - import pyarrow as pa - writer = None - try: - for series in iterator: - batch = _create_batch(series, self._timezone, self._safecheck, - self._assign_cols_by_name) - if writer is None: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - writer = pa.RecordBatchStreamWriter(stream, batch.schema) - writer.write_batch(batch) - finally: - if writer is not None: - writer.close() + batches = (_create_batch(series, self._timezone, self._safecheck, self._assign_cols_by_name) + for series in iterator) + super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream) def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ + batch_iter = super(ArrowStreamPandasSerializer, self).load_stream(stream) import pyarrow as pa - reader = pa.ipc.open_stream(stream) - - for batch in reader: + for batch in batch_iter: yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): From 93bb83151ad84e72ebcd13ed236a8d6ce4f69c95 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 14 Mar 2019 09:23:21 -0700 Subject: [PATCH 3/7] createDataFrame uses ArrowStreamPandasSerializer --- python/pyspark/serializers.py | 25 +++++++------ python/pyspark/sql/session.py | 37 ++++++++++--------- .../sql/execution/arrow/ArrowConverters.scala | 1 - 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fe8108e5add6a..bfb6e701abf79 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -222,6 +222,17 @@ class ArrowStreamSerializer(Serializer): """ Serializes Arrow record batches as a stream. """ + def __init__(self, send_start_stream=True): + self._send_start_stream = send_start_stream + + def _init_dump_stream(self, stream): + """ + Called just before writing an Arrow stream + """ + # NOTE: this is required by Pandas UDFs to be called after creating first record batch so + # that any errors can be sent back to the JVM, but not interfere with the Arrow stream + if self._send_start_stream: + write_int(SpecialLengths.START_ARROW_STREAM, stream) def dump_stream(self, iterator, stream): import pyarrow as pa @@ -242,10 +253,6 @@ def load_stream(self, stream): for batch in reader: yield batch - def _init_dump_stream(self, stream): - """Called just before writing an Arrow stream""" - pass - def __repr__(self): return "ArrowStreamSerializer" @@ -338,8 +345,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): Serializes Pandas.Series as Arrow data with Arrow streaming format. """ - def __init__(self, timezone, safecheck, assign_cols_by_name): - super(ArrowStreamPandasSerializer, self).__init__() + def __init__(self, timezone, safecheck, assign_cols_by_name, send_start_stream=True): + super(ArrowStreamPandasSerializer, self).__init__(send_start_stream) self._timezone = timezone self._safecheck = safecheck self._assign_cols_by_name = assign_cols_by_name @@ -352,12 +359,6 @@ def arrow_to_pandas(self, arrow_column): s = _check_series_localize_timestamps(s, self._timezone) return s - def _init_dump_stream(self, stream): - # Override to signal the start of writing an Arrow stream - # NOTE: this is required by Pandas UDFs to be called after creating first record batch so - # that any errors can be sent back to the JVM, but not interfere with the Arrow stream - write_int(SpecialLengths.START_ARROW_STREAM, stream) - def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 32a2c8a67252d..b5c530c655287 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -530,8 +530,8 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowStreamSerializer, _create_batch - from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType + from pyspark.serializers import ArrowStreamPandasSerializer + from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -539,6 +539,15 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): require_minimum_pyarrow_version() from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + import pyarrow as pa + + # Create the Spark schema from list of names passed in with Arrow types + if isinstance(schema, (list, tuple)): + arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False) + struct = StructType() + for name, field in zip(schema, arrow_schema): + struct.add(name, from_arrow_type(field.type), nullable=field.nullable) + schema = struct # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): @@ -555,23 +564,16 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) - # Create Arrow record batches - safecheck = self._wrapped._conf.arrowSafeTypeConversion() - col_by_name = True # col by name only applies to StructType columns, can't happen here - batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], - timezone, safecheck, col_by_name) - for pdf_slice in pdf_slices] - - # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) - if isinstance(schema, (list, tuple)): - struct = from_arrow_schema(batches[0].schema) - for i, name in enumerate(schema): - struct.fields[i].name = name - struct.names[i] = name - schema = struct + # Create list of Arrow (columns, type) for serializer dump_stream + arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)] + for pdf_slice in pdf_slices] jsqlContext = self._wrapped._jsqlContext + safecheck = self._wrapped._conf.arrowSafeTypeConversion() + col_by_name = True # col by name only applies to StructType columns, can't happen here + ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name, send_start_stream=False) + def reader_func(temp_filename): return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename) @@ -579,8 +581,7 @@ def create_RDD_server(): return self._jvm.ArrowRDDServer(jsqlContext) # Create Spark DataFrame from Arrow stream file, using one batch per partition - jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func, - create_RDD_server) + jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server) jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext) df = DataFrame(jdf, self._wrapped) df._schema = schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 2bf6a58b55658..884dc8c6215ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -31,7 +31,6 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.network.util.JavaUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ From bc08d1bcf745271ca6a9a31f92532815ddf616a4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 18 Mar 2019 13:36:26 -0700 Subject: [PATCH 4/7] change arrow start stream length to be written in ArrowStreamPandasSerializer --- python/pyspark/serializers.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index bfb6e701abf79..062fbfcd32042 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -222,17 +222,6 @@ class ArrowStreamSerializer(Serializer): """ Serializes Arrow record batches as a stream. """ - def __init__(self, send_start_stream=True): - self._send_start_stream = send_start_stream - - def _init_dump_stream(self, stream): - """ - Called just before writing an Arrow stream - """ - # NOTE: this is required by Pandas UDFs to be called after creating first record batch so - # that any errors can be sent back to the JVM, but not interfere with the Arrow stream - if self._send_start_stream: - write_int(SpecialLengths.START_ARROW_STREAM, stream) def dump_stream(self, iterator, stream): import pyarrow as pa @@ -240,7 +229,6 @@ def dump_stream(self, iterator, stream): try: for batch in iterator: if writer is None: - self._init_dump_stream(stream) writer = pa.RecordBatchStreamWriter(stream, batch.schema) writer.write_batch(batch) finally: @@ -346,10 +334,11 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ def __init__(self, timezone, safecheck, assign_cols_by_name, send_start_stream=True): - super(ArrowStreamPandasSerializer, self).__init__(send_start_stream) + super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone self._safecheck = safecheck self._assign_cols_by_name = assign_cols_by_name + self._send_start_stream = send_start_stream def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import from_arrow_type, \ @@ -366,15 +355,28 @@ def dump_stream(self, iterator, stream): """ batches = (_create_batch(series, self._timezone, self._safecheck, self._assign_cols_by_name) for series in iterator) - super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream) + + def init_stream_yield_batches(): + # NOTE: START_ARROW_STREAM is required by Pandas UDFs, called after creating the first + # record batch so any errors can be sent back to the JVM before the Arrow stream starts + should_write_start_length = True + for batch in batches: + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + iterator = init_stream_yield_batches() if self._send_start_stream else batches + + super(ArrowStreamPandasSerializer, self).dump_stream(iterator, stream) def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - batch_iter = super(ArrowStreamPandasSerializer, self).load_stream(stream) + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) import pyarrow as pa - for batch in batch_iter: + for batch in batches: yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): From 5832d639e0338416e06d0fdcdce097da635a6870 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 20 Mar 2019 09:45:28 -0700 Subject: [PATCH 5/7] assign FlatMapGroupsInPandasExec unsafe projection to a variable --- .../sql/execution/python/FlatMapGroupsInPandasExec.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 9d0ec2e227fd0..ce755ffb7c9fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -146,15 +146,16 @@ case class FlatMapGroupsInPandasExec( sessionLocalTimeZone, pythonRunnerConf).compute(grouped, context.partitionId(), context) + val unsafeProj = UnsafeProjection.create(output, output) + columnarBatchIter.flatMap { batch => // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here - // TODO: ColumnVector getChild is protected, so use ArrowColumnVector which is public val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild(_).asInstanceOf[ColumnVector]) + val outputVectors = output.indices.map(structVector.getChild) val flattenedBatch = new ColumnarBatch(outputVectors.toArray) flattenedBatch.setNumRows(batch.numRows()) flattenedBatch.rowIterator.asScala - }.map(UnsafeProjection.create(output, output)) + }.map(unsafeProj) } } } From 1809dfecb711a177cd1f5118d984f0c42e97f11a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 20 Mar 2019 10:52:48 -0700 Subject: [PATCH 6/7] added workaround for creatDataFrame schema generation --- python/pyspark/sql/session.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b5c530c655287..cb3f6842a045b 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -530,6 +530,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ + from distutils.version import LooseVersion from pyspark.serializers import ArrowStreamPandasSerializer from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ @@ -543,7 +544,11 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): # Create the Spark schema from list of names passed in with Arrow types if isinstance(schema, (list, tuple)): - arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False) + if LooseVersion(pa.__version__) < LooseVersion("0.12.0"): + temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False) + arrow_schema = temp_batch.schema + else: + arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False) struct = StructType() for name, field in zip(schema, arrow_schema): struct.add(name, from_arrow_type(field.type), nullable=field.nullable) From f6b0e30d818531b09ec4e72556fc11e7706383d4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 20 Mar 2019 13:43:23 -0700 Subject: [PATCH 7/7] move _create_batch to ArrowPandasStreamSerializer and make subclass to handle Pandas UDF serialize --- python/pyspark/serializers.py | 219 ++++++++++++++++++---------------- python/pyspark/sql/session.py | 2 +- python/pyspark/worker.py | 4 +- 3 files changed, 121 insertions(+), 104 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 062fbfcd32042..58f7552cab491 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -245,100 +245,20 @@ def __repr__(self): return "ArrowStreamSerializer" -def _create_batch(series, timezone, safecheck, assign_cols_by_name): - """ - Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. - - :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) - :param timezone: A timezone to respect when handling timestamp values - :return: Arrow RecordBatch - """ - import decimal - from distutils.version import LooseVersion - import pandas as pd - import pyarrow as pa - from pyspark.sql.types import _check_series_convert_timestamps_internal - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or \ - (len(series) == 2 and isinstance(series[1], pa.DataType)): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - def create_array(s, t): - mask = s.isnull() - # Ensure timestamp series are in expected form for Spark internal representation - # TODO: maybe don't need None check anymore as of Arrow 0.9.1 - if t is not None and pa.types.is_timestamp(t): - s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) - # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 - return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) - elif t is not None and pa.types.is_string(t) and sys.version < '3': - # TODO: need decode before converting to Arrow in Python 2 - # TODO: don't need as of Arrow 0.9.1 - return pa.Array.from_pandas(s.apply( - lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) - elif t is not None and pa.types.is_decimal(t) and \ - LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. - return pa.Array.from_pandas(s.apply( - lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) - elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"): - # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. - return pa.Array.from_pandas(s, mask=mask, type=t) - - try: - array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck) - except pa.ArrowException as e: - error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \ - "Array (%s). It can be caused by overflows or other unsafe " + \ - "conversions warned by Arrow. Arrow safe type check can be " + \ - "disabled by using SQL config " + \ - "`spark.sql.execution.pandas.arrowSafeTypeConversion`." - raise RuntimeError(error_msg % (s.dtype, t), e) - return array - - arrs = [] - for s, t in series: - if t is not None and pa.types.is_struct(t): - if not isinstance(s, pd.DataFrame): - raise ValueError("A field of type StructType expects a pandas.DataFrame, " - "but got: %s" % str(type(s))) - - # Input partition and result pandas.DataFrame empty, make empty Arrays with struct - if len(s) == 0 and len(s.columns) == 0: - arrs_names = [(pa.array([], type=field.type), field.name) for field in t] - # Assign result columns by schema name if user labeled with strings - elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns): - arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t] - # Assign result columns by position - else: - arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) - for i, field in enumerate(t)] - - struct_arrs, struct_names = zip(*arrs_names) - - # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version - if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): - arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs)) - else: - arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) - else: - arrs.append(create_array(s, t)) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) - - class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ Serializes Pandas.Series as Arrow data with Arrow streaming format. + + :param timezone: A timezone to respect when handling timestamp values + :param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation + :param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name """ - def __init__(self, timezone, safecheck, assign_cols_by_name, send_start_stream=True): + def __init__(self, timezone, safecheck, assign_cols_by_name): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone self._safecheck = safecheck self._assign_cols_by_name = assign_cols_by_name - self._send_start_stream = send_start_stream def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import from_arrow_type, \ @@ -348,27 +268,97 @@ def arrow_to_pandas(self, arrow_column): s = _check_series_localize_timestamps(s, self._timezone) return s - def dump_stream(self, iterator, stream): + def _create_batch(self, series): """ - Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or - a list of series accompanied by an optional pyarrow type to coerce the data to. + Create an Arrow record batch from the given pandas.Series or list of Series, + with optional type. + + :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) + :return: Arrow RecordBatch """ - batches = (_create_batch(series, self._timezone, self._safecheck, self._assign_cols_by_name) - for series in iterator) + import decimal + from distutils.version import LooseVersion + import pandas as pd + import pyarrow as pa + from pyspark.sql.types import _check_series_convert_timestamps_internal + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or \ + (len(series) == 2 and isinstance(series[1], pa.DataType)): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + def create_array(s, t): + mask = s.isnull() + # Ensure timestamp series are in expected form for Spark internal representation + # TODO: maybe don't need None check anymore as of Arrow 0.9.1 + if t is not None and pa.types.is_timestamp(t): + s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone) + # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 + return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + elif t is not None and pa.types.is_string(t) and sys.version < '3': + # TODO: need decode before converting to Arrow in Python 2 + # TODO: don't need as of Arrow 0.9.1 + return pa.Array.from_pandas(s.apply( + lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) + elif t is not None and pa.types.is_decimal(t) and \ + LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. + return pa.Array.from_pandas(s.apply( + lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) + elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. + return pa.Array.from_pandas(s, mask=mask, type=t) - def init_stream_yield_batches(): - # NOTE: START_ARROW_STREAM is required by Pandas UDFs, called after creating the first - # record batch so any errors can be sent back to the JVM before the Arrow stream starts - should_write_start_length = True - for batch in batches: - if should_write_start_length: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - should_write_start_length = False - yield batch + try: + array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) + except pa.ArrowException as e: + error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \ + "Array (%s). It can be caused by overflows or other unsafe " + \ + "conversions warned by Arrow. Arrow safe type check can be " + \ + "disabled by using SQL config " + \ + "`spark.sql.execution.pandas.arrowSafeTypeConversion`." + raise RuntimeError(error_msg % (s.dtype, t), e) + return array + + arrs = [] + for s, t in series: + if t is not None and pa.types.is_struct(t): + if not isinstance(s, pd.DataFrame): + raise ValueError("A field of type StructType expects a pandas.DataFrame, " + "but got: %s" % str(type(s))) + + # Input partition and result pandas.DataFrame empty, make empty Arrays with struct + if len(s) == 0 and len(s.columns) == 0: + arrs_names = [(pa.array([], type=field.type), field.name) for field in t] + # Assign result columns by schema name if user labeled with strings + elif self._assign_cols_by_name and any(isinstance(name, basestring) + for name in s.columns): + arrs_names = [(create_array(s[field.name], field.type), field.name) + for field in t] + # Assign result columns by position + else: + arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) + for i, field in enumerate(t)] + + struct_arrs, struct_names = zip(*arrs_names) + + # TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version + if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): + arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs)) + else: + arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) + else: + arrs.append(create_array(s, t)) - iterator = init_stream_yield_batches() if self._send_start_stream else batches + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) - super(ArrowStreamPandasSerializer, self).dump_stream(iterator, stream) + def dump_stream(self, iterator, stream): + """ + Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or + a list of series accompanied by an optional pyarrow type to coerce the data to. + """ + batches = (self._create_batch(series) for series in iterator) + super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream) def load_stream(self, stream): """ @@ -383,6 +373,33 @@ def __repr__(self): return "ArrowStreamPandasSerializer" +class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): + """ + Serializer used by Python worker to evaluate Pandas UDFs + """ + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def init_stream_yield_batches(): + should_write_start_length = True + for series in iterator: + batch = self._create_batch(series) + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) + + def __repr__(self): + return "ArrowStreamPandasUDFSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index cb3f6842a045b..b11e0f3ff69de 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -577,7 +577,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): safecheck = self._wrapped._conf.arrowSafeTypeConversion() col_by_name = True # col by name only applies to StructType columns, can't happen here - ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name, send_start_stream=False) + ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name) def reader_func(temp_filename): return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3f570335f013c..f59fb443b4db3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,7 +38,7 @@ from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowStreamPandasSerializer + BatchedSerializer, ArrowStreamPandasUDFSerializer from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -251,7 +251,7 @@ def read_udfs(pickleSer, infile, eval_type): "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ .lower() == "true" - ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) + ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name) else: ser = BatchedSerializer(PickleSerializer(), 100)