Skip to content

Commit

Permalink
[OpenCL] Fix type casting (apache#11038)
Browse files Browse the repository at this point in the history
* [OpenCL] Fix type casting

The previous PR apache#11021 was reverted in apache#11035 due
to it affected performance of generated OpenCL code.

This PR fixed the same issue but doesn't lead to performance
degradation. Tested on Resnet50_v2 network.

* Implement using select built-in
  • Loading branch information
echuraev authored and driazati committed Apr 19, 2022
1 parent 1efd7df commit 30314c9
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ void CodeGenOpenCL::PrintRestrict(const Var& v, std::ostream& os) {

std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
return CastTo(value, target);
}

std::string CodeGenOpenCL::CastTo(std::string value, DataType target) {
std::ostringstream os;
if (target.lanes() == 1) {
os << "((";
Expand Down Expand Up @@ -512,6 +516,40 @@ void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) {
PrintBinaryExpr(op, "max", os, this);
}

void CodeGenOpenCL::VisitExpr_(const AndNode* op, std::ostream& os) {
std::ostringstream oss;
os << "(";
this->PrintExpr(op->a, oss);
os << CastTo(oss.str(), op->dtype);
oss.str("");
os << " && ";
this->PrintExpr(op->b, oss);
os << CastTo(oss.str(), op->dtype);
os << ")";
}

void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) {
std::ostringstream oss;
os << "(";
this->PrintExpr(op->a, oss);
os << CastTo(oss.str(), op->dtype);
oss.str("");
os << " || ";
this->PrintExpr(op->b, oss);
os << CastTo(oss.str(), op->dtype);
os << ")";
}

void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) {
os << "select(";
PrintExpr(op->false_value, os);
os << ", ";
PrintExpr(op->true_value, os);
os << ", ";
PrintExpr(op->condition, os);
os << ")";
}

void CodeGenOpenCL::SetTextureScope(
const std::unordered_map<const VarNode*, std::string>& scope) { // NOLINT(*)
for (auto& texture : scope) {
Expand Down
4 changes: 4 additions & 0 deletions src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CodeGenOpenCL final : public CodeGenC {
std::ostream& os); // NOLINT(*)
void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*)
std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
std::string CastTo(std::string value, DataType target); // NOLINT(*)
void SetTextureScope(const std::unordered_map<const VarNode*, std::string>&); // NOLINT(*)

// overload visitor
Expand All @@ -69,6 +70,9 @@ class CodeGenOpenCL final : public CodeGenC {
// overload min and max to avoid ambiguous call errors
void VisitExpr_(const MinNode* op, std::ostream& os) final;
void VisitExpr_(const MaxNode* op, std::ostream& os) final;
void VisitExpr_(const AndNode* op, std::ostream& os) final;
void VisitExpr_(const OrNode* op, std::ostream& os) final;
void VisitExpr_(const SelectNode* op, std::ostream& os) final;

private:
// whether enable fp16 and fp64 extension
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_lower_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_lower_build_tir_module():


def test_lower_build_lowered_module():
assert 1 == 0
# check lowering with the CSE pass disabled as otherwise it would do some commoning
with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]):
ir_mod = tvm.lower(LoweredTIRModule)
Expand Down
46 changes: 46 additions & 0 deletions tests/python/unittest/test_target_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,54 @@ def check_erf(dev, n, dtype):
check_erf(dev, 1, "float64")


@tvm.testing.requires_gpu
@tvm.testing.requires_opencl
def test_opencl_type_casting():
def check_type_casting(ctx, n, dtype):
block_size = 4
C = te.compute(
(n,),
lambda i: tvm.tir.Select(
tvm.tir.all(
*[
i // block_size == tvm.tir.const(3, "int32"),
i % block_size == tvm.tir.const(3, "int32"),
]
),
tvm.tir.const(1, dtype),
tvm.tir.const(0, dtype),
),
name="C",
)
s = te.create_schedule(C.op)
(tx, vx) = s[C].split(s[C].op.axis[0], factor=block_size)
s[C].vectorize(vx)
thrx = te.thread_axis("threadIdx.x")

s[C].bind(tx, thrx)
fun = tvm.build(s, [C], target)

c = tvm.nd.empty((n,), dtype, ctx)
assembly = fun.imported_modules[0].get_source()
false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))"
true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))"
lcond = "(convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))"
cond = "({} && {})".format(lcond, rcond)
select = "select({}, {}, {})".format(false_branch, true_branch, cond)
count = assembly.count(select)
assert count == 1

fun(c)

dev = tvm.device(target, 0)

check_type_casting(dev, 16, "float32")


if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()
test_opencl_max()
test_opencl_erf()
test_opencl_type_casting()

0 comments on commit 30314c9

Please sign in to comment.