diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index b5a87d9446d8..31815fc71060 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -228,6 +228,10 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, } // Step 4. Create block body. + // helper to transform the expr and remap iters to the block domain + auto f_transform_and_remap = [&](const PrimExpr& e) { + return Substitute(info->transformer(e), var_map); + }; String block_name{nullptr}; Optional init = NullOpt; Stmt body; @@ -246,8 +250,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, // - A RHS operand is the value to be reduced. for (int i = 0; i < n_buffers; ++i) { const PrimExpr& left = BufferLoad(buffers[i], indices); - const PrimExpr& right = - analyzer->Simplify(Substitute(info->transformer(reduce->source[i]), var_map)); + const PrimExpr& right = analyzer->Simplify(f_transform_and_remap(reduce->source[i])); lhs.push_back(left); rhs.push_back(right); ICHECK_EQ(left->dtype, right->dtype); @@ -267,13 +270,15 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, // then store the value of the variables into the target buffer positions. for (int i = 0; i < n_buffers; ++i) { const Buffer& buffer = buffers[i]; - init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices)); + PrimExpr identity = f_transform_and_remap(reduce->combiner->identity_element[i]); + init_stmts.push_back(BufferStore(buffer, identity, indices)); PrimExpr value{nullptr}; if (n_buffers > 1) { temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype()))); value = temp_vars.back(); } else { - value = reduce->combiner.get()->operator()(lhs, rhs)[i]; + PrimExpr combined = reduce->combiner.get()->operator()(lhs, rhs)[i]; + value = f_transform_and_remap(combined); } body_stmts.push_back(BufferStore(buffer, value, indices)); } @@ -283,7 +288,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, if (n_buffers > 1) { // When there are multiple buffers, we wrap the body with LetStmts. for (int i = n_buffers - 1; i >= 0; --i) { - PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i]; + PrimExpr value = f_transform_and_remap(reduce->combiner.get()->operator()(lhs, rhs)[i]); body = LetStmt(temp_vars[i], std::move(value), std::move(body)); } } @@ -291,7 +296,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, // Case 2. Data parallel compute ICHECK_EQ(tensors.size(), 1); block_name = info->FreshName(tensors[0]->GetNameHint()); - const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map); + const PrimExpr& compute_body = f_transform_and_remap(expr_body); body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices); } diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index ade414f4234f..1a7e03188a25 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -814,5 +814,78 @@ def test_with_var_input(): _check_workload(te_slice_with_var_input, tir_slice_with_var_input, index_dtype_override="int64") +def test_loop_aware_initial_value(): + """Test initial value aware of spatial iter position""" + + @T.prim_func + def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"}) + a = T.match_buffer(var_a, (5, 5)) + b = T.match_buffer(var_b, (5,)) + sum_red = T.match_buffer(var_sum_red, (5,)) + for i, ax in T.grid(5, 5): + with T.block("sum_red"): + v_i, v_ax = T.axis.remap("SR", [i, ax]) + T.reads(b[v_i], a[v_i, v_ax]) + T.writes(sum_red[v_i]) + with T.init(): + sum_red[v_i] = b[v_i] + sum_red[v_i] = sum_red[v_i] + a[v_i, v_ax] + + def te_workload(): + data = te.placeholder((5, 5), "float32", "a") + init = te.placeholder((5,), "float32", "b") + ax = te.reduce_axis((0, 5), "ax") + sum_red = te.compute( + (5,), + lambda i: te.comm_reducer( + lambda x, y: x + y, + lambda t: init[i], + )(data[i, ax], axis=[ax]), + name="sum_red", + ) + return [data, init, sum_red] + + _check_workload(te_workload, tir_workload) + + +def test_loop_aware_reducer_combiner(): + """Test combiner aware of spatial iter position""" + + @T.prim_func + def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"}) + a = T.match_buffer(var_a, (5, 5)) + b = T.match_buffer(var_b, (5,)) + sum_red = T.match_buffer(var_sum_red, (5,)) + for i, ax in T.grid(5, 5): + with T.block("sum_red"): + v_i = T.axis.spatial(5, i) + v_ax = T.axis.reduce(5, ax) + T.reads(a[v_i, 0:5]) + T.writes(sum_red[v_i]) + with T.init(): + sum_red[v_i] = T.float32(0.0) + sum_red[v_i] = T.if_then_else( + a[v_i, sum_red[v_i]] < a[v_i, v_ax], sum_red[v_i], T.Cast("float32", v_ax) + ) + + def te_workload(): + data = te.placeholder((5, 5), "float32", "a") + init = te.placeholder((5,), "float32", "b") + ax = te.reduce_axis((0, 5), "ax") + sum_red = te.compute( + (5,), + lambda i: te.comm_reducer( + lambda x, y: te.if_then_else(data[i, x] < y, x, ax), + lambda _: te.const(0, "float32"), + )(data[i, ax], axis=[ax]), + name="sum_red", + ) + return [data, init, sum_red] + + _check_workload(te_workload, tir_workload) + + if __name__ == "__main__": tvm.testing.main()