Skip to content

Commit

Permalink
Merge branch 'main' into feat-add-to-parquet-row-groups
Browse files Browse the repository at this point in the history
  • Loading branch information
zbilodea authored Jan 31, 2024
2 parents 1da5b07 + f2a2340 commit 14087c8
Show file tree
Hide file tree
Showing 7 changed files with 2,068 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ERROR awkward_ListArray_rpad_axis1(
}
offset = (target > rangeval) ? tostarts[i] + target : tostarts[i] + rangeval;
tostops[i] = offset;
}
}
return success();
}
ERROR awkward_ListArray32_rpad_axis1_64(
Expand Down
96 changes: 60 additions & 36 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,23 @@ def checkuint(test_args, args):
return flag


def checkintrange(test_args, args):
def checkintrange(test_args, error, args):
flag = True
for arg, val in test_args:
typename = remove_const(
next(argument for argument in args if argument.name == arg).typename
)
if "int" in typename or "uint" in typename:
dtype = gettypename(typename)
min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max
if "List" in typename:
for data in val:
if not (min_val <= data <= max_val):
if not error:
for arg, val in test_args:
typename = remove_const(
next(argument for argument in args if argument.name == arg).typename
)
if "int" in typename or "uint" in typename:
dtype = gettypename(typename)
min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max
if "List" in typename:
for data in val:
if not (min_val <= data <= max_val):
flag = False
else:
if not (min_val <= val <= max_val):
flag = False
else:
if not (min_val <= val <= max_val):
flag = False
return flag


Expand Down Expand Up @@ -581,7 +582,7 @@ def gencpuunittests(specdict):
test["inputs"], test["outputs"]
)
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), test["error"], spec.args)
if flag and range:
num += 1
f.write(funcName)
Expand Down Expand Up @@ -628,17 +629,25 @@ def gencpuunittests(specdict):
count += 1
else:
args += ", " + arg.name
f.write(" " * 4 + "ret_pass = funcC(" + args + ")\n")
for arg, val in test["outputs"].items():
f.write(" " * 4 + "pytest_" + arg + " = " + str(val) + "\n")
if isinstance(val, list):
if not test["error"]:
f.write(" " * 4 + "ret_pass = funcC(" + args + ")\n")
for arg, val in test["outputs"].items():
f.write(
" " * 4
+ f"assert {arg}[:len(pytest_{arg})] == pytest.approx(pytest_{arg})\n"
" " * 4 + "pytest_" + arg + " = " + str(val) + "\n"
)
else:
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")
f.write(" " * 4 + "assert not ret_pass.str\n")
if isinstance(val, list):
f.write(
" " * 4
+ f"assert {arg}[:len(pytest_{arg})] == pytest.approx(pytest_{arg})\n"
)
else:
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")
f.write(" " * 4 + "assert not ret_pass.str\n")
else:
f.write(
" " * 4
+ f"assert funcC({args}).str.decode('utf-8') == \"{test['message']}\"\n"
)
f.write("\n")


Expand Down Expand Up @@ -896,6 +905,7 @@ def gencudaunittests(specdict):
)

f.write(
"import re\n"
"import cupy\n"
"import pytest\n\n"
"import awkward as ak\n"
Expand All @@ -915,7 +925,7 @@ def gencudaunittests(specdict):
test["inputs"], test["outputs"]
)
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), test["error"], spec.args)
if flag and range:
num += 1
if not status:
Expand Down Expand Up @@ -973,24 +983,38 @@ def gencudaunittests(specdict):
else:
args += ", " + arg.name
f.write(" " * 4 + "funcC(" + args + ")\n")
f.write(
"""
if test["error"]:
f.write(
f"""
error_message = re.escape("{test['message']} in compiled CUDA code ({spec.templatized_kernel_name})")
"""
)
f.write(
""" with pytest.raises(ValueError, match=rf"{error_message}"):
ak_cu.synchronize_cuda()
"""
)
else:
f.write(
"""
try:
ak_cu.synchronize_cuda()
except:
pytest.fail("This test case shouldn't have raised an error")
"""
)
for arg, val in test["outputs"].items():
f.write(" " * 4 + "pytest_" + arg + " = " + str(val) + "\n")
if isinstance(val, list):
)
for arg, val in test["outputs"].items():
f.write(
" " * 4
+ f"assert cupy.array_equal({arg}[:len(pytest_{arg})], cupy.array(pytest_{arg}))\n"
" " * 4 + "pytest_" + arg + " = " + str(val) + "\n"
)
else:
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")
f.write("\n")
if isinstance(val, list):
f.write(
" " * 4
+ f"assert cupy.array_equal({arg}[:len(pytest_{arg})], cupy.array(pytest_{arg}))\n"
)
else:
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")
f.write("\n")


def genunittests():
Expand Down
20 changes: 20 additions & 0 deletions header-only/layout-builder/awkward/LayoutBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,26 @@ namespace awkward {
size_t id_;
};

/// @class String
///
/// @brief Helper for building an array of strings with a similar API as a
/// Numpy builder.
template<class PRIMITIVE>
class String : public ListOffset<PRIMITIVE, Numpy<uint8_t>> {
public:
String() : ListOffset<PRIMITIVE, Numpy<uint8_t>>() {
this->set_parameters(R"""("__array__": "string")""");
this->content().set_parameters(R"""("__array__": "char")""");
}

void append(const std::string& value) {
this->begin_list();
for (const auto& c: value) {
this->content().append(c);
}
this->end_list();
}
};

/// @class Empty
///
Expand Down
31 changes: 31 additions & 0 deletions header-only/tests/test_1494-layout-builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ using UnionBuilder8_U32 = awkward::LayoutBuilder::Union<int8_t, uint32_t, BUILDE
template<class... BUILDERS>
using UnionBuilder8_64 = awkward::LayoutBuilder::Union<int8_t, int64_t, BUILDERS...>;

template<class PRIMITIVE>
using StringBuilder = awkward::LayoutBuilder::String<PRIMITIVE>;

void
test_Numpy_bool() {
Expand Down Expand Up @@ -1860,6 +1862,33 @@ test_categorical_form() {
}


void test_string_builder() {
StringBuilder<int64_t> builder;
assert(builder.length() == 0);

builder.append("one");
builder.append("two");
builder.append("three");

assert(builder.length() == 3);
}

void test_list_string_builder() {
ListOffsetBuilder<int64_t, StringBuilder<int64_t>> builder;
assert(builder.length() == 0);

builder.begin_list();
builder.content().append("one");
builder.content().append("two");
builder.content().append("three");
builder.end_list();

builder.begin_list();
builder.content().append("four");
builder.content().append("five");
builder.end_list();
}

int main(int /* argc */, char ** /* argv */) {
test_Numpy_bool();
test_Numpy_int();
Expand Down Expand Up @@ -1888,6 +1917,8 @@ int main(int /* argc */, char ** /* argv */) {
test_char_form();
test_string_form();
test_categorical_form();
test_string_builder();
test_list_string_builder();

return 0;
}
Loading

0 comments on commit 14087c8

Please sign in to comment.