Skip to content

Commit

Permalink
[OpenCL] Fix type casting error (apache#11021)
Browse files Browse the repository at this point in the history
Faced situation when generated OpenCL kernel contained the following if
condition:
```
if (uint4(...) && (int4(...) == int4(...)))
```

In this case, got the following error:
"can't convert between vector values of different size ('uint4' and 'int __attribute__((ext_vector_type(4)))')"

Added casts for binary ops. But it was necessary to modify `CastFromTo`
and add new method `CastTo`. Because with `CastFromTo` the following
code was generated:
```
if (uint4(...) && (convert_uint4(int4(...)) == convert_uint4(int4(...))))
```
But the OpenCL compiler still generated the same error.

This is why added new method `CastTo`. In this method we don't check the
current type of op and just add cast to a new type.

Finally the following code will be generated:
```
if (uint4(...) && convert_uint4(convert_uint4(int4(...)) == convert_uint4(int4(...))))
```
  • Loading branch information
echuraev authored Apr 15, 2022
1 parent 365fcc8 commit 8aafe5b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ void CodeGenOpenCL::PrintRestrict(const Var& v, std::ostream& os) {

std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
return CastTo(value, target);
}

std::string CodeGenOpenCL::CastTo(std::string value, DataType target) {
std::ostringstream os;
if (target.lanes() == 1) {
os << "((";
Expand Down Expand Up @@ -512,6 +516,30 @@ void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) {
PrintBinaryExpr(op, "max", os, this);
}

void CodeGenOpenCL::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) {
std::ostringstream oss;
if (isalpha(op[0])) {
os << op << "(";
this->PrintExpr(lhs, oss);
os << CastTo(oss.str(), t);
oss.str("");
os << ", ";
this->PrintExpr(rhs, oss);
os << CastTo(oss.str(), t);
os << ")";
} else {
os << "(";
this->PrintExpr(lhs, oss);
os << CastTo(oss.str(), t);
oss.str("");
os << ' ' << op << ' ';
this->PrintExpr(rhs, oss);
os << CastTo(oss.str(), t);
os << ")";
}
}

void CodeGenOpenCL::SetTextureScope(
const std::unordered_map<const VarNode*, std::string>& scope) { // NOLINT(*)
for (auto& texture : scope) {
Expand Down
5 changes: 5 additions & 0 deletions src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CodeGenOpenCL final : public CodeGenC {
std::ostream& os); // NOLINT(*)
void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*)
std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
std::string CastTo(std::string value, DataType target); // NOLINT(*)
void SetTextureScope(const std::unordered_map<const VarNode*, std::string>&); // NOLINT(*)

// overload visitor
Expand All @@ -70,6 +71,10 @@ class CodeGenOpenCL final : public CodeGenC {
void VisitExpr_(const MinNode* op, std::ostream& os) final;
void VisitExpr_(const MaxNode* op, std::ostream& os) final;

// Binary vector op.
void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final;

private:
// whether enable fp16 and fp64 extension
bool enable_fp16_{false};
Expand Down
37 changes: 37 additions & 0 deletions tests/python/unittest/test_target_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,45 @@ def check_erf(dev, n, dtype):
check_erf(dev, 1, "float64")


@tvm.testing.requires_gpu
@tvm.testing.requires_opencl
def test_opencl_type_casting():
def check_type_casting(ctx, n, dtype):
block_size = 4
C = te.compute(
(n,),
lambda i: tvm.tir.Select(
tvm.tir.all(
*[
i // block_size == tvm.tir.const(3, "int32"),
i % block_size == tvm.tir.const(3, "int32"),
]
),
tvm.tir.const(1, dtype),
tvm.tir.const(0, dtype),
),
name="C",
)
s = te.create_schedule(C.op)
(tx, vx) = s[C].split(s[C].op.axis[0], factor=block_size)
s[C].vectorize(vx)
thrx = te.thread_axis("threadIdx.x")

s[C].bind(tx, thrx)
fun = tvm.build(s, [C], target)

c = tvm.nd.empty((n,), dtype, ctx)
# Only need to test compiling here
fun(c)

dev = tvm.device(target, 0)

check_type_casting(dev, 16, "float32")


if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()
test_opencl_max()
test_opencl_erf()
test_opencl_type_casting()

0 comments on commit 8aafe5b

Please sign in to comment.