Skip to content

Commit

Permalink
style: pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] authored and ManasviGoyal committed Jan 19, 2024
1 parent 307f36c commit 8e70170
Showing 1 changed file with 7 additions and 20 deletions.
27 changes: 7 additions & 20 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,7 @@ def checkuint(inputs, args):
flag = True
for arg, val in inputs:
typename = remove_const(
next(
argument
for argument in args
if argument.name == arg
).typename
next(argument for argument in args if argument.name == arg).typename
)
if "List[uint" in typename and (any(n < 0 for n in val)):
flag = False
Expand Down Expand Up @@ -515,15 +511,10 @@ def gencpuunittests(specdict):
)

for spec in specdict.values():
if (
spec.templatized_kernel_name in list(unit_tests.keys())
):
if spec.templatized_kernel_name in list(unit_tests.keys()):
func = "test_cpu" + spec.name + ".py"
num = 0
with open(
os.path.join(unit_tests_cuda_kernels, func),
"w"
) as f:
with open(os.path.join(unit_tests_cuda_kernels, func), "w") as f:
f.write(
f"""# AUTO GENERATED ON {reproducible_datetime()}
# DO NOT EDIT BY HAND!
Expand Down Expand Up @@ -608,12 +599,10 @@ def gencpuunittests(specdict):
args += arg.name
count += 1
else:
args += ", " + arg.name
args += ", " + arg.name
f.write(" " * 4 + "ret_pass = funcC(" + args + ")\n")
for arg, val in test["outputs"].items():
f.write(
" " * 4 + "pytest_" + arg + " = " + str(val) + "\n"
)
f.write(" " * 4 + "pytest_" + arg + " = " + str(val) + "\n")
if isinstance(val, list):
f.write(
" " * 4
Expand Down Expand Up @@ -896,7 +885,7 @@ def gencudaunittests(specdict):
unit_test_values = unit_tests[spec.templatized_kernel_name]
tests = unit_test_values["tests"]
status = unit_test_values["status"]
for test in tests:
for test in tests:
num += 1
funcName = "def test_" + spec.name + "_" + str(num) + "():\n"
dtypes = getdtypes(spec.args)
Expand Down Expand Up @@ -982,9 +971,7 @@ def gencudaunittests(specdict):
"""
)
for arg, val in test["outputs"].items():
f.write(
" " * 4 + "pytest_" + arg + " = " + str(val) + "\n"
)
f.write(" " * 4 + "pytest_" + arg + " = " + str(val) + "\n")
if isinstance(val, list):
f.write(
" " * 4
Expand Down

0 comments on commit 8e70170

Please sign in to comment.