Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add variable length loop kernels #3003

Merged
merged 20 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b9c4fe6
feat: add variable length kernels
ManasviGoyal Feb 5, 2024
8c7e066
fix: spec kernel errors
ManasviGoyal Feb 5, 2024
aa653c5
feat: add awkward_ListArray_broadcast_tooffsets
ManasviGoyal Feb 5, 2024
0619525
fix: awkward_ListArray_compact_offsets kernel
ManasviGoyal Feb 5, 2024
6701dc4
test: remove XFAIL
ManasviGoyal Feb 5, 2024
5f6f699
style: pre-commit fixes
pre-commit-ci[bot] Feb 5, 2024
e6dc15b
feat: add awkward_ListArray_getitem_jagged_descend.cu
ManasviGoyal Feb 6, 2024
a5671fd
feat: add awkward_ListArray_getitem_jagged_numvalid
ManasviGoyal Feb 6, 2024
9430c98
feat: add awkward_ListArray_getitem_next_range_spreadadvanced
ManasviGoyal Feb 6, 2024
01da944
feat: add awkward_ListOffsetArray_rpad_length_axis1
ManasviGoyal Feb 6, 2024
00a01e8
feat: add awkward_ListOffsetArray_toRegularArray.cpp
ManasviGoyal Feb 7, 2024
843463f
feat: add awkward_ListArray_localindex
ManasviGoyal Feb 7, 2024
28188e2
feat: add awkward_ListOffsetArray_reduce_local_nextparents_64.cu
ManasviGoyal Feb 8, 2024
ef3c6ed
feat: add awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscop…
ManasviGoyal Feb 8, 2024
f98983f
feat: add awkward_UnionArray_regular_index_getsize
ManasviGoyal Feb 9, 2024
3352507
refactor: remove _a from the name of the kernel
ManasviGoyal Feb 9, 2024
aa5cf58
test: generate tests when outarg is also an inarg
ManasviGoyal Feb 9, 2024
64bd34c
feat: add awkward_ListOffsetArray_drop_none_indexes
ManasviGoyal Feb 9, 2024
4fdad74
fix: awkward_ListOffsetArray_drop_none_indexes
ManasviGoyal Feb 9, 2024
28bd029
Merge branch 'main' into ManasviGoyal/variable-length-kernels
ManasviGoyal Feb 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ERROR awkward_ListOffsetArray_toRegularArray(
*size = count;
}
else if (*size != count) {
return failure("cannot convert to RegularArray because subarray lengths are not " "regular", i, kSliceNone, FILENAME(__LINE__));
return failure("cannot convert to RegularArray because subarray lengths are not regular", i, kSliceNone, FILENAME(__LINE__));
}
}
if (*size == -1) {
Expand Down
15 changes: 14 additions & 1 deletion dev/generate-kernel-signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"awkward_ListArray_min_range",
"awkward_ListArray_validity",
"awkward_BitMaskedArray_to_ByteMaskedArray",
"awkward_ListArray_broadcast_tooffsets",
"awkward_ListArray_compact_offsets",
"awkward_ListOffsetArray_flatten_offsets",
"awkward_IndexedArray_overlay_mask",
Expand Down Expand Up @@ -52,14 +53,18 @@
"awkward_RegularArray_reduce_nonlocal_preparenext",
"awkward_missing_repeat",
"awkward_RegularArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_carrylen",
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
"awkward_ListArray_getitem_next_range_counts",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
"awkward_ListArray_getitem_next_range_spreadadvanced",
"awkward_ListArray_localindex",
"awkward_NumpyArray_reduce_adjust_starts_64",
"awkward_NumpyArray_reduce_adjust_starts_shifts_64",
"awkward_RegularArray_getitem_next_at",
Expand All @@ -76,14 +81,22 @@
"awkward_IndexedArray_getitem_nextcarry",
"awkward_IndexedArray_getitem_nextcarry_outindex",
"awkward_IndexedArray_index_of_nulls",
"awkward_IndexedArray_ranges_next_64",
"awkward_IndexedArray_ranges_carry_next_64",
"awkward_IndexedArray_reduce_next_64",
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64",
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_fromshifts_64",
"awkward_IndexedOptionArray_rpad_and_clip_mask_axis1",
"awkward_ListOffsetArray_rpad_and_clip_axis1",
"awkward_ListOffsetArray_rpad_length_axis1",
"awkward_ListOffsetArray_toRegularArray",
# "awkward_ListOffsetArray_rpad_axis1",
"awkward_MaskedArray_getitem_next_jagged_project",
"awkward_UnionArray_project",
"awkward_ListOffsetArray_drop_none_indexes",
"awkward_ListOffsetArray_reduce_local_nextparents_64",
"awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64",
"awkward_UnionArray_regular_index_getsize",
"awkward_UnionArray_simplify",
"awkward_UnionArray_simplify_one",
"awkward_reduce_argmax",
Expand Down
138 changes: 99 additions & 39 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,10 @@ def getdtypes(args):
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool" or typename == "float":
if typename == "bool":
typename = typename + "_"
if typename == "float":
typename = typename + "32"
ManasviGoyal marked this conversation as resolved.
Show resolved Hide resolved
if count == 1:
dtypes.append("cupy." + typename)
elif count == 2:
Expand Down Expand Up @@ -286,8 +288,7 @@ def unittestmap():

def getunittests(test_inputs, test_outputs):
unit_tests = {**test_outputs, **test_inputs}
num_outputs = len(test_outputs)
return unit_tests, num_outputs
return unit_tests


def gettypename(spectype):
Expand Down Expand Up @@ -602,32 +603,52 @@ def gencpuunittests(specdict):
funcName = (
"def test_unit_cpu" + spec.name + "_" + str(num) + "():\n"
)
unit_tests, num_outputs = getunittests(
test["inputs"], test["outputs"]
)
unit_tests = getunittests(test["inputs"], test["outputs"])
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), test["error"], spec.args)
if flag and range:
num += 1
f.write(funcName)
for i, (arg, val) in enumerate(unit_tests.items()):
for arg, val in test["outputs"].items():
typename = remove_const(
next(
argument
for argument in spec.args
if argument.name == arg
).typename
)
if i < num_outputs:
f.write(
" " * 4
+ arg
+ " = "
+ str([gettypeval(typename)] * len(val))
+ "\n"
)
else:
f.write(" " * 4 + arg + " = " + str(val) + "\n")
f.write(
" " * 4
+ arg
+ " = "
+ str([gettypeval(typename)] * len(val))
+ "\n"
)
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if count == 1:
f.write(
" " * 4
+ f"{arg} = (ctypes.c_{typename}*len({arg}))(*{arg})\n"
)
elif count == 2:
f.write(
" " * 4
+ "{0} = ctypes.pointer(ctypes.cast((ctypes.c_{1}*len({0}[0]))(*{0}[0]),ctypes.POINTER(ctypes.c_{1})))\n".format(
arg, typename
)
)
for arg, val in test["inputs"].items():
typename = remove_const(
next(
argument
for argument in spec.args
if argument.name == arg
).typename
)

f.write(" " * 4 + arg + " = " + str(val) + "\n")
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
Expand Down Expand Up @@ -680,6 +701,7 @@ def gencpuunittests(specdict):
"awkward_ListArray_min_range",
"awkward_ListArray_validity",
"awkward_BitMaskedArray_to_ByteMaskedArray",
"awkward_ListArray_broadcast_tooffsets",
"awkward_ListArray_compact_offsets",
"awkward_ListOffsetArray_flatten_offsets",
"awkward_IndexedArray_overlay_mask",
Expand Down Expand Up @@ -716,14 +738,18 @@ def gencpuunittests(specdict):
"awkward_RegularArray_reduce_nonlocal_preparenext",
"awkward_missing_repeat",
"awkward_RegularArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_carrylen",
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
"awkward_ListArray_getitem_next_range_counts",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
"awkward_ListArray_getitem_next_range_spreadadvanced",
"awkward_ListArray_localindex",
"awkward_NumpyArray_reduce_adjust_starts_64",
"awkward_NumpyArray_reduce_adjust_starts_shifts_64",
"awkward_RegularArray_getitem_next_at",
Expand All @@ -740,14 +766,22 @@ def gencpuunittests(specdict):
"awkward_IndexedArray_getitem_nextcarry",
"awkward_IndexedArray_getitem_nextcarry_outindex",
"awkward_IndexedArray_index_of_nulls",
"awkward_IndexedArray_ranges_next_64",
"awkward_IndexedArray_ranges_carry_next_64",
"awkward_IndexedArray_reduce_next_64",
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64",
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_fromshifts_64",
"awkward_IndexedOptionArray_rpad_and_clip_mask_axis1",
"awkward_ListOffsetArray_rpad_and_clip_axis1",
"awkward_ListOffsetArray_rpad_length_axis1",
"awkward_ListOffsetArray_toRegularArray",
# "awkward_ListOffsetArray_rpad_axis1",
"awkward_MaskedArray_getitem_next_jagged_project",
"awkward_UnionArray_project",
"awkward_ListOffsetArray_drop_none_indexes",
"awkward_ListOffsetArray_reduce_local_nextparents_64",
"awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64",
"awkward_UnionArray_regular_index_getsize",
"awkward_UnionArray_simplify",
"awkward_UnionArray_simplify_one",
"awkward_reduce_argmax",
Expand Down Expand Up @@ -841,8 +875,10 @@ def gencudakerneltests(specdict):
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool" or typename == "float":
if typename == "bool":
typename = typename + "_"
if typename == "float":
typename = typename + "32"
if count == 1:
f.write(
" " * 4
Expand Down Expand Up @@ -957,9 +993,7 @@ def gencudaunittests(specdict):
"def test_unit_cuda" + spec.name + "_" + str(num) + "():\n"
)
dtypes = getdtypes(spec.args)
unit_tests, num_outputs = getunittests(
test["inputs"], test["outputs"]
)
unit_tests = getunittests(test["inputs"], test["outputs"])
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), test["error"], spec.args)
if flag and range:
Expand All @@ -969,7 +1003,7 @@ def gencudaunittests(specdict):
"@pytest.mark.skip(reason='Kernel is not implemented properly')\n"
)
f.write(funcName)
for i, (arg, val) in enumerate(unit_tests.items()):
for arg, val in test["outputs"].items():
typename = remove_const(
next(
argument
Expand All @@ -982,25 +1016,50 @@ def gencudaunittests(specdict):
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool" or typename == "float":
if typename == "bool":
typename = typename + "_"
if typename == "float":
typename = typename + "32"
if count == 1:
if i < num_outputs:
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg,
[gettypeval(typename)] * len(val),
typename,
)
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg,
[gettypeval(typename)] * len(val),
typename,
)
else:
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg, val, typename
)
)
elif count == 2:
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg, val, typename
)
)
for arg, val in test["inputs"].items():
typename = remove_const(
next(
argument
for argument in spec.args
if argument.name == arg
).typename
)
if "List" not in typename:
f.write(" " * 4 + arg + " = " + str(val) + "\n")
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool":
typename = typename + "_"
if typename == "float":
typename = typename + "32"
if count == 1:
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg, val, typename
)
)
elif count == 2:
f.write(
" " * 4
Expand Down Expand Up @@ -1083,7 +1142,8 @@ def genunittests():
for key in test["outputs"]:
line += key + " = " + key + ","
for key in test["inputs"]:
line += key + " = " + key + ","
if key not in test["outputs"]:
line += key + " = " + key + ","
line = line[0 : len(line) - 1]
line += ")\n"
if test["error"]:
Expand Down
9 changes: 5 additions & 4 deletions kernel-specification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1001,10 +1001,10 @@ kernels:
stride = fromstops[i] - fromstarts[i]
tostarts[i] = k
for j in range(stride):
if index[fromstarts[i] + j] > 0:
if index[fromstarts[i] + j] >= 0:
k = k + 1
tostops[i] = k
tolength = k
tolength[0] = k
automatic-tests: false


Expand Down Expand Up @@ -1040,7 +1040,7 @@ kernels:
for i in range(length):
stride = fromstops[i] - fromstarts[i]
for j in range(stride):
if index[fromstarts[i] + j] > 0:
if index[fromstarts[i] + j] >= 0:
tocarry[k] = index[fromstarts[i] + j]
k = k + 1
automatic-tests: false
Expand Down Expand Up @@ -1515,7 +1515,8 @@ kernels:
if slicestop > missinglength:
raise ValueError("jagged slice's offsets extend beyond its content")
for j in range(slicestart, slicestop):
numvalid[0] = numvalid[0] + 1 if missing[j] >= 0 else 0
if missing[j] >= 0:
numvalid[0] = numvalid[0] + 1
automatic-tests: false

- name: awkward_ListArray_getitem_jagged_shrink
Expand Down
Loading
Loading