Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
WeichenXu123 committed May 24, 2021
1 parent 1530876 commit af94332
Showing 1 changed file with 80 additions and 38 deletions.
118 changes: 80 additions & 38 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
#

import functools
import itertools
import os
import platform
Expand All @@ -25,8 +26,6 @@
import traceback
import types

from py4j.clientserver import ClientServer

__all__ = [] # type: ignore


Expand Down Expand Up @@ -263,6 +262,82 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])


def inheritable_thread_target(f):
"""
Return thread target wrapper which is recommended to be used in PySpark when the
pinned thread mode is enabled. The wrapper function, before calling original
thread target, it inherits the inheritable properties specific
to JVM thread such as ``InheritableThreadLocal``.
Also, note that pinned thread mode does not close the connection from Python
to JVM when the thread is finished in the Python side. With this wrapper, Python
garbage-collects the Python thread instance and also closes the connection
which finishes JVM thread correctly.
When the pinned thread mode is off, it return the original ``f``.
Parameters
----------
f : function
the original thread target.
.. versionadded:: 3.2.0
Notes
-----
This API is experimental.
It captures the local properties when you decorate it. Therefore, it is encouraged
to decorate it when you want to capture the local properties.
For example, the local properties from the current Spark context is captured
when you define a function here:
>>> @inheritable_thread_target
... def target_func():
... pass # your codes.
If you have any updates on local properties afterwards, it would not be reflected to
the Spark context in ``target_func()``.
The example below mimics the behavior of JVM threads as close as possible:
>>> Thread(target=inheritable_thread_target(target_func)).start() # doctest: +SKIP
"""
from pyspark import SparkContext
if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true":
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
sc = SparkContext._active_spark_context

# Get local properties from main thread
properties = sc._jsc.sc().getLocalProperties().clone()

@functools.wraps(f)
def wrapped_f(*args, **kwargs):
try:
# Set local properties in child thread.
sc._jsc.sc().setLocalProperties(properties)
return f(*args, **kwargs)
finally:
thread_connection = sc._jvm._gateway_client.thread_connection.connection()
if thread_connection is not None:
connections = sc._jvm._gateway_client.deque
# Reuse the lock for Py4J in PySpark
with SparkContext._lock:
for i in range(len(connections)):
if connections[i] is thread_connection:
connections[i].close()
del connections[i]
break
else:
# Just in case the connection was not closed but removed from the
# queue.
thread_connection.close()
return wrapped_f
else:
return f


class InheritableThread(threading.Thread):
"""
Thread that is recommended to be used in PySpark instead of :class:`threading.Thread`
Expand All @@ -285,42 +360,9 @@ class InheritableThread(threading.Thread):
This API is experimental.
"""
def __init__(self, target, *args, **kwargs):
from pyspark import SparkContext

sc = SparkContext._active_spark_context

if isinstance(sc._gateway, ClientServer):
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
properties = sc._jsc.sc().getLocalProperties().clone()
self._sc = sc

def copy_local_properties(*a, **k):
sc._jsc.sc().setLocalProperties(properties)
return target(*a, **k)

super(InheritableThread, self).__init__(
target=copy_local_properties, *args, **kwargs)
else:
super(InheritableThread, self).__init__(target=target, *args, **kwargs)

def __del__(self):
from pyspark import SparkContext

if isinstance(SparkContext._gateway, ClientServer):
thread_connection = self._sc._jvm._gateway_client.thread_connection.connection()
if thread_connection is not None:
connections = self._sc._jvm._gateway_client.deque

# Reuse the lock for Py4J in PySpark
with SparkContext._lock:
for i in range(len(connections)):
if connections[i] is thread_connection:
connections[i].close()
del connections[i]
break
else:
# Just in case the connection was not closed but removed from the queue.
thread_connection.close()
super(InheritableThread, self).__init__(
target=inheritable_thread_target(target), *args, **kwargs
)


if __name__ == "__main__":
Expand Down

0 comments on commit af94332

Please sign in to comment.