Skip to content

Commit

Permalink
Squeeze p2: hook up Squeeze to LazyView (pytorch#73067)
Browse files Browse the repository at this point in the history
Summary:
This PR hooks up Squeeze op to LazyView.
The end goal to reduce the instances where we need to rely on explicit shapes.
In the final PR we will make `aten_ops` to use the right ViewInfo and update the lowering in ts_lowering.

Pull Request resolved: pytorch#73067

Reviewed By: wconstab, mikaylagawarecki

Differential Revision: D34345163

Pulled By: Krovatkin

fbshipit-source-id: 6bfadedbded7521312019ead0dfc7c6a334ff0f5
(cherry picked from commit 4b3b10f)
  • Loading branch information
Krovatkin authored and pytorchmergebot committed Feb 24, 2022
1 parent 6c6ae0e commit 5a7778c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
21 changes: 21 additions & 0 deletions torch/csrc/lazy/core/lazy_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <torch/csrc/lazy/core/view_ops/permute.h>
#include <torch/csrc/lazy/core/view_ops/resize.h>
#include <torch/csrc/lazy/core/view_ops/select.h>
#include <torch/csrc/lazy/core/view_ops/squeeze.h>
#include <torch/csrc/lazy/core/view_ops/unsqueeze.h>
#include <torch/csrc/lazy/core/view_ops/select_view_update.h>
#include <torch/csrc/lazy/core/view_ops/view.h>

Expand Down Expand Up @@ -43,6 +45,10 @@ Value ApplyViewInfo(Value ir_value, const ViewInfo& view_info) {
return MakeNode<View>(ir_value, view_info.shape.sizes().vec());
case ViewInfo::Type::kResize:
return MakeNode<Resize>(ir_value, view_info.shape.sizes().vec());
case ViewInfo::Type::kSqueeze:
return MakeNode<torch::lazy::Squeeze>(ir_value, view_info.squeeze_index);
case ViewInfo::Type::kUnsqueeze:
return MakeNode<torch::lazy::Unsqueeze>(ir_value, view_info.squeeze_index);
case ViewInfo::Type::kAsStrided:
return MakeNode<AsStrided>(
ir_value,
Expand Down Expand Up @@ -98,6 +104,12 @@ Value ApplyUpdate(Value ir_value, const Alias::UpdateData& update_data) {
case ViewInfo::Type::kResize:
result = MakeNode<Resize>(result, view_info.source_shape.sizes().vec());
break;
case ViewInfo::Type::kSqueeze:
result = MakeNode<torch::lazy::Unsqueeze>(ir_value, view_info.squeeze_index);
break;
case ViewInfo::Type::kUnsqueeze:
result = MakeNode<torch::lazy::Squeeze>(ir_value, view_info.squeeze_index);
break;
case ViewInfo::Type::kAsStrided:
result = MakeNode<AsStridedViewUpdate>(
tmp_values[i - 1],
Expand Down Expand Up @@ -130,6 +142,15 @@ ViewInfo::ViewInfo(Type view_type, Shape shape, Shape source_shape)
indices(source_shape.dim(), 0),
source_shape(std::move(source_shape)) {}

ViewInfo::ViewInfo(Type view_type, Shape shape, Shape source_shape, int64_t sqi)
: view_type(view_type),
shape(std::move(shape)),
source_shape(std::move(source_shape)),
squeeze_index(sqi)
{
TORCH_CHECK(view_type == Type::kSqueeze);
}

ViewInfo::ViewInfo(
Type view_type,
Shape source_shape,
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/lazy/core/lazy_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ struct TORCH_API ViewInfo {
kSelect,
kAsStrided,
kDiagonal,
kSqueeze,
kUnsqueeze,
};

ViewInfo() = default;
ViewInfo(Type view_type, Shape shape, Shape source_shape);
ViewInfo(Type view_type, Shape shape, Shape source_shape, int64_t sqi);
ViewInfo(
Type view_type,
Shape source_shape,
Expand Down Expand Up @@ -92,6 +95,8 @@ struct TORCH_API ViewInfo {
c10::optional<AsStridedInfo> as_strided;
// Information used for diagonal views.
c10::optional<DiagonalInfo> diagonal;
// Squeeze/Unsqueeze Index
int64_t squeeze_index;
};

// When a "view" (capture by reference) is taken on a node, an Alias object is
Expand Down

0 comments on commit 5a7778c

Please sign in to comment.