From af94332023c65df92a5b1a781dd3771e115b279b Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 24 May 2021 11:09:08 +0800 Subject: [PATCH] init --- python/pyspark/util.py | 118 ++++++++++++++++++++++++++++------------- 1 file changed, 80 insertions(+), 38 deletions(-) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 09c5963927456..bc3faf0e7d7a6 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import functools import itertools import os import platform @@ -25,8 +26,6 @@ import traceback import types -from py4j.clientserver import ClientServer - __all__ = [] # type: ignore @@ -263,6 +262,82 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) +def inheritable_thread_target(f): + """ + Return thread target wrapper which is recommended to be used in PySpark when the + pinned thread mode is enabled. The wrapper function, before calling original + thread target, it inherits the inheritable properties specific + to JVM thread such as ``InheritableThreadLocal``. + + Also, note that pinned thread mode does not close the connection from Python + to JVM when the thread is finished in the Python side. With this wrapper, Python + garbage-collects the Python thread instance and also closes the connection + which finishes JVM thread correctly. + + When the pinned thread mode is off, it return the original ``f``. + + Parameters + ---------- + f : function + the original thread target. + + .. versionadded:: 3.2.0 + + Notes + ----- + This API is experimental. + + It captures the local properties when you decorate it. Therefore, it is encouraged + to decorate it when you want to capture the local properties. + + For example, the local properties from the current Spark context is captured + when you define a function here: + + >>> @inheritable_thread_target + ... def target_func(): + ... pass # your codes. + + If you have any updates on local properties afterwards, it would not be reflected to + the Spark context in ``target_func()``. + + The example below mimics the behavior of JVM threads as close as possible: + + >>> Thread(target=inheritable_thread_target(target_func)).start() # doctest: +SKIP + """ + from pyspark import SparkContext + if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true": + # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. + sc = SparkContext._active_spark_context + + # Get local properties from main thread + properties = sc._jsc.sc().getLocalProperties().clone() + + @functools.wraps(f) + def wrapped_f(*args, **kwargs): + try: + # Set local properties in child thread. + sc._jsc.sc().setLocalProperties(properties) + return f(*args, **kwargs) + finally: + thread_connection = sc._jvm._gateway_client.thread_connection.connection() + if thread_connection is not None: + connections = sc._jvm._gateway_client.deque + # Reuse the lock for Py4J in PySpark + with SparkContext._lock: + for i in range(len(connections)): + if connections[i] is thread_connection: + connections[i].close() + del connections[i] + break + else: + # Just in case the connection was not closed but removed from the + # queue. + thread_connection.close() + return wrapped_f + else: + return f + + class InheritableThread(threading.Thread): """ Thread that is recommended to be used in PySpark instead of :class:`threading.Thread` @@ -285,42 +360,9 @@ class InheritableThread(threading.Thread): This API is experimental. """ def __init__(self, target, *args, **kwargs): - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - - if isinstance(sc._gateway, ClientServer): - # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. - properties = sc._jsc.sc().getLocalProperties().clone() - self._sc = sc - - def copy_local_properties(*a, **k): - sc._jsc.sc().setLocalProperties(properties) - return target(*a, **k) - - super(InheritableThread, self).__init__( - target=copy_local_properties, *args, **kwargs) - else: - super(InheritableThread, self).__init__(target=target, *args, **kwargs) - - def __del__(self): - from pyspark import SparkContext - - if isinstance(SparkContext._gateway, ClientServer): - thread_connection = self._sc._jvm._gateway_client.thread_connection.connection() - if thread_connection is not None: - connections = self._sc._jvm._gateway_client.deque - - # Reuse the lock for Py4J in PySpark - with SparkContext._lock: - for i in range(len(connections)): - if connections[i] is thread_connection: - connections[i].close() - del connections[i] - break - else: - # Just in case the connection was not closed but removed from the queue. - thread_connection.close() + super(InheritableThread, self).__init__( + target=inheritable_thread_target(target), *args, **kwargs + ) if __name__ == "__main__":