Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Missing predicate guarding reduction init for a tensor scheduled with compute_at #9598

Closed
lazycal opened this issue Nov 26, 2021 · 0 comments
Labels

Comments

@lazycal
Copy link
Contributor

lazycal commented Nov 26, 2021

import tvm
from tvm import te
import numpy as np
import tvm.testing

F = 100
N = F + 1
A = te.placeholder((N, N), name="A")
k = te.reduce_axis((0, N), name="k")
B = te.compute((N,), lambda i: te.sum(A[i, k], k), name="B")
C = te.compute((N,), lambda i: B[i], name="C")

s = te.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=F)
s[B].compute_at(s[C], xi)

foo = tvm.build(s, [A, B, C], "llvm")
print(tvm.lower(s, [A, B, C], simple_mode=True))

anp = tvm.nd.array(np.random.uniform(
    size=(N, N)).astype(A.dtype), tvm.cpu())
bnp = tvm.nd.array(np.random.uniform(
    size=(N,)).astype(A.dtype), tvm.cpu())
cnp = tvm.nd.array(np.random.uniform(
    size=(N,)).astype(A.dtype), tvm.cpu())
foo(anp, bnp, cnp)
tvm.testing.assert_allclose(bnp.asnumpy(), cnp.asnumpy())

This triggers segmentation fault. The produced IR is

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [101], []),
             A: Buffer(A_2: Pointer(float32), float32, [101, 101], []),
             B: Buffer(B_2: Pointer(float32), float32, [101], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i.outer: int32, 0, 2) {
    for (i.inner: int32, 0, 100) {
      B_2[((i.outer*100) + i.inner)] = 0f32
      if @tir.likely((((i.outer*100) + i.inner) < 101), dtype=bool) {
        for (k: int32, 0, 101) {
          B_2[((i.outer*100) + i.inner)] = ((float32*)B_2[((i.outer*100) + i.inner)] + (float32*)A_2[(((i.outer*10100) + (i.inner*101)) + k)])
        }
      }
      if @tir.likely((((i.outer*100) + i.inner) < 101), dtype=bool) {
        C_2[((i.outer*100) + i.inner)] = (float32*)B_2[((i.outer*100) + i.inner)]
      }
    }
  }
}

where B_2[((i.outer*100) + i.inner)] = 0f32 isn't wrapped with the predicate as in the reduction body.

Investigation

The problem can be solved if we do not skip the bound check by replacing !stage->rolling_buffer with false in

ret.init_predicates =
MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter);
. However, I'm not sure if this is the right fix as I am having trouble understanding the logic of bound checking. The part that confuses me is why the reduction body does not skip the bound checks (shown in
ret.main_predicates =
MakeBoundCheck(stage, dom_map, ret.main_vmap, false, std::unordered_set<IterVar>());
) but the init skips it.

I see that there are two types (L550-L560 and L561-577) of bound checks in the MakeBoundCheck function

std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, PrimExpr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) {
arith::Analyzer analyzer;
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
std::vector<PrimExpr> preds;
Map<Var, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
iset_dmap.Set(kv.first->var, IntSet::FromRange(kv.second));
}
for (auto entry : dom_map) {
analyzer.Bind(entry.first->var, entry.second);
}
for (const IterVar& iv : stage->all_iter_vars) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
if (bound_state.at(iv)) {
Range dom = dom_map.at(iv);
PrimExpr value = value_map.at(iv) - dom->min;
PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
}
}
for (const IterVar& iv : stage->op->root_iter_vars()) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
Range dom = dom_map.at(iv);
ICHECK(iv->dom.defined());
if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
PrimExpr value = value_map.at(iv) - iv->dom->min;
IntSet s = analyzer.int_set(value, iset_dmap);
PrimExpr vmin = s.min();
PrimExpr vmax = s.max();
// The range of `value` resides in [vmin, vmax]
if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
preds.emplace_back(value >= 0);
}
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) {
preds.emplace_back(value < iv->dom->extent);
}
}
}
return preds;
}
and passing false to skip_ivar_domain only disables the second one. But the first check seems not comprehensive: in the above code, due to the compute_at B's axis is "implicitly" binded to a split axis of C, but the first check cannot see the split relation. As a result PassUpBoundCheck doesn't mark it as needing checks. So I'm also curious whehter this is expected or not.

Environment

OS: Ubuntu 18.04
TVM Version: ecd8a9c

@areusch areusch added the needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it label Oct 19, 2022
@hpanda-naut hpanda-naut added relay:ir src/relay/ir and removed needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Nov 28, 2022
@tqchen tqchen closed this as completed Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants