Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion failure in Linear Layouts when num_warps = 8, but passes with num_warps = 4 #5265

Closed
Moerafaat opened this issue Nov 27, 2024 · 6 comments
Assignees
Labels

Comments

@Moerafaat
Copy link
Contributor

Describe the bug

To reproduce the issue, you can run the following python test:

import torch
import triton
import triton.language as tl


@triton.jit
def repro_kernel(q_ref,
               k_ref,
               v_ref,
               output_ptr,
               ):
    offsets64 = tl.arange(0, 64)
    offsets128 = tl.arange(0, 128)
    q = tl.load(q_ref + (offsets64[:, None] * 128 + offsets128[None, :]))
    k = tl.load(k_ref + (offsets128[:, None] * 64 + offsets64[None, :]))
    qk = tl.dot(q, k).to(tl.bfloat16)
    v = tl.load(v_ref + (offsets64[:, None] * 128 + offsets128[None, :]))
    o = tl.dot(qk, v)
    tl.store(output_ptr + (offsets64[:, None] * 128 + offsets128[None, :]), o.to(tl.bfloat16))

def repro(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
    output = torch.empty((64, 128), dtype=torch.bfloat16, device='cuda')
    grid = lambda meta: (1, 1)
    k = repro_kernel[grid](q, k, v, output, num_warps=8, num_ctas=1, num_stages=3)
    # print(k.asm['ttir'])
    return output

torch.manual_seed(0)
q = torch.ones((64, 128), dtype=torch.bfloat16, device='cuda')
k = torch.ones((128, 64), dtype=torch.bfloat16, device='cuda')
v = torch.ones((64, 128), dtype=torch.bfloat16, device='cuda')
output_torch = (q @ k) @ v
output_triton = repro(q, k, v)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

You will encounter the following error:

python3: /tmp/triton/lib/Tools/LinearLayout.cpp:526: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeOuts(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalOutDimSize() == std::accumulate( newOutDims.begin(), newOutDims.end(), 1, [&](int32_t acc, auto &outDim) { return acc * outDim.second; })' failed.
Aborted

I notice that there was a similar report here #4727 before the issue was re-opened. Interestingly, the failure actually started happening with the commit that was linked to that issue. The culprit commit is 49266aa

The test passes if num_warps are set to 4 instead of 8, and used to work properly before the culprit commit.

Environment details

The issue reproduces on H100 with the latest Triton main: commit 8b29bb7

@Moerafaat Moerafaat added the bug label Nov 27, 2024
@Jokeren
Copy link
Contributor

Jokeren commented Nov 27, 2024

Interesting. Taking a look now.

#4727 is TMA so it's not related.

@Jokeren Jokeren self-assigned this Nov 27, 2024
@Jokeren
Copy link
Contributor

Jokeren commented Nov 27, 2024

FYI, I have a solution works for it now with stmatrix. Will upstream soon

where out dims are: [offset (size 4096), iteration (size 1)]
tensor([[8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        ...,
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.]], device='cuda:0',
       dtype=torch.bfloat16)
tensor([[8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        ...,
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.]], device='cuda:0',
       dtype=torch.bfloat16)

@Moerafaat
Copy link
Contributor Author

Thanks! Really appreciate the fast reply on this and looking forward to your fix 🙏

@Jokeren
Copy link
Contributor

Jokeren commented Nov 28, 2024

#5277 is a partial fix. More general fixes will be pushed next week

@Moerafaat
Copy link
Contributor Author

#5277 is a partial fix.

Tested it and it works great! Thanks for the fast turn-around!

@Moerafaat
Copy link
Contributor Author

Marking this fixed. Thanks for the assistance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants