Skip to content

Commit

Permalink
Bug dask cudf personalization (#1764)
Browse files Browse the repository at this point in the history
fixes pagerank error when personalization is passed

Authors:
  - https://github.com/Iroy30

Approvers:
  - Brad Rees (https://github.com/BradReesWork)
  - Rick Ratzel (https://github.com/rlratzel)

URL: #1764
  • Loading branch information
Iroy30 authored Aug 12, 2021
1 parent 0b3ab53 commit 35c53f7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
4 changes: 3 additions & 1 deletion python/cugraph/dask/common/part_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ async def _extract_partitions(dask_obj, client=None, batch_enabled=False):
if batch_enabled:
persisted = client.persist(dask_obj, workers=worker_list[0])
else:
persisted = client.persist(dask_obj)
persisted = [client.persist(
dask_obj.get_partition(p), workers=w) for p, w in enumerate(
worker_list)]
parts = futures_of(persisted)
# iterable of dask collections (need to colocate them)
elif isinstance(dask_obj, collections.Sequence):
Expand Down
37 changes: 32 additions & 5 deletions python/cugraph/dask/link_analysis/pagerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cugraph.dask.link_analysis import mg_pagerank_wrapper as mg_pagerank
import cugraph.comms.comms as Comms
import dask_cudf
from dask.dataframe.shuffle import rearrange_by_column


def call_pagerank(sID,
Expand Down Expand Up @@ -124,8 +125,6 @@ def pagerank(input_graph,
edge_attr='value')
>>> pr = dcg.pagerank(dg)
"""
from cugraph.structure.graph_classes import null_check

nstart = None

client = default_client()
Expand All @@ -139,13 +138,41 @@ def pagerank(input_graph,
data = get_distributed_data(ddf)

if personalization is not None:
null_check(personalization["vertex"])
null_check(personalization["values"])
if input_graph.renumbered is True:
personalization = input_graph.add_internal_vertex_id(
personalization, "vertex", "vertex"
)
p_data = get_distributed_data(personalization)

# Function to assign partition id to personalization dataframe
def _set_partitions_pre(s, divisions):
partitions = divisions.searchsorted(s, side="right") - 1
partitions[
divisions.tail(1).searchsorted(s, side="right").astype("bool")
] = (len(divisions) - 2)
return partitions

# Assign partition id column as per vertex_partition_offsets
df = personalization
by = ['vertex']
meta = df._meta._constructor_sliced([0])
divisions = vertex_partition_offsets
partitions = df[by].map_partitions(
_set_partitions_pre, divisions=divisions, meta=meta
)

df2 = df.assign(_partitions=partitions)

# Shuffle personalization values according to the partition id
df3 = rearrange_by_column(
df2,
"_partitions",
max_branch=None,
npartitions=len(divisions) - 1,
shuffle="tasks",
ignore_index=False,
).drop(columns=["_partitions"])

p_data = get_distributed_data(df3)

result = [client.submit(call_pagerank,
Comms.get_session_id(),
Expand Down
6 changes: 6 additions & 0 deletions python/cugraph/link_analysis/pagerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from cugraph.link_analysis import pagerank_wrapper
import cugraph
import cudf


def pagerank(
Expand Down Expand Up @@ -97,6 +98,11 @@ def pagerank(
G, isNx = cugraph.utilities.check_nx_graph(G, weight)

if personalization is not None:
if not isinstance(personalization, cudf.DataFrame):
raise NotImplementedError(
"personalization other than a cudf dataframe "
"currently not supported"
)
if G.renumbered is True:
if len(G.renumber_map.implementation.col_names) > 1:
cols = personalization.columns[:-1].to_list()
Expand Down

0 comments on commit 35c53f7

Please sign in to comment.