Skip to content

Commit

Permalink
[SR] Eliminate extra permute ops before aten::sum (pytorch#74481)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#74481

This diff fixes an interesting performance issue related to `permute_copy`.

We see this pattern frequently:
```
y = torch.permute(x, (0, 2, 1))
z = torch.sum(y, dim=-1)
```

With copy variants off, we get a strided output from `permute`, and we hit this (faster) kernel in `sum`: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589

But with copy variants on, we get a contiguous output from `permute_copy`, which causes us to hit the slower reduction:
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597

But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on:
```
z = torch.sum(x, dim=1)
```
ghstack-source-id: 152003888

Reviewed By: navahgar

Differential Revision: D34992319

fbshipit-source-id: 0baf493708ee2180c899814a954d220d88ba1d4f
(cherry picked from commit 797b6be)
  • Loading branch information
Mike Iovine authored and pytorchmergebot committed Mar 23, 2022
1 parent d9f2cf5 commit f5a9c36
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 0 deletions.
59 changes: 59 additions & 0 deletions benchmarks/static_runtime/test_static_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1529,3 +1529,62 @@ TEST(ForceNonEmptyOutputs, TwoSubBlocks) {
}
}
}

TEST(EliminateExtraPermuteOps, FusesCorrectly) {
const auto src = R"JIT(
def forward(self, x):
y = torch.permute(x, (0, 2, 1))
z = torch.sum(y, dim=-1)
return z
)JIT";
torch::jit::Module mod("m");
mod.define(src);

auto graph = mod.get_method("forward").graph();
// turn the ListConstruct(%constant) into proper constant lists
ConstantPropagation(graph);
EliminateExtraPermuteOps(graph);

EXPECT_FALSE(hasNodeWithKind(graph, "aten::permute"));
auto* sum = getNodeWithKind(graph, "aten::sum");
ASSERT_NE(sum, nullptr);
auto dim = toIValue(sum->input(1));
ASSERT_TRUE(dim.has_value() && dim->isIntList());
EXPECT_EQ(dim->toIntList(), c10::List<int64_t>{1});
}

TEST(EliminateExtraPermuteOps, DoesNotFuseWrongDim) {
const auto src = R"JIT(
def forward(self, x):
y = torch.permute(x, (0, 2, 1))
z = torch.sum(y, dim=1)
return z
)JIT";
torch::jit::Module mod("m");
mod.define(src);

auto graph = mod.get_method("forward").graph();
// turn the ListConstruct(%constant) into proper constant lists
ConstantPropagation(graph);
EliminateExtraPermuteOps(graph);

EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
}

TEST(EliminateExtraPermuteOps, DoesNotFuseNonConstantDim) {
const auto src = R"JIT(
def forward(self, x, dim: int):
y = torch.permute(x, (0, 2, 1))
z = torch.sum(y, dim=dim)
return z
)JIT";
torch::jit::Module mod("m");
mod.define(src);

auto graph = mod.get_method("forward").graph();
// turn the ListConstruct(%constant) into proper constant lists
ConstantPropagation(graph);
EliminateExtraPermuteOps(graph);

EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
}
1 change: 1 addition & 0 deletions torch/csrc/jit/runtime/static/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ void OptimizeGraph(
UseVariadicCat(graph);
UseVariadicStack(graph);
EliminateTrivialEquallySplit(graph);
EliminateExtraPermuteOps(graph);

if (opts.enable_out_variant) {
UseVariadicOp(
Expand Down
45 changes: 45 additions & 0 deletions torch/csrc/jit/runtime/static/passes.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/runtime/static/passes.h>

#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
Expand Down Expand Up @@ -1005,5 +1006,49 @@ void ForceNonEmptyOutputs(Graph& graph) {
}
}

void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph) {
auto input_is_constant_list =
[](Node* node, size_t input_idx, const c10::List<int64_t>& expected) {
auto input_opt = toIValue(node->input(input_idx));
if (!input_opt.has_value() || !input_opt->isIntList()) {
return false;
}
return input_opt->toIntList() == expected;
};

// SubgraphRewriter can't pattern-match on constants, so we use this
// extra filter to make sure the values of the `dim` arguments are
// correct.
auto dims_are_valid_constants =
[&input_is_constant_list](
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
// Get the nodes in the real graph from the nodes in the template
// pattern graph
const auto& node_map = match.nodes_map;
auto* sum_node = node_map.at(vmap.at("c")->node());
auto* permute_node = node_map.at(vmap.at("b")->node());
return input_is_constant_list(sum_node, 1, c10::List<int64_t>{-1}) &&
input_is_constant_list(
permute_node, 1, c10::List<int64_t>{0, 2, 1});
};

const auto pattern = R"IR(
graph(%a, %sum_dim, %permute_dim, %keepdim, %dtype):
%b = aten::permute(%a, %permute_dim)
%c = aten::sum(%b, %sum_dim, %keepdim, %dtype)
return (%c))IR";

const auto fused_pattern = R"IR(
graph(%a, %sum_dim, %permute_dim, %keepdim, %dtype):
%new_sum_dim: int[] = prim::Constant[value=[1]]()
%d = aten::sum(%a, %new_sum_dim, %keepdim, %dtype)
return (%d))IR";

SubgraphRewriter fuse;
fuse.RegisterRewritePattern(pattern, fused_pattern);
fuse.runOnGraph(graph, dims_are_valid_constants);
}

} // namespace jit
} // namespace torch
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/static/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,7 @@ TORCH_API void ForceNonEmptyOutputs(Graph& graph);

TORCH_API void UseVariadicGroupedAccessor(const std::shared_ptr<Graph>& graph);

TORCH_API void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph);

} // namespace jit
} // namespace torch

0 comments on commit f5a9c36

Please sign in to comment.