Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed lowering for select scatter #112411

Closed
drisspg opened this issue Oct 30, 2023 · 4 comments
Closed

Failed lowering for select scatter #112411

drisspg opened this issue Oct 30, 2023 · 4 comments
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@drisspg
Copy link
Contributor

drisspg commented Oct 30, 2023

🐛 Describe the bug

Inductor failed lowering for select scatter. I originally did not the to(a.dtype) which is also valid eager code(same failure).

Repro:

import torch

def main():
    a = torch.rand(16, 16, device="cuda", dtype=torch.bfloat16)
    b = torch.rand(16, device="cuda", dtype=torch.float32)

    @torch.compile(fullgraph = True)
    def func(a, b):
        abs_max = torch.abs(a).max()
        b[0] = abs_max.to(a.dtype)

    func(a, b)

if __name__ == '__main__':
    main()
    ```

### Error logs

``` Shell
  File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 381, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
  File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 584, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/home/drisspg/meta/pytorch/torch/_dynamo/utils.py", line 221, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 455, in run
    return super().run(*args)
  File "/home/drisspg/meta/pytorch/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 757, in run_node
    result = super().run_node(n)
  File "/home/drisspg/meta/pytorch/torch/fx/interpreter.py", line 195, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 627, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 624, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/home/drisspg/meta/pytorch/torch/_inductor/lowering.py", line 289, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/home/drisspg/meta/pytorch/torch/_inductor/lowering.py", line 2219, in select_scatter
    assert x.get_dtype() == src.get_dtype()
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError:
  target: aten.select_scatter.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.float32, size=[16], stride=[1]))
  ))
  args[1]: TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('cuda', torch.bfloat16, size=[], stride=[]), data=Reduction(
      'cuda',
      torch.bfloat16,
      def inner_fn(index, rindex):
          r0, r1 = rindex
          tmp0 = ops.load(arg0_1, r1 + 16 * r0)
          tmp1 = ops.abs(tmp0)
          return tmp1
      ,
      ranges=[],
      reduction_ranges=[16, 16],
      reduction_type=max,
      origin_node=max_1,
      origins={abs_1, max_1}
    ))
  ))
  args[2]: 0
  args[3]: 0

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Minified repro

No response

Versions

Nightly

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@williamwen42
Copy link
Member

cc inductor people @desertfire @eellison @Chillee @shunting314 @mlazos

@williamwen42 williamwen42 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 30, 2023
@Chillee Chillee self-assigned this Oct 30, 2023
@drisspg
Copy link
Contributor Author

drisspg commented Oct 31, 2023

As well, I commented out the top level, assert equal dtypes and this fixed my cause. This is likely because the scatter into dtype is equal to float32 which is the default type promotion?

@Chillee
Copy link
Contributor

Chillee commented Oct 31, 2023

I think this is my fault 😅 , particularly this commit (#112093), which has since been reverted.

Very interestingly, this particular pattern (I'm guessing the b[0] = ...) triggers a copy with converting between different dtypes, which we aren't checking before removing copy nodes from the graph.

This also fixes the issue: https://github.com/pytorch/pytorch/pull/112476/files#diff-5d0c4891ce10ae399cd6d44f7408532d1a9eb56103ab0cd7db3613cdf4fbaf50R496

@Chillee
Copy link
Contributor

Chillee commented Nov 1, 2023

Fixed by #112476

@Chillee Chillee closed this as completed Nov 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants