Skip to content

Commit

Permalink
Merge 2032945 into f16deb5
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHastings authored Nov 3, 2022
2 parents f16deb5 + 2032945 commit 5009d5c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 8 additions & 0 deletions cpp/src/c_api/bfs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ extern "C" cugraph_error_code_t cugraph_bfs(const cugraph_resource_handle_t* han
cugraph_paths_result_t** result,
cugraph_error_t** error)
{
CAPI_EXPECTS(
reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)->vertex_type_ ==
reinterpret_cast<cugraph::c_api::cugraph_type_erased_device_array_view_t const*>(sources)
->type_,
CUGRAPH_INVALID_INPUT,
"vertex type of graph and sources must match",
*error);

cugraph::c_api::bfs_functor functor(handle,
graph,
sources,
Expand Down
8 changes: 6 additions & 2 deletions python/cugraph/cugraph/traversal/bfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def _ensure_args(G, start, i_start, directed):
else:
if not isinstance(start, cudf.DataFrame):
if not isinstance(start, dask_cudf.DataFrame):
start = cudf.DataFrame({"starts": cudf.Series(start)})
vertex_dtype = G.nodes().dtype
start = cudf.DataFrame(
{"starts": cudf.Series(start, dtype=vertex_dtype)}
)

if G.is_renumbered():
validlen = len(
Expand Down Expand Up @@ -224,7 +227,8 @@ def bfs(
if is_dataframe:
start = start[start.columns[0]]
else:
start = cudf.Series(start, name="starts")
vertex_dtype = G.nodes().dtype
start = cudf.Series(start, dtype=vertex_dtype)

distances, predecessors, vertices = pylibcugraph_bfs(
handle=ResourceHandle(),
Expand Down

0 comments on commit 5009d5c

Please sign in to comment.