Skip to content

Commit

Permalink
[BACKPORT] Add support for dask.persist (#2953) (#2990)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored May 5, 2022
1 parent ec72145 commit 0918713
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 17 deletions.
47 changes: 30 additions & 17 deletions mars/contrib/dask/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

from dask.core import istask, ishashable

from typing import List, Tuple
from typing import List, Tuple, Union
from .utils import reduce
from ...remote import spawn
from ...deploy.oscar.session import execute


def mars_scheduler(dsk: dict, keys: List[List[str]]):
def mars_scheduler(dsk: dict, keys: Union[List[List[str]], List[str]]):
"""
A Dask-Mars scheduler
Expand All @@ -30,22 +31,29 @@ def mars_scheduler(dsk: dict, keys: List[List[str]]):
----------
dsk: Dict
Dask graph, represented as a task DAG dictionary.
keys: List[List[str]]
2d-list of Dask graph keys whose values we wish to compute and return.
keys: Union[List[List[str]], List[str]]
1d or 2d list of Dask graph keys whose values we wish to compute and return.
Returns
-------
Object
Computed values corresponding to the provided keys.
Computed values corresponding to the provided keys with same dimension.
"""
res = reduce(mars_dask_get(dsk, keys)).execute().fetch()
if not isinstance(res, List):
return [[res]]
else:
return res

if isinstance(keys, List) and not isinstance(keys[0], List): # 1d keys
task = execute(mars_dask_get(dsk, keys))
if not isinstance(task, List):
task = [task]
return map(lambda x: x.fetch(), task)
else: # 2d keys
res = execute(reduce(mars_dask_get(dsk, keys))).fetch()
if not isinstance(res, List):
return [[res]]
else:
return res

def mars_dask_get(dsk: dict, keys: List[List]):

def mars_dask_get(dsk: dict, keys: Union[List[List[str]], List[str]]):
"""
A Dask-Mars convert function. This function will send the dask graph layers
to Mars Remote API, generating mars objects correspond to the provided keys.
Expand All @@ -54,21 +62,21 @@ def mars_dask_get(dsk: dict, keys: List[List]):
----------
dsk: Dict
Dask graph, represented as a task DAG dictionary.
keys: List[List[str]]
2d-list of Dask graph keys whose values we wish to compute and return.
keys: Union[List[List[str]], List[str]]
1d or 2d list of Dask graph keys whose values we wish to compute and return.
Returns
-------
Object
Spawned mars objects corresponding to the provided keys.
Spawned mars objects corresponding to the provided keys with same dimension.
"""

def _get_arg(a):
# if arg contains layer index or callable objs, handle it
if ishashable(a) and a in dsk.keys():
while ishashable(a) and a in dsk.keys():
a = dsk[a]
return _execute_task(a)
return _spawn_task(a)
elif not isinstance(a, str) and hasattr(a, "__getitem__"):
if istask(
a
Expand All @@ -80,9 +88,14 @@ def _get_arg(a):
return type(a)(_get_arg(i) for i in a)
return a

def _execute_task(task: tuple):
def _spawn_task(task: tuple):
if not istask(task):
return _get_arg(task)
return spawn(task[0], args=tuple(_get_arg(a) for a in task[1:]))

return [[_execute_task(dsk[k]) for k in keys_d] for keys_d in keys]
return [
[_spawn_task(dsk[k]) for k in keys_d]
if isinstance(keys_d, List)
else _spawn_task(dsk[keys_d])
for keys_d in keys
]
38 changes: 38 additions & 0 deletions mars/contrib/dask/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,41 @@ def inc(x: int):
assert dask.compute(test_obj) == dask.compute(
test_obj, scheduler=mars_scheduler
)


@pytest.mark.skipif(not dask_installed, reason="dask not installed")
def test_persist(setup_cluster):
import dask

def inc(x):
return x + 1

a = dask.delayed(inc)(1)
task_mars_persist = dask.delayed(inc)(a.persist(scheduler=mars_scheduler))
task_dask_persist = dask.delayed(inc)(a.persist())

assert task_dask_persist.compute() == task_mars_persist.compute(
scheduler=mars_scheduler
)


@pytest.mark.skipif(not dask_installed, reason="dask not installed")
def test_partitioned_dataframe_persist(setup_cluster):
import numpy as np
import pandas as pd
from dask import dataframe as dd
from pandas._testing import assert_frame_equal

data = np.random.randn(10000, 100)
df = dd.from_pandas(
pd.DataFrame(data, columns=[f"col{i}" for i in range(100)]), npartitions=4
)
df["col0"] = df["col0"] + df["col1"] / 2
col2_mean = df["col2"].mean()

df_mars_persist = df[df["col2"] > col2_mean.persist(scheduler=mars_scheduler)]
df_dask_persist = df[df["col2"] > col2_mean.persist()]

assert_frame_equal(
df_dask_persist.compute(), df_mars_persist.compute(scheduler=mars_scheduler)
)

0 comments on commit 0918713

Please sign in to comment.