Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] code simplification: always use synthetical tensor for… #73597

Merged
merged 1 commit into from
Nov 28, 2023

Conversation

PeimingLiu
Copy link
Member

… loop bound.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
… loop bound.
@PeimingLiu PeimingLiu requested a review from aartbik November 28, 2023 01:19
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Nov 28, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 28, 2023

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

… loop bound.


Full diff: https://github.com/llvm/llvm-project/pull/73597.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+1-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+9-17)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 69072b91b2fa523..a245344755f0404 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -339,9 +339,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
       const SparseTensorType stt(rtp);
       lvlRank = stt.getLvlRank();
 
-      // We always treat sparse output tensor as dense so that we always iterate
-      // it based on lvl size.
-      if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
+      if (stt.hasEncoding()) {
         const auto enc = stt.getEncoding();
         isSparseSlices[tid] = enc.isSlice();
         for (auto lvlTp : enc.getLvlTypes())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3fb90ef379a5778..e0d3ce241e454d0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1059,28 +1059,20 @@ static bool translateBitsToTidLvlPairs(
       }
       if (isUndefLT(lt)) {
         // An undefined lt in the lattices, we probably mean to
-        // iterate based on the level of output tensor.  E.g., this
-        // could be a synthetic tensor (for invariants and sparse
-        // output tensor).
-        auto itType = env.op().getIteratorTypesArray()[ldx];
-        if (linalg::isReductionIterator(itType) &&
-            env.merger().getSynTensorID() == tid) {
-          // Coiterating with an invariant, and this is a reduction loop
+        // generate a dense loop according to the synthetic tensor (for
+        // invariants and sparse output tensor).
+        if (env.merger().getSynTensorID() == tid) {
+          // Coiterating with an invariant
           // e.g., out = prod(in[i][j] op invariant);
-          // In this case, we can not infer the loop bound from output
-          // (whose level is reduced). Instead we use the synthetic tensor
-          // to infer the bound.
+          // or a broadcast
+          // e.g., out[i][j] = in[i] (j is undef for input)
+          //
           // The level of the synthetic tensor is the current loop depth;
           // the rank of the synthetic tensor equals to number of loops.
           lvl = env.emitter().getCurrentDepth();
-        } else {
-          // or a broadcast
-          // out[i][j] = in[i] (j is undef for input)
-          tid = outTid;
-          lvl = outLvl;
+        } else if (!lvl) {
           // Skips invalid lvl (e.g., when this is a zero ranked tensor).
-          if (!lvl)
-            return;
+          return;
         }
       }
       hasNonUnique = !isUniqueLT(lt) || hasNonUnique;

@llvmbot
Copy link
Member

llvmbot commented Nov 28, 2023

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

… loop bound.


Full diff: https://github.com/llvm/llvm-project/pull/73597.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+1-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+9-17)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 69072b91b2fa523..a245344755f0404 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -339,9 +339,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
       const SparseTensorType stt(rtp);
       lvlRank = stt.getLvlRank();
 
-      // We always treat sparse output tensor as dense so that we always iterate
-      // it based on lvl size.
-      if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
+      if (stt.hasEncoding()) {
         const auto enc = stt.getEncoding();
         isSparseSlices[tid] = enc.isSlice();
         for (auto lvlTp : enc.getLvlTypes())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3fb90ef379a5778..e0d3ce241e454d0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1059,28 +1059,20 @@ static bool translateBitsToTidLvlPairs(
       }
       if (isUndefLT(lt)) {
         // An undefined lt in the lattices, we probably mean to
-        // iterate based on the level of output tensor.  E.g., this
-        // could be a synthetic tensor (for invariants and sparse
-        // output tensor).
-        auto itType = env.op().getIteratorTypesArray()[ldx];
-        if (linalg::isReductionIterator(itType) &&
-            env.merger().getSynTensorID() == tid) {
-          // Coiterating with an invariant, and this is a reduction loop
+        // generate a dense loop according to the synthetic tensor (for
+        // invariants and sparse output tensor).
+        if (env.merger().getSynTensorID() == tid) {
+          // Coiterating with an invariant
           // e.g., out = prod(in[i][j] op invariant);
-          // In this case, we can not infer the loop bound from output
-          // (whose level is reduced). Instead we use the synthetic tensor
-          // to infer the bound.
+          // or a broadcast
+          // e.g., out[i][j] = in[i] (j is undef for input)
+          //
           // The level of the synthetic tensor is the current loop depth;
           // the rank of the synthetic tensor equals to number of loops.
           lvl = env.emitter().getCurrentDepth();
-        } else {
-          // or a broadcast
-          // out[i][j] = in[i] (j is undef for input)
-          tid = outTid;
-          lvl = outLvl;
+        } else if (!lvl) {
           // Skips invalid lvl (e.g., when this is a zero ranked tensor).
-          if (!lvl)
-            return;
+          return;
         }
       }
       hasNonUnique = !isUniqueLT(lt) || hasNonUnique;

@PeimingLiu PeimingLiu merged commit 1ece4d3 into llvm:main Nov 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants