Skip to content

Commit

Permalink
[Bugfix][TIR] Handle bool tensor in FlattenBuffer (#11532)
Browse files Browse the repository at this point in the history
This PR fixes an existing bug in TIR lowering where the TIR below triggers an error:

```python
@T.prim_func
def func(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    for i in T.serial(10):
        with T.block("b"):
            vi = T.axis.spatial(10, i)
            b[vi] = a[vi]

tvm.build(func, target="llvm")
```

The error message is:

```
  File "/root/Projects/tvm-dev/src/tir/transforms/flatten_buffer.cc", line 173
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

Check failed: store->buffer->dtype == DataType::Int(8) (bool vs. int8) : Expected int8 backing array
for boolean tensor
```

This PR fixes this behavior.
  • Loading branch information
junrushao authored Jun 2, 2022
1 parent e60849c commit 4c513b9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
18 changes: 9 additions & 9 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ class BufferFlattener : public StmtExprMutator {
static PrimFunc Flatten(PrimFunc func) {
Map<Var, Buffer> preflattened_buffer_map =
Merge(func->buffer_map, func->preflattened_buffer_map);

auto pass = BufferFlattener(func->buffer_map);

auto writer = func.CopyOnWrite();
writer->body = pass.VisitStmt(func->body);
writer->preflattened_buffer_map = preflattened_buffer_map;
Expand Down Expand Up @@ -137,7 +135,7 @@ class BufferFlattener : public StmtExprMutator {
} else {
PrimExpr expr = it->second;
if (expr.dtype() != var.dtype()) {
expr = Cast(var.dtype(), std::move(expr));
expr = tvm::cast(var.dtype(), std::move(expr));
}
return expr;
}
Expand All @@ -164,33 +162,35 @@ class BufferFlattener : public StmtExprMutator {

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
bool store_returns_bool = (op->value.dtype() == DataType::Bool());
store = VisitBufferAccess(store);

// Handle casts from the value's dtype to the dtype of the
// backing array.
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (store->value.dtype() == DataType::Bool()) {
if (store_returns_bool) {
ICHECK_EQ(store->buffer->dtype, DataType::Int(8))
<< "Expected int8 backing array for boolean tensor";
auto writer = store.CopyOnWrite();
writer->value = tir::Cast(DataType::Int(8), store->value);
writer->value = tvm::cast(DataType::Int(8), store->value);
return store;
}
auto flattened_indices = store->buffer->ElemOffset(store->indices);
return VisitBufferAccess(std::move(store));
return store;
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
bool load_returns_bool = (op->dtype == DataType::Bool());
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
load = VisitBufferAccess(load);

// Handle casts from dtype of the backing array to value's dtype.
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (load_returns_bool) {
ICHECK_EQ(load->buffer->dtype, DataType::Int(8))
<< "Expected int8 backing array for boolean tensor";
return tir::Cast(DataType::Bool(), load);
load.CopyOnWrite()->dtype = DataType::Int(8);
return tvm::cast(DataType::Bool(), load);
} else {
return std::move(load);
}
Expand Down
37 changes: 36 additions & 1 deletion tests/python/unittest/test_tir_transform_flatten_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import tir, te
from tvm import te, tir
from tvm.script import tir as T


Expand Down Expand Up @@ -268,6 +268,33 @@ def annotated_loops(a: T.handle) -> None:
A[i] = 0.0


@T.prim_func
def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None:
for i0 in T.serial(10):
with T.block("b"):
T.reads(a[i0])
T.writes(b[i0])
b[i0] = a[i0]


@T.prim_func
def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None:
T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
# body
for i0 in T.serial(10):
b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")


@T.prim_func
def boolean_handle_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None:
T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
# body
for i0 in T.serial(10):
b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")


def test_elementwise():
_check(compacted_elementwise_func, flattened_elementwise_func)

Expand Down Expand Up @@ -319,6 +346,13 @@ def test_annotated_loops():
tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))


def test_boolean_handling():
_check(boolean_handling_before, boolean_handling_after)
# mod = tvm.IRModule.from_expr(boolean_handling_before)
# mod = tvm.tir.transform.FlattenBuffer()(mod)
# print(mod.script())


if __name__ == "__main__":
test_elementwise()
test_gpu_workload()
Expand All @@ -329,3 +363,4 @@ def test_annotated_loops():
test_strided_buffer()
test_lower_te()
test_annotated_loops()
test_boolean_handling()

0 comments on commit 4c513b9

Please sign in to comment.