-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sparse] code simplification: always use synthetical tensor for… #73597
Merged
+10
−20
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) Changes… loop bound. Full diff: https://github.com/llvm/llvm-project/pull/73597.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 69072b91b2fa523..a245344755f0404 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -339,9 +339,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
const SparseTensorType stt(rtp);
lvlRank = stt.getLvlRank();
- // We always treat sparse output tensor as dense so that we always iterate
- // it based on lvl size.
- if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
+ if (stt.hasEncoding()) {
const auto enc = stt.getEncoding();
isSparseSlices[tid] = enc.isSlice();
for (auto lvlTp : enc.getLvlTypes())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3fb90ef379a5778..e0d3ce241e454d0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1059,28 +1059,20 @@ static bool translateBitsToTidLvlPairs(
}
if (isUndefLT(lt)) {
// An undefined lt in the lattices, we probably mean to
- // iterate based on the level of output tensor. E.g., this
- // could be a synthetic tensor (for invariants and sparse
- // output tensor).
- auto itType = env.op().getIteratorTypesArray()[ldx];
- if (linalg::isReductionIterator(itType) &&
- env.merger().getSynTensorID() == tid) {
- // Coiterating with an invariant, and this is a reduction loop
+ // generate a dense loop according to the synthetic tensor (for
+ // invariants and sparse output tensor).
+ if (env.merger().getSynTensorID() == tid) {
+ // Coiterating with an invariant
// e.g., out = prod(in[i][j] op invariant);
- // In this case, we can not infer the loop bound from output
- // (whose level is reduced). Instead we use the synthetic tensor
- // to infer the bound.
+ // or a broadcast
+ // e.g., out[i][j] = in[i] (j is undef for input)
+ //
// The level of the synthetic tensor is the current loop depth;
// the rank of the synthetic tensor equals to number of loops.
lvl = env.emitter().getCurrentDepth();
- } else {
- // or a broadcast
- // out[i][j] = in[i] (j is undef for input)
- tid = outTid;
- lvl = outLvl;
+ } else if (!lvl) {
// Skips invalid lvl (e.g., when this is a zero ranked tensor).
- if (!lvl)
- return;
+ return;
}
}
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
|
@llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) Changes… loop bound. Full diff: https://github.com/llvm/llvm-project/pull/73597.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 69072b91b2fa523..a245344755f0404 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -339,9 +339,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
const SparseTensorType stt(rtp);
lvlRank = stt.getLvlRank();
- // We always treat sparse output tensor as dense so that we always iterate
- // it based on lvl size.
- if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
+ if (stt.hasEncoding()) {
const auto enc = stt.getEncoding();
isSparseSlices[tid] = enc.isSlice();
for (auto lvlTp : enc.getLvlTypes())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3fb90ef379a5778..e0d3ce241e454d0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1059,28 +1059,20 @@ static bool translateBitsToTidLvlPairs(
}
if (isUndefLT(lt)) {
// An undefined lt in the lattices, we probably mean to
- // iterate based on the level of output tensor. E.g., this
- // could be a synthetic tensor (for invariants and sparse
- // output tensor).
- auto itType = env.op().getIteratorTypesArray()[ldx];
- if (linalg::isReductionIterator(itType) &&
- env.merger().getSynTensorID() == tid) {
- // Coiterating with an invariant, and this is a reduction loop
+ // generate a dense loop according to the synthetic tensor (for
+ // invariants and sparse output tensor).
+ if (env.merger().getSynTensorID() == tid) {
+ // Coiterating with an invariant
// e.g., out = prod(in[i][j] op invariant);
- // In this case, we can not infer the loop bound from output
- // (whose level is reduced). Instead we use the synthetic tensor
- // to infer the bound.
+ // or a broadcast
+ // e.g., out[i][j] = in[i] (j is undef for input)
+ //
// The level of the synthetic tensor is the current loop depth;
// the rank of the synthetic tensor equals to number of loops.
lvl = env.emitter().getCurrentDepth();
- } else {
- // or a broadcast
- // out[i][j] = in[i] (j is undef for input)
- tid = outTid;
- lvl = outLvl;
+ } else if (!lvl) {
// Skips invalid lvl (e.g., when this is a zero ranked tensor).
- if (!lvl)
- return;
+ return;
}
}
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
|
aartbik
approved these changes
Nov 28, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
… loop bound.