Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion error from linear layouts #4727

Open
peterbell10 opened this issue Sep 13, 2024 · 6 comments · Fixed by #4731
Open

Assertion error from linear layouts #4727

peterbell10 opened this issue Sep 13, 2024 · 6 comments · Fixed by #4731
Assignees

Comments

@peterbell10
Copy link
Contributor

I am running into an assertion error in the codegen for local_load which is coming from the linear layouts code. Here is a minified reproducer

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 2056 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @test_fn() attributes {noinline = false} {
    %0 = triton_gpu.local_alloc  { allocation.offset = 0 : i32} : () -> !tt.memdesc<4x128xf32, #shared, #triton_gpu.shared_memory, mutable>
    %1 = triton_gpu.local_load %0 : !tt.memdesc<4x128xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<4x128xf32, #blocked>
    tt.return
  }
}

When lowering to llvm ir it fails with the following error

$ triton-opt --convert-triton-gpu-to-llvm repro.ttgir

triton-opt: /root/code/triton/lib/Tools/LinearLayout.cpp:512: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: /root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt --convert-triton-gpu-to-llvm repro.ttgir
 #0 0x00005621e7032ff7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c94ff7)
 #1 0x00005621e7030b1e llvm::sys::RunSignalHandlers() (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c92b1e)
 #2 0x00005621e70336af SignalHandler(int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c956af)
 #3 0x00007f20a484c420 __restore_rt (/usr/lib/x86_64-linux-gnu/libpthread.so.0+0x14420)
 #4 0x00007f20a431900b raise /build/glibc-LcI20x/glibc-2.31/signal/../sysdeps/unix/sysv/linux/raise.c:51:1
 #5 0x00007f20a42f8859 abort /build/glibc-LcI20x/glibc-2.31/stdlib/abort.c:81:7
 #6 0x00007f20a42f8729 get_sysdep_segment_value /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:509:8
 #7 0x00007f20a42f8729 _nl_load_domain /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:970:34
 #8 0x00007f20a4309fd6 (/usr/lib/x86_64-linux-gnu/libc.so.6+0x33fd6)
 #9 0x00005621e4aac52a mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const /root/code/triton/lib/Tools/LinearLayout.cpp:520:37
#10 0x00005621e46b26dd mlir::emitTransferBetweenRegistersAndShared(mlir::RankedTensorType, mlir::triton::MemDescType, mlir::Type, std::optional<int>, mlir::Value, llvm::ArrayRef<mlir::Value>, mlir::Location, mlir::RewriterBase&, mlir::triton::TargetInfoBase const&, std::function<void (mlir::VectorType, mlir::Value)>) /root/code/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp:307:61
#11 0x00005621e46b31f3 mlir::loadSharedToDistributed(mlir::RankedTensorType, mlir::triton::MemDescType, mlir::Type, mlir::LLVM::SharedMemoryObject, mlir::Location, mlir::RewriterBase&, mlir::triton::TargetInfoBase const&) /root/code/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp:386:55
#12 0x00005621e47a4185 (anonymous namespace)::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&) const /root/code/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:172:69
#13 0x00005621e47a3d05 (anonymous namespace)::LocalLoadOpConversion::matchAndRewrite(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::ConversionPatternRewriter&) const /root/code/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:124:47
#14 0x00005621e47ac85d mlir::ConvertOpToLLVMPattern<mlir::triton::gpu::LocalLoadOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /root/.triton/llvm/llvm-c08c6a71-ubuntu-x64/include/mlir/Conversion/LLVMCommon/Pattern.h:166:77
#15 0x00005621e6b3bd10 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279dd10)
#16 0x00005621e6b7a65b mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::$_2::operator()() const (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27dc65b)
#17 0x00005621e6b771df mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27d91df)
#18 0x00005621e6b3cca1 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279eca1)
#19 0x00005621e6b3bdb4 mlir::OperationConverter::convert(mlir::ConversionPatternRewriter&, mlir::Operation*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279ddb4)
#20 0x00005621e6b3d1bf mlir::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279f1bf)
#21 0x00005621e6b438fb mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27a58fb)
#22 0x00005621e4d6e312 (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation() /root/code/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp:178:15
#23 0x00005621e6081996 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce3996)
#24 0x00005621e6082140 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce4140)
#25 0x00005621e60845f5 mlir::PassManager::run(mlir::Operation*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce65f5)
#26 0x00005621e607dccf performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdfccf)
#27 0x00005621e607d8fd llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_2>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdf8fd)
#28 0x00005621e6fb2656 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c14656)
#29 0x00005621e6078721 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cda721)
#30 0x00005621e60789d3 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cda9d3)
#31 0x00005621e6078da6 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdada6)
#32 0x00005621e4e72ad0 main /root/code/triton/bin/triton-opt.cpp:9:0
#33 0x00007f20a42fa083 __libc_start_main /build/glibc-LcI20x/glibc-2.31/csu/../csu/libc-start.c:342:3
#34 0x00005621e468707e _start (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2e907e)

cc @Jokeren @jlebar

@Jokeren
Copy link
Contributor

Jokeren commented Sep 13, 2024

Just to confirm, the TritonGPU IR is generated from valid Triton python code?

@peterbell10
Copy link
Contributor Author

It's came from the lowering from a new operator I'm adding, but I'll see if I can reproduce with an existing operator.

@peterbell10
Copy link
Contributor Author

This produces the same error on the current master branch

import triton.language as tl
import triton
import torch

@triton.jit
def test_fn(out_ptr, a_ptr, workspace, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
    desc_ptr = workspace
    tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=desc_ptr, global_address=a_ptr, load_size=[4, N_BLOCK], global_size=[M, N], element_ty=a_ptr.dtype.element_ty)
    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_ptr)

    gather = tl._experimental_descriptor_load(desc_ptr, [0, 0], [4, N_BLOCK], a_ptr.dtype.element_ty)
    tl.store(out_ptr + tl.arange(0, 4)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :], gather)


out = torch.empty((4, 128), dtype=torch.float32, device="cuda")
inp = torch.arange(4 * 128, dtype=torch.float32, device="cuda").reshape(4, 128)
workspace = torch.empty(128, dtype=torch.uint8, device="cuda")
test_fn[(1,)](out, inp, workspace, 4, 128, 4, 128)

@Jokeren
Copy link
Contributor

Jokeren commented Sep 13, 2024

I'll take a look today

@peterbell10
Copy link
Contributor Author

Reopening this as it seems the TMA hardware does support swizzling with only 4 rows of data.

I get this result if it's helpful:

unswizzled:
tensor([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,  11.,
          12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,
          24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,
          36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,
          48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,
          60.,  61.,  62.,  63.],
        [128., 129., 130., 131., 132., 133., 134., 135., 136., 137., 138., 139.,
         140., 141., 142., 143., 144., 145., 146., 147., 148., 149., 150., 151.,
         152., 153., 154., 155., 156., 157., 158., 159., 160., 161., 162., 163.,
         164., 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175.,
         176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187.,
         188., 189., 190., 191.],
        [256., 257., 258., 259., 260., 261., 262., 263., 264., 265., 266., 267.,
         268., 269., 270., 271., 272., 273., 274., 275., 276., 277., 278., 279.,
         280., 281., 282., 283., 284., 285., 286., 287., 288., 289., 290., 291.,
         292., 293., 294., 295., 296., 297., 298., 299., 300., 301., 302., 303.,
         304., 305., 306., 307., 308., 309., 310., 311., 312., 313., 314., 315.,
         316., 317., 318., 319.],
        [384., 385., 386., 387., 388., 389., 390., 391., 392., 393., 394., 395.,
         396., 397., 398., 399., 400., 401., 402., 403., 404., 405., 406., 407.,
         408., 409., 410., 411., 412., 413., 414., 415., 416., 417., 418., 419.,
         420., 421., 422., 423., 424., 425., 426., 427., 428., 429., 430., 431.,
         432., 433., 434., 435., 436., 437., 438., 439., 440., 441., 442., 443.,
         444., 445., 446., 447.]], device='cuda:0', dtype=torch.float16)

swizzled:
tensor([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,  11.,
          12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,
          24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,
          36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,
          48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,
          60.,  61.,  62.,  63.],
        [136., 137., 138., 139., 140., 141., 142., 143., 128., 129., 130., 131.,
         132., 133., 134., 135., 152., 153., 154., 155., 156., 157., 158., 159.,
         144., 145., 146., 147., 148., 149., 150., 151., 168., 169., 170., 171.,
         172., 173., 174., 175., 160., 161., 162., 163., 164., 165., 166., 167.,
         184., 185., 186., 187., 188., 189., 190., 191., 176., 177., 178., 179.,
         180., 181., 182., 183.],
        [272., 273., 274., 275., 276., 277., 278., 279., 280., 281., 282., 283.,
         284., 285., 286., 287., 256., 257., 258., 259., 260., 261., 262., 263.,
         264., 265., 266., 267., 268., 269., 270., 271., 304., 305., 306., 307.,
         308., 309., 310., 311., 312., 313., 314., 315., 316., 317., 318., 319.,
         288., 289., 290., 291., 292., 293., 294., 295., 296., 297., 298., 299.,
         300., 301., 302., 303.],
        [408., 409., 410., 411., 412., 413., 414., 415., 400., 401., 402., 403.,
         404., 405., 406., 407., 392., 393., 394., 395., 396., 397., 398., 399.,
         384., 385., 386., 387., 388., 389., 390., 391., 440., 441., 442., 443.,
         444., 445., 446., 447., 432., 433., 434., 435., 436., 437., 438., 439.,
         424., 425., 426., 427., 428., 429., 430., 431., 416., 417., 418., 419.,
         420., 421., 422., 423.]], dtype=torch.float16)

@peterbell10 peterbell10 reopened this Sep 30, 2024
@Jokeren
Copy link
Contributor

Jokeren commented Sep 30, 2024

I think the problem is on this line int tileRows = 8;
I'll try to address it tomorrow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants