Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameter checks to BFS and SSSP in C API #2844

Merged
merged 13 commits into from
Nov 3, 2022
8 changes: 8 additions & 0 deletions cpp/src/c_api/bfs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ extern "C" cugraph_error_code_t cugraph_bfs(const cugraph_resource_handle_t* han
cugraph_paths_result_t** result,
cugraph_error_t** error)
{
CAPI_EXPECTS(
reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)->vertex_type_ ==
reinterpret_cast<cugraph::c_api::cugraph_type_erased_device_array_view_t const*>(sources)
->type_,
CUGRAPH_INVALID_INPUT,
"vertex type of graph and sources must match",
*error);

cugraph::c_api::bfs_functor functor(handle,
graph,
sources,
Expand Down
53 changes: 53 additions & 0 deletions cpp/tests/c_api/bfs_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,58 @@ int generic_bfs_test(vertex_t* h_src,
return test_ret_value;
}

int test_bfs_exceptions()
{
size_t num_edges = 8;
size_t num_vertices = 6;
size_t depth_limit = 1;
size_t num_seeds = 1;

vertex_t src[] = {0, 1, 1, 2, 2, 2, 3, 4};
vertex_t dst[] = {1, 3, 4, 0, 1, 3, 5, 5};
weight_t wgt[] = {0.1f, 2.1f, 1.1f, 5.1f, 3.1f, 4.1f, 7.2f, 3.2f};
int64_t seeds[] = {0};

int test_ret_value = 0;

cugraph_error_code_t ret_code = CUGRAPH_SUCCESS;
cugraph_error_t* ret_error = NULL;

cugraph_resource_handle_t* p_handle = NULL;
cugraph_graph_t* p_graph = NULL;
cugraph_paths_result_t* p_result = NULL;
cugraph_type_erased_device_array_t* p_sources = NULL;
cugraph_type_erased_device_array_view_t* p_source_view = NULL;

p_handle = cugraph_create_resource_handle(NULL);
TEST_ASSERT(test_ret_value, p_handle != NULL, "resource handle creation failed.");

ret_code = create_test_graph(
p_handle, src, dst, wgt, num_edges, FALSE, FALSE, FALSE, &p_graph, &ret_error);

/*
* FIXME: in create_graph_test.c, variables are defined but then hard-coded to
* the constant INT32. It would be better to pass the types into the functions
* in both cases so that the test cases could be parameterized in the main.
*/
ret_code =
cugraph_type_erased_device_array_create(p_handle, num_seeds, INT64, &p_sources, &ret_error);
TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "p_sources create failed.");

p_source_view = cugraph_type_erased_device_array_view(p_sources);

ret_code = cugraph_type_erased_device_array_view_copy_from_host(
p_handle, p_source_view, (byte_t*)seeds, &ret_error);
TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "src copy_from_host failed.");

ret_code = cugraph_bfs(
p_handle, p_graph, p_source_view, FALSE, depth_limit, TRUE, FALSE, &p_result, &ret_error);

TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_INVALID_INPUT, "cugraph_bfs expected to fail");

return test_ret_value;
}

int test_bfs()
{
size_t num_edges = 8;
Expand Down Expand Up @@ -176,5 +228,6 @@ int main(int argc, char** argv)
int result = 0;
result |= RUN_TEST(test_bfs);
result |= RUN_TEST(test_bfs_with_transpose);
result |= RUN_TEST(test_bfs_exceptions);
return result;
}
8 changes: 6 additions & 2 deletions python/cugraph/cugraph/traversal/bfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def _ensure_args(G, start, i_start, directed):
else:
if not isinstance(start, cudf.DataFrame):
if not isinstance(start, dask_cudf.DataFrame):
start = cudf.DataFrame({"starts": cudf.Series(start)})
vertex_dtype = G.nodes().dtype
start = cudf.DataFrame(
{"starts": cudf.Series(start, dtype=vertex_dtype)}
)

if G.is_renumbered():
validlen = len(
Expand Down Expand Up @@ -224,7 +227,8 @@ def bfs(
if is_dataframe:
start = start[start.columns[0]]
else:
start = cudf.Series(start, name="starts")
vertex_dtype = G.nodes().dtype
start = cudf.Series(start, dtype=vertex_dtype)

distances, predecessors, vertices = pylibcugraph_bfs(
handle=ResourceHandle(),
Expand Down