diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index 69072b91b2fa5..a245344755f04 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -339,9 +339,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, const SparseTensorType stt(rtp); lvlRank = stt.getLvlRank(); - // We always treat sparse output tensor as dense so that we always iterate - // it based on lvl size. - if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) { + if (stt.hasEncoding()) { const auto enc = stt.getEncoding(); isSparseSlices[tid] = enc.isSlice(); for (auto lvlTp : enc.getLvlTypes()) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 3fb90ef379a57..e0d3ce241e454 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1059,28 +1059,20 @@ static bool translateBitsToTidLvlPairs( } if (isUndefLT(lt)) { // An undefined lt in the lattices, we probably mean to - // iterate based on the level of output tensor. E.g., this - // could be a synthetic tensor (for invariants and sparse - // output tensor). - auto itType = env.op().getIteratorTypesArray()[ldx]; - if (linalg::isReductionIterator(itType) && - env.merger().getSynTensorID() == tid) { - // Coiterating with an invariant, and this is a reduction loop + // generate a dense loop according to the synthetic tensor (for + // invariants and sparse output tensor). + if (env.merger().getSynTensorID() == tid) { + // Coiterating with an invariant // e.g., out = prod(in[i][j] op invariant); - // In this case, we can not infer the loop bound from output - // (whose level is reduced). Instead we use the synthetic tensor - // to infer the bound. + // or a broadcast + // e.g., out[i][j] = in[i] (j is undef for input) + // // The level of the synthetic tensor is the current loop depth; // the rank of the synthetic tensor equals to number of loops. lvl = env.emitter().getCurrentDepth(); - } else { - // or a broadcast - // out[i][j] = in[i] (j is undef for input) - tid = outTid; - lvl = outLvl; + } else if (!lvl) { // Skips invalid lvl (e.g., when this is a zero ranked tensor). - if (!lvl) - return; + return; } } hasNonUnique = !isUniqueLT(lt) || hasNonUnique;