Skip to content

Commit

Permalink
fix test_gather.cpp&&test_gather.json (#1084)
Browse files Browse the repository at this point in the history
* fix test_gather.cpp&&test_gather.json

* fix

* fix

* fix

---------

Co-authored-by: hejunchao <[email protected]>
  • Loading branch information
HeJunchao100813 and hejunchao authored Sep 6, 2023
1 parent 7d4c9a8 commit f828e43
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 9 deletions.
17 changes: 14 additions & 3 deletions tests/kernels/test_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,22 @@ class GatherTest : public KernelTest,
READY_SUBCASE()

auto shape = GetShapeArray("lhs_shape");
auto indices_shape = GetShapeArray("indices_shape");
auto indices_value = GetDataArray("indices_value");
auto value = GetNumber("axis");
auto typecode = GetDataType("lhs_type");

input = hrt::create(typecode, shape, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
init_tensor(input);

int64_t indices_array[] = {0, 0, -1, -1};
indices = hrt::create(dt_int64, {4},
size_t indices_value_size = indices_value.size();
auto *indices_array =
(int64_t *)malloc(indices_value_size * sizeof(int64_t));
std::copy(indices_value.begin(), indices_value.end(), indices_array);
indices = hrt::create(dt_int64, indices_shape,
{reinterpret_cast<gsl::byte *>(indices_array),
sizeof(indices_array)},
indices_value_size * sizeof(int64_t)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

Expand Down Expand Up @@ -112,15 +117,21 @@ TEST_P(GatherTest, gather) {
int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_shape, i)
FOR_LOOP(indices_shape, l)
FOR_LOOP(indices_value, h)
FOR_LOOP(axis, j)
FOR_LOOP(lhs_type, k)
SPLIT_ELEMENT(lhs_shape, i)
SPLIT_ELEMENT(indices_shape, l)
SPLIT_ELEMENT(indices_value, h)
SPLIT_ELEMENT(axis, j)
SPLIT_ELEMENT(lhs_type, k)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/test_gather.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{
"lhs_shape":[[2, 3, 5, 7], [2, 2], [2, 3, 1], [5, 5, 7, 7], [11]],
"indices_shape":[[4], [2, 2], [4, 1]],
"axis":[0, 1, -1, 2, 3, -2, -3, -4],
"indices_value": [[0, 0, -1, -1]],
"lhs_type":["dt_float32", "dt_int8", "dt_int32", "dt_uint8", "dt_int16", "dt_uint16", "dt_uint32", "dt_uint64", "dt_int64", "dt_float16", "dt_float64", "dt_bfloat16", "dt_boolean"]
}
17 changes: 14 additions & 3 deletions tests/kernels/test_gather_elements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,22 @@ class GatherElementsTest : public KernelTest,
READY_SUBCASE()

auto shape = GetShapeArray("lhs_shape");
auto indices_shape = GetShapeArray("indices_shape");
auto indices_value = GetDataArray("indices_value");
auto value = GetNumber("axis");
auto typecode = GetDataType("lhs_type");

input = hrt::create(typecode, shape, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
init_tensor(input);

int64_t indices_array[] = {0, 0, 1, 1};
indices = hrt::create(dt_int64, {2, 2},
size_t indices_value_size = indices_value.size();
auto *indices_array =
(int64_t *)malloc(indices_value_size * sizeof(int64_t));
std::copy(indices_value.begin(), indices_value.end(), indices_array);
indices = hrt::create(dt_int64, indices_shape,
{reinterpret_cast<gsl::byte *>(indices_array),
sizeof(indices_array)},
indices_value_size * sizeof(int64_t)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

Expand Down Expand Up @@ -110,15 +115,21 @@ TEST_P(GatherElementsTest, gather_elements) {
int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_shape, i)
FOR_LOOP(indices_shape, l)
FOR_LOOP(indices_value, h)
FOR_LOOP(axis, j)
FOR_LOOP(lhs_type, k)
SPLIT_ELEMENT(lhs_shape, i)
SPLIT_ELEMENT(indices_shape, l)
SPLIT_ELEMENT(indices_value, h)
SPLIT_ELEMENT(axis, j)
SPLIT_ELEMENT(lhs_type, k)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/test_gather_elements.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{
"lhs_shape":[[2, 2]],
"axis":[0],
"indices_shape":[[2, 2], [4, 1]],
"indices_value": [[0, 0, 1, 1]],
"lhs_type":["dt_float32", "dt_int8", "dt_int32", "dt_uint8", "dt_int16", "dt_uint16", "dt_uint32", "dt_uint64", "dt_int64", "dt_float16", "dt_float64", "dt_bfloat16", "dt_boolean"]
}
17 changes: 14 additions & 3 deletions tests/kernels/test_gather_nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,22 @@ class GatherNDTest : public KernelTest,
READY_SUBCASE()

auto shape = GetShapeArray("lhs_shape");
auto indices_shape = GetShapeArray("indices_shape");
auto indices_value = GetDataArray("indices_value");
auto value = GetNumber("axis");
auto typecode = GetDataType("lhs_type");

input = hrt::create(typecode, shape, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
init_tensor(input);

int64_t indices_array[] = {0, 0, 0, 0};
indices = hrt::create(dt_int64, {2, 2},
size_t indices_value_size = indices_value.size();
auto *indices_array =
(int64_t *)malloc(indices_value_size * sizeof(int64_t));
std::copy(indices_value.begin(), indices_value.end(), indices_array);
indices = hrt::create(dt_int64, indices_shape,
{reinterpret_cast<gsl::byte *>(indices_array),
sizeof(indices_array)},
indices_value_size * sizeof(int64_t)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

Expand Down Expand Up @@ -109,15 +114,21 @@ TEST_P(GatherNDTest, gather_nd) {
int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_shape, i)
FOR_LOOP(indices_shape, l)
FOR_LOOP(indices_value, h)
FOR_LOOP(axis, j)
FOR_LOOP(lhs_type, k)
SPLIT_ELEMENT(lhs_shape, i)
SPLIT_ELEMENT(indices_shape, l)
SPLIT_ELEMENT(indices_value, h)
SPLIT_ELEMENT(axis, j)
SPLIT_ELEMENT(lhs_type, k)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/test_gather_nd.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{
"lhs_shape":[[3, 5], [2, 2], [2, 3, 1], [5, 5, 7, 7]],
"axis":[0],
"indices_shape":[[2, 2], [4, 1]],
"indices_value": [[0, 0, 0, 0]],
"lhs_type":["dt_float32", "dt_int8", "dt_int32", "dt_uint8", "dt_int16", "dt_uint16", "dt_uint32", "dt_uint64", "dt_int64", "dt_float16", "dt_float64", "dt_bfloat16", "dt_boolean"]
}

0 comments on commit f828e43

Please sign in to comment.