Skip to content

Commit

Permalink
tests: remove out of bound tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ManasviGoyal committed Jan 19, 2024
1 parent 8e70170 commit 8c886b8
Showing 1 changed file with 55 additions and 77 deletions.
132 changes: 55 additions & 77 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from itertools import product

import yaml
import numpy as np
from numpy import uint8 # noqa: F401 (used in evaluated strings)

CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -210,23 +211,55 @@ def getdtypes(args):
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool" or typename == "float":
typename = typename + "_"
if count == 1:
dtypes.append("cupy." + typename)
return dtypes


def checkuint(inputs, args):
def checkuint(test_args, args):
flag = True
for arg, val in inputs:
for arg, val in test_args:
typename = remove_const(
next(argument for argument in args if argument.name == arg).typename
)
if "List[uint" in typename and (any(n < 0 for n in val)):
flag = False
return flag

def checkintrange(test_args, args):
flag = True
for arg, val in test_args:
typename = remove_const(
next(argument for argument in args if argument.name == arg).typename
)
if "int" in typename or "uint" in typename:
dtype = gettypename(typename)
min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max
if "List" in typename:
for data in val:
if not (min_val <= data <= max_val):
flag = False
else:
if not (min_val <= val <= max_val):
flag = False
return flag


def unittestmap():
with open(os.path.join(CURRENT_DIR, "..", "kernel-test-data.json")) as f:
data = json.load(f)["unit-tests"]
unit_tests_map = {}
for function in data:
tests = function["tests"]
status = function["status"]
unit_tests_map[function["name"]] = {"tests": tests, "status": status}
return unit_tests_map


def getunittests(test_inputs, test_outputs):
unit_tests = {**test_inputs, **test_outputs}
return unit_tests


def gettypename(spectype):
typename = spectype.replace("List", "").replace("[", "").replace("]", "")
Expand Down Expand Up @@ -488,7 +521,7 @@ def gencpukerneltests(specdict):
def gencpuunittests(specdict):
print("Generating Unit Tests for CPU kernels")

unit_tests = unittestmap()
unit_test_map = unittestmap()
unit_tests_cuda_kernels = os.path.join(
CURRENT_DIR, "..", "awkward-cpp", "tests-cpu-kernels-explicit"
)
Expand All @@ -511,7 +544,7 @@ def gencpuunittests(specdict):
)

for spec in specdict.values():
if spec.templatized_kernel_name in list(unit_tests.keys()):
if spec.templatized_kernel_name in list(unit_test_map.keys()):
func = "test_cpu" + spec.name + ".py"
num = 0
with open(os.path.join(unit_tests_cuda_kernels, func), "w") as f:
Expand All @@ -534,39 +567,17 @@ def gencpuunittests(specdict):
"import pytest\n\n"
"from awkward_cpp.cpu_kernels import lib\n\n"
)
unit_test_values = unit_tests[spec.templatized_kernel_name]
unit_test_values = unit_test_map[spec.templatized_kernel_name]
tests = unit_test_values["tests"]
for test in tests:
num += 1
funcName = "def test_" + spec.name + "_" + str(num) + "():\n"
flag = checkuint(test["inputs"].items(), spec.args)
if flag is True:
unit_tests = getunittests(test["inputs"], test["outputs"])
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), spec.args)
if flag and range:
f.write(funcName)
for arg, val in test["outputs"].items():
f.write(" " * 4 + arg + " = " + str(val) + "\n")
typename = remove_const(
next(
argument
for argument in spec.args
if argument.name == arg
).typename
)
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if count == 1:
f.write(
" " * 4
+ f"{arg} = (ctypes.c_{typename}*len({arg}))(*{arg})\n"
)
elif count == 2:
f.write(
" " * 4
+ "{0} = ctypes.pointer(ctypes.cast((ctypes.c_{1}*len({0}[0]))(*{0}[0]),ctypes.POINTER(ctypes.c_{1})))\n".format(
arg, typename
)
)
for arg, val in test["inputs"].items():
for arg, val in unit_tests.items():
f.write(" " * 4 + arg + " = " + str(val) + "\n")
typename = remove_const(
next(
Expand Down Expand Up @@ -813,21 +824,10 @@ def gencudakerneltests(specdict):
f.write("\n")


def unittestmap():
with open(os.path.join(CURRENT_DIR, "..", "kernel-test-data.json")) as f:
data = json.load(f)["unit-tests"]
unit_tests_map = {}
for function in data:
tests = function["tests"]
status = function["status"]
unit_tests_map[function["name"]] = {"tests": tests, "status": status}
return unit_tests_map


def gencudaunittests(specdict):
print("Generating Unit Tests for CUDA kernels")

unit_tests = unittestmap()
unit_test_map = unittestmap()
unit_tests_cuda_kernels = os.path.join(
CURRENT_DIR, "..", "tests-cuda-kernels-explicit"
)
Expand All @@ -852,7 +852,7 @@ def gencudaunittests(specdict):
for spec in specdict.values():
if (
spec.templatized_kernel_name in cuda_kernels_tests
and spec.templatized_kernel_name in list(unit_tests.keys())
and spec.templatized_kernel_name in list(unit_test_map.keys())
):
func = "test_cuda" + spec.name + ".py"
num = 0
Expand Down Expand Up @@ -882,21 +882,23 @@ def gencudaunittests(specdict):
"from awkward._backends.cupy import CupyBackend\n\n"
"cupy_backend = CupyBackend.instance()\n\n"
)
unit_test_values = unit_tests[spec.templatized_kernel_name]
unit_test_values = unit_test_map[spec.templatized_kernel_name]
tests = unit_test_values["tests"]
status = unit_test_values["status"]
for test in tests:
num += 1
funcName = "def test_" + spec.name + "_" + str(num) + "():\n"
dtypes = getdtypes(spec.args)
flag = checkuint(test["inputs"].items(), spec.args)
if flag is True:
unit_tests = getunittests(test["inputs"], test["outputs"])
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), spec.args)
if flag and range:
if not status:
f.write(
"@pytest.mark.skip(reason='Unable to generate any tests for kernel')\n"
"@pytest.mark.skip(reason='Kernel is not implemented properly')\n"
)
f.write(funcName)
for arg, val in test["outputs"].items():
for arg, val in unit_tests.items():
typename = remove_const(
next(
argument
Expand All @@ -906,32 +908,8 @@ def gencudaunittests(specdict):
)
if "List" not in typename:
f.write(
" " * 4 + arg + " = " + str([123] * len(val)) + "\n"
" " * 4 + arg + " = " + str(val) + "\n"
)
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool" or typename == "float":
typename = typename + "_"
if count == 1:
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg, str([123] * len(val)), typename
)
)
elif count == 2:
raise NotImplementedError
for arg, val in test["inputs"].items():
typename = remove_const(
next(
argument
for argument in spec.args
if argument.name == arg
).typename
)
if "List" not in typename:
f.write(" " * 4 + arg + " = " + str(val) + "\n")
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
Expand Down

0 comments on commit 8c886b8

Please sign in to comment.