Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[Compile] Error with compile for aot_eager Tensor sublcass #117

Closed
drisspg opened this issue Oct 3, 2023 · 3 comments
Closed

[Compile] Error with compile for aot_eager Tensor sublcass #117

drisspg opened this issue Oct 3, 2023 · 3 comments
Labels
Compile Issues related with subclass compilation

Comments

@drisspg
Copy link
Contributor

drisspg commented Oct 3, 2023

Summary

Using nightly: torch==2.2.0.dev20231003+cu121

Repro

Running the following tests:
pytest tests/test_compile.py -k "aot_eager"

Output

t = _to_functional_tensor(FakeTensor(..., device='cuda:0', size=(16, 16), dtype=torch.float8_e4m3fn),
       device='cuda:0')

    def to_fun(t):
        if isinstance(t, Tensor):
            if t in memo:
                return memo[t]
>           r = torch._to_functional_tensor(t)
E           torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
E           RuntimeError: !at::functionalization::impl::isFunctionalTensor(base_) INTERNAL ASSERT FAILED at "../aten/src/ATen/FunctionalStorageImpl.cpp":101, please report a bug to PyTorch.
E
E           Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E
E
E           You can suppress this exception and fall back to eager by setting:
E               import torch._dynamo
E               torch._dynamo.config.suppress_errors = True

../../miniconda3/envs/nightly/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:732: BackendCompilerFailed
@drisspg drisspg added the Compile Issues related with subclass compilation label Oct 3, 2023
@drisspg
Copy link
Contributor Author

drisspg commented Oct 3, 2023

Spoke with @bdhirsh: This is expected as of this nightly.
See: pytorch/pytorch#110079 for tracking of changes

@drisspg
Copy link
Contributor Author

drisspg commented Oct 10, 2023

Full_graph = False
Latest error on top of Brians landed PR: ba86dfcd83eca5c12247aace06226ece27d185b9

     if fake_modes:
            fake_mode, desc1, i1 = fake_modes[0]
            for m, desc2, i2 in fake_modes[1:]:
>               assert fake_mode is m, (
                    f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
                    f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
                    f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
                )
E               torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
E               AssertionError: fake mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4df0521720>) from tracing context 0 doesn't match mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4df7fd3160>) from fake tensor input 0
E
E               fake mode from tracing context 0 allocated at:
E                 File "/home/drisspg/miniconda3/envs/dev/bin/pytest", line 8, in <module>
E                   sys.exit(console_main())
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/config/__init__.py", line 189, in console_main
E                   code = main()
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/config/__init__.py", line 166, in main
E                   ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/main.py", line 317, in pytest_cmdline_main
E                   return wrap_session(config, _main)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/main.py", line 270, in wrap_session
E                   session.exitstatus = doit(config, session) or 0
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/main.py", line 324, in _main
E                   config.hook.pytest_runtestloop(session=session)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/main.py", line 349, in pytest_runtestloop
E                   item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 114, in pytest_runtest_protocol
E                   runtestprotocol(item, nextitem=nextitem)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 133, in runtestprotocol
E                   reports.append(call_and_report(item, "call", log))
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 222, in call_and_report
E                   call = call_runtest_hook(item, when, **kwds)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 261, in call_runtest_hook
E                   return CallInfo.from_call(
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 341, in from_call
E                   result: Optional[TResult] = func()
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 262, in <lambda>
E                   lambda: ihook(item=item, **kwds), when=when, reraise=reraise
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 169, in pytest_runtest_call
E                   item.runtest()
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/python.py", line 1788, in runtest
E                   self.ihook.pytest_pyfunc_call(pyfuncitem=self)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/python.py", line 194, in pytest_pyfunc_call
E                   result = testfunction(**testargs)
E                 File "/home/drisspg/meta/float8_experimental/tests/test_compile.py", line 61, in test_aot_eager
E                   _test_compile_base("aot_eager", fullgraph, emulate, use_subclass, dtype)
E                 File "/home/drisspg/meta/float8_experimental/tests/test_compile.py", line 35, in _test_compile_base
E                   y_fp8 = m_fp8(x)
E                 File "/home/drisspg/meta/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
E                   return self._call_impl(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
E                   return forward_call(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/eval_frame.py", line 401, in _fn
E                   return fn(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
E                   return self._call_impl(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
E                   return forward_call(*args, **kwargs)
E                 File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_linear.py", line 252, in forward
E                   x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
E                 File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_linear.py", line 253, in <resume in forward>
E                   w_fp8 = self.cast_w_to_float8(
E                 File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_linear.py", line 255, in <resume in forward>
E                   y = self.float8_mm(
E                 File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_linear.py", line 228, in float8_mm
E                   y = float8_linear.apply(
E                 File "/home/drisspg/meta/pytorch/torch/autograd/function.py", line 551, in apply
E                   return super().apply(*args, **kwargs)  # type: ignore[misc]
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/eval_frame.py", line 549, in catch_errors
E                   return callback(frame, cache_entry, hooks, frame_state)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 643, in _convert_frame
E                   result = inner_convert(frame, cache_entry, hooks, frame_state)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 142, in _fn
E                   return fn(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 384, in _convert_frame_assert
E                   return _compile(
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 570, in _compile
E                   guarded_code = compile_inner(code, one_graph, hooks, transform)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/utils.py", line 221, in time_wrapper
E                   r = func(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 492, in compile_inner
E                   out_code = transform_code_object(code, transform)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
E                   transformations(instructions, code_options)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 445, in transform
E                   tracer = InstructionTranslator(
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/symbolic_convert.py", line 2011, in __init__
E                   output=OutputGraph(
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/output_graph.py", line 291, in __init__
E                   fake_mode = torch._subclasses.FakeTensorMode(
E                 File "/home/drisspg/meta/pytorch/torch/_subclasses/fake_tensor.py", line 1295, in __init__
E                   self.stack = "".join(traceback.format_stack())
E
E               fake mode from fake tensor input 0 allocated at:
E                 File "/home/drisspg/miniconda3/envs/dev/bin/pytest", line 8, in <module>
E                   sys.exit(console_main())
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/config/__init__.py", line 189, in console_main
E                   code = main()
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/config/__init__.py", line 166, in main
E                   ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/main.py", line 317, in pytest_cmdline_main
E                   return wrap_session(config, _main)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/main.py", line 270, in wrap_session
E                   session.exitstatus = doit(config, session) or 0
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/main.py", line 324, in _main
E                   config.hook.pytest_runtestloop(session=session)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/main.py", line 349, in pytest_runtestloop
E                   item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 114, in pytest_runtest_protocol
E                   runtestprotocol(item, nextitem=nextitem)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 133, in runtestprotocol
E                   reports.append(call_and_report(item, "call", log))
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 222, in call_and_report
E                   call = call_runtest_hook(item, when, **kwds)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 261, in call_runtest_hook
E                   return CallInfo.from_call(
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 341, in from_call
E                   result: Optional[TResult] = func()
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 262, in <lambda>
E                   lambda: ihook(item=item, **kwds), when=when, reraise=reraise
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/runner.py", line 169, in pytest_runtest_call
E                   item.runtest()
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/python.py", line 1788, in runtest
E                   self.ihook.pytest_pyfunc_call(pyfuncitem=self)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 493, in __call__
E                   return self._hookexec(self.name, self._hookimpls, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_manager.py", line 115, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/pluggy/_callers.py", line 77, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/_pytest/python.py", line 194, in pytest_pyfunc_call
E                   result = testfunction(**testargs)
E                 File "/home/drisspg/meta/float8_experimental/tests/test_compile.py", line 61, in test_aot_eager
E                   _test_compile_base("aot_eager", fullgraph, emulate, use_subclass, dtype)
E                 File "/home/drisspg/meta/float8_experimental/tests/test_compile.py", line 35, in _test_compile_base
E                   y_fp8 = m_fp8(x)
E                 File "/home/drisspg/meta/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
E                   return self._call_impl(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
E                   return forward_call(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/eval_frame.py", line 401, in _fn
E                   return fn(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
E                   return self._call_impl(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
E                   return forward_call(*args, **kwargs)
E                 File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_linear.py", line 252, in forward
E                   x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
E                 File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_linear.py", line 204, in cast_x_to_float8
E                   x_fp8 = Float8Tensor.to_float8(
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/eval_frame.py", line 549, in catch_errors
E                   return callback(frame, cache_entry, hooks, frame_state)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 643, in _convert_frame
E                   result = inner_convert(frame, cache_entry, hooks, frame_state)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 142, in _fn
E                   return fn(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 384, in _convert_frame_assert
E                   return _compile(
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 570, in _compile
E                   guarded_code = compile_inner(code, one_graph, hooks, transform)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/utils.py", line 221, in time_wrapper
E                   r = func(*args, **kwargs)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 492, in compile_inner
E                   out_code = transform_code_object(code, transform)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
E                   transformations(instructions, code_options)
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 445, in transform
E                   tracer = InstructionTranslator(
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/symbolic_convert.py", line 2011, in __init__
E                   output=OutputGraph(
E                 File "/home/drisspg/meta/pytorch/torch/_dynamo/output_graph.py", line 291, in __init__
E                   fake_mode = torch._subclasses.FakeTensorMode(
E                 File "/home/drisspg/meta/pytorch/torch/_subclasses/fake_tensor.py", line 1295, in __init__
E                   self.stack = "".join(traceback.format_stack())
E
E
E               Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E
E
E               You can suppress this exception and fall back to eager by setting:
E                   import torch._dynamo
E                   torch._dynamo.config.suppress_errors = True

../pytorch/torch/_guards.py:830: BackendCompilerFailed

@drisspg
Copy link
Contributor Author

drisspg commented Nov 3, 2023

By removing all graph breaks we were able to avoid not encounter this issue

@drisspg drisspg closed this as completed Nov 3, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Compile Issues related with subclass compilation
Projects
None yet
Development

No branches or pull requests

1 participant