Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch Upstream][IGC] Stock pytorch fp16 e2e test got segmentfault from triton. #1073

Closed
etaf opened this issue May 9, 2024 · 7 comments
Closed

Comments

@etaf
Copy link

etaf commented May 9, 2024

We got segmentfault when runing stock pytorch fp16 end2end test. we've narrow down a mini reporducer as follow:

import faulthandler; faulthandler.enable()

import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config

torch._inductor.config.fallback_random = True
torch._inductor.config.freezing = True
torch._inductor.config.triton.cudagraphs = True
torch._functorch.config.unlift_effect_tokens = True
torch._functorch.config.debug_partitioner = True

isolate_fails_code_str = None

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, arg0_1):
        isnan = torch.ops.aten.isnan.default(arg0_1);  arg0_1 = None
        any_1 = torch.ops.aten.any.default(isnan);  isnan = None
        return (any_1,)

def load_args(reader):
    buf0 = reader.storage(None, 2097152, device=device(type='xpu', index=0), dtype_hint=torch.float16)
    reader.tensor(buf0, (1, 1024, 1024), dtype=torch.float16, is_leaf=True)  # arg0_1
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
    from torch._dynamo.repro.after_aot import run_repro
    with torch.no_grad():
        run_repro(mod, load_args, accuracy=False, command='run', save_dir=None, tracing_mode='real', check_str=None)

We got the callstack as follow:

Current thread 0x00007ffb1fd0c740 (most recent call first):
  File "/home/xinanlin/xinanlin/miniconda3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 369 in _init_handles
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 406 in _precompile_config
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 203 in precompile
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/codecache.py", line 2932 in result
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/codecache.py", line 3128 in wait
  File "/tmp/torchinductor_xinanlin/lq/clqdbzfwkffulbfq265wdizbwhp2hb3vvnzqekc75wzbkcniuxem.py", line 77 in <module>
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/runtime/compile_tasks.py", line 44 in _reload_python_module
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/codecache.py", line 2567 in load_by_key_path
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/graph.py", line 1657 in compile_to_module
  File "/home/xinanlin/xinanlin/pytorch/torch/_dynamo/utils.py", line 273 in time_wrapper
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/graph.py", line 1710 in compile_to_fn
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/compile_fx.py", line 803 in fx_codegen_and_compile
  File "/home/xinanlin/xinanlin/miniconda3/lib/python3.10/contextlib.py", line 79 in inner
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/compile_fx.py", line 507 in compile_fx_inner
  File "/home/xinanlin/xinanlin/pytorch/torch/_dynamo/utils.py", line 273 in time_wrapper
  File "/home/xinanlin/xinanlin/miniconda3/lib/python3.10/contextlib.py", line 79 in inner
  File "/home/xinanlin/xinanlin/miniconda3/lib/python3.10/contextlib.py", line 79 in inner
  File "/home/xinanlin/xinanlin/pytorch/torch/_inductor/debug.py", line 304 in inner
  File "/home/xinanlin/xinanlin/pytorch/torch/_dynamo/repro/after_aot.py", line 83 in debug_wrapper
  File "/home/xinanlin/xinanlin/pytorch/torch/_dynamo/repro/after_aot.py", line 708 in repro_run
  File "/home/xinanlin/xinanlin/pytorch/torch/_dynamo/repro/after_aot.py", line 957 in run_repro
  File "/home/pt-gpu/4T-4652/xinanlin/pytorch/../test_any.py", line 41 in <module>

To reproduce, please build stock pytorch with env "USE_XPU=1", and run the above script.

@etaf
Copy link
Author

etaf commented May 9, 2024

@riverliuintel @vlad-penkin this issue blocked Inductor upstream process, please priotize, thanks.

@alexbaden alexbaden self-assigned this May 9, 2024
@alexbaden
Copy link
Contributor

I will investigate and report back.

@etaf
Copy link
Author

etaf commented May 9, 2024

Hi, @alexbaden:
The corresponding triton kernel:

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_xinanlin/rb/crb7vl4jnylgbdeqtxzpnde4vgafzac4akzqlpwhnozqn3sna272.py
# Source Nodes: [any_1, isnan], Original ATen: [aten.any, aten.isnan]
# any_1 => any_1
# isnan => isnan
triton_red_fused_any_isnan_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.reduction(
    size_hints=[1, 1048576],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*i1', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'xpu', 'constants': {2: 1}, 'configs': [AttrsDescriptor(
divisible_by_16=(0, 1, 3), equal_to_1=(2,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_any_isnan_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '196aef
a0a85bda494694e78ae7765ffc5ac720f56091b59c964159c6f79c4a45'}
)
@triton.jit                                                                                                                                         [19/92120]
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 1
    rnumel = 1048576
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp3 = tl.full([XBLOCK, RBLOCK], 0, tl.int1)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), None, eviction_policy='evict_first').to(tl.float32)
        tmp1 = libdevice.isnan(tmp0).to(tl.int1)
        tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
        tmp4 = _tmp3 | tmp2
        _tmp3 = tmp4
    tmp3 = triton_helpers.any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)
    tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None)
''', device_str='xpu')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _xpu_getCurrentRawStream as get_raw_stream
async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (1, 1024, 1024), (1048576, 1024, 1))
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        buf0 = empty_strided((), (), device='xpu', dtype=torch.bool)
        # Source Nodes: [any_1, isnan], Original ATen: [aten.any, aten.isnan]
        stream0 = get_raw_stream(0)
        triton_red_fused_any_isnan_0.run(arg0_1, buf0, 1, 1048576, grid=grid(1), stream=stream0)
        del arg0_1
    return (buf0, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((1, 1024, 1024), (1048576, 1024, 1), device='xpu:0', dtype=torch.float16)
    fn = lambda: call([arg0_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('BartForConditionalGeneration', benchmark_compiled_module)

@alexbaden
Copy link
Contributor

Looks like the crash is in IGC:

Program received signal SIGSEGV, Segmentation fault.
0x00007ffff319d574 in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
(gdb) bt
#0  0x00007ffff319d574 in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
#1  0x00007ffff31a7403 in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
#2  0x00007ffff31a7ccc in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
#3  0x00007ffff3be0aee in llvm::FPPassManager::runOnFunction(llvm::Function&) () from /lib/x86_64-linux-gnu/libigc.so.1
#4  0x00007ffff3be0e14 in llvm::FPPassManager::runOnModule(llvm::Module&) () from /lib/x86_64-linux-gnu/libigc.so.1
#5  0x00007ffff3be1bac in llvm::legacy::PassManagerImpl::run(llvm::Module&) () from /lib/x86_64-linux-gnu/libigc.so.1
#6  0x00007ffff3153b2c in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
#7  0x00007ffff2e0c321 in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
#8  0x00007ffff3050afb in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
#9  0x00007ffff2e0e307 in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
#10 0x00007ffff2e7c4a5 in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
#11 0x00007ffff2e7da3e in ?? () from /lib/x86_64-linux-gnu/libigc.so.1
#12 0x00007ffff7efac64 in NEO::OfflineCompiler::buildSourceCode() () from /lib/x86_64-linux-gnu/libocloc.so
#13 0x00007ffff7efe4e5 in NEO::OfflineCompiler::build() () from /lib/x86_64-linux-gnu/libocloc.so
#14 0x00007ffff7f366ce in int SafetyGuardLinux::call<int, NEO::OfflineCompiler, int (NEO::OfflineCompiler::*)()>(NEO::OfflineCompiler*, int (NEO::OfflineCompiler::*)(), int) () from /lib/x86_64-linux-gnu/libocloc.so
#15 0x00007ffff7f363ee in buildWithSafetyGuard(NEO::OfflineCompiler*) () from /lib/x86_64-linux-gnu/libocloc.so
#16 0x00007ffff7ef19c8 in Ocloc::Commands::compile(OclocArgHelper*, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > const&) () from /lib/x86_64-linux-gnu/libocloc.so
#17 0x00007ffff7edda35 in oclocInvoke () from /lib/x86_64-linux-gnu/libocloc.so
#18 0x0000555555554787 in main ()

I have the spirv dumps that cause the IGC crash and will submit a ticket to the IGC team.

@etaf
Copy link
Author

etaf commented May 10, 2024

Hi @alexbaden , can you share the IGC ticket link?

@vlad-penkin
Copy link
Contributor

@etaf you've been added as watcher to the IGC ticket and should've receive an email notification.

@vlad-penkin vlad-penkin changed the title [PyTorch Upstream] Stock pytorch fp16 e2e test got segmentfault from triton. [PyTorch Upstream][IGC] Stock pytorch fp16 e2e test got segmentfault from triton. May 20, 2024
@vlad-penkin vlad-penkin self-assigned this May 20, 2024
@etaf
Copy link
Author

etaf commented May 30, 2024

Hi, @alexbaden @vlad-penkin how can I get the fixed driver?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants