Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BMM into BMM triggers internal assert #5211

Closed
falkaer opened this issue Nov 21, 2024 · 4 comments · Fixed by #5285
Closed

BMM into BMM triggers internal assert #5211

falkaer opened this issue Nov 21, 2024 · 4 comments · Fixed by #5285
Assignees
Labels

Comments

@falkaer
Copy link

falkaer commented Nov 21, 2024

Describe the bug

Performing two back-to-back bmm calls on NVIDIA GPU triggers an internal assert. On triton 3.1.0:

python3: /project/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp:84: mlir::LogicalResult {anonymous}::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, const mlir::LLVMTypeConverter*, mlir::ConversionPatternRewriter&) const: Assertion `dstShape.size() <= 2 && "Unexpected rank of ConvertLayout(shared->blocked)"' failed.

and on main (d5ba6ac):

python3: /project/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:210: llvm::LogicalResult {anonymous}::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, const mlir::LLVMTypeConverter*, mlir::ConversionPatternRewriter&) const: Assertion `(dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"' failed.

Below is my WIP kernel that triggers the issue. The kernel is supposed to perform two batched matrix multiply and some reshapes to compute the forward pass of a block tensor train as in the reference einsum. The code gives the correct result when run using the interpreter. If this is a known issue with a workaround I would appreciate some help :)

import os

import torch
import triton
import triton.language as tl
from einops import einsum, rearrange


def btt_fwd_ref(x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor:
    g, _, _, _, j = W1.shape
    y = rearrange(x, "... (g i j) -> ... g i j", g=g, j=j)
    out = einsum(y, W1, W2, "... g i j, g k r i j, g k o r i -> ... g k o")
    return rearrange(out, "... g k o -> ... (g k o)")


@triton.jit
def _btt_fwd_kernel(
    # Pointers to matrices
    x_ptr,
    w1_ptr,
    w2_ptr,
    out_ptr,
    # Matrix dimensions
    B,
    G,
    K,
    I,
    J,
    O,
    # Strides
    # x
    stride_xb,
    stride_xg,
    stride_xi,
    stride_xj,
    # W1
    stride_w1g,
    stride_w1k,
    stride_w1r,
    stride_w1i,
    stride_w1j,
    # W2
    stride_w2g,
    stride_w2k,
    stride_w2o,
    stride_w2r,
    stride_w2i,
    # out
    stride_ob,
    stride_og,
    stride_ok,
    stride_oo,
    # Meta-parameters
    RANK: tl.constexpr,
    BLOCK_B: tl.constexpr,
    BLOCK_K: tl.constexpr,
    BLOCK_J: tl.constexpr,
):
    # Program ID
    pid1 = tl.program_id(0)
    pid2 = tl.program_id(1)

    # Compute batch and group indices
    bb = pid1 // G * BLOCK_B
    g = pid1 % G

    o_blocks = tl.cdiv(O, BLOCK_K)
    bk = pid2 // o_blocks * BLOCK_K
    bo = pid2 % o_blocks * BLOCK_K
    acc_out = tl.zeros((BLOCK_K, BLOCK_B, BLOCK_K), dtype=tl.float32)  # g: k b o

    for r in tl.range(0, RANK):
        indsb = tl.arange(0, BLOCK_B)
        indsj = tl.arange(0, BLOCK_J)
        indsk = tl.arange(0, BLOCK_K)
        for bi in tl.range(0, I, BLOCK_J):
            # First matmul: x × W1
            acc_inner = tl.zeros(
                (BLOCK_J, BLOCK_B, BLOCK_K), dtype=tl.float32
            )  # g r: i b k
            for bj in tl.range(0, J, BLOCK_J):
                # Load block of x
                x_block = tl.load(
                    x_ptr
                    + g * stride_xg
                    + (bb + indsb[:, None, None]) * stride_xb
                    + (bi + indsj[None, :, None]) * stride_xi
                    + (bj + indsj[None, None, :]) * stride_xj,
                    mask=(
                        bb + indsb[:, None, None] < B and bi + indsj[None, :, None] < I
                    )
                    and bj + indsj[None, None, :] < J,
                    other=0,
                )  # g: i b j
                x_block = tl.trans(x_block, 1, 0, 2)  # g: b i j -> i b j

                # Load block of W1 (k j)
                w1_block = tl.load(
                    w1_ptr
                    + g * stride_w1g
                    + r * stride_w1r
                    + (bk + indsk[:, None, None]) * stride_w1k
                    + (bi + indsj[None, :, None]) * stride_w1i
                    + (bj + indsj[None, None, :]) * stride_w1j,
                    mask=(
                        bk + indsk[:, None, None] < K and bi + indsj[None, :, None] < I
                    )
                    and bj + indsj[None, None, :] < J,
                    other=0,
                )  # g r: k i j
                w1_block = tl.trans(w1_block, 1, 2, 0)  # g r: k i j -> i j k

                # Accumulate
                acc_inner += tl.dot(
                    x_block, w1_block
                )  # g r: (i b j) x (i j k) -> i b k

            acc_inner = tl.trans(acc_inner, 2, 1, 0)  # g r: i b k -> k b i

            # Second matmul: acc × W2
            w2_block = tl.load(
                w2_ptr
                + g * stride_w2g
                + r * stride_w2r
                + (bk + indsk[:, None, None]) * stride_w2k
                + (bo + indsk[None, :, None]) * stride_w2o
                + (bi + indsj[None, None, :]) * stride_w2i,
                mask=(bk + indsk[:, None, None] < K and bo + indsk[None, :, None] < O)
                and bi + indsj[None, None, :] < I,
                other=0,
            )  # g r: k o i
            w2_block = tl.trans(w2_block, 0, 2, 1)  # g r: k o i -> k i o

            # Accumulate
            acc_out += tl.dot(acc_inner, w2_block)  # g r: (k b i) x (k i o) -> k b o

    indsb = tl.arange(0, BLOCK_B)
    indsk = tl.arange(0, BLOCK_K)

    # Store result
    tl.store(
        out_ptr
        + g * stride_og
        + (bb + indsb[:, None, None]) * stride_ob
        + (bk + indsk)[None, :, None] * stride_ok
        + (bo + indsk)[None, None, :] * stride_oo,
        tl.trans(acc_out, 1, 0, 2),  # g: k b o -> b k o
        mask=(bb + indsb[:, None, None] < B and bk + indsk[None, :, None] < K)
        and (bo + indsk)[None, None, :] < O,
    )


def _btt_fwd_triton(
    x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor
) -> torch.Tensor:
    G, K, R, I, J = W1.shape
    _, _, O, _, _ = W2.shape
    out = x.new_empty(x.shape[:-1] + (G * K * O,))

    # ... (g i j) -> (...) g i j
    x_view = x.view(-1, G, I, J)
    out_view = out.view(-1, G, K, O)
    B = x.shape[0]

    BLOCK_B, BLOCK_K, BLOCK_J = 16, 16, 16
    grid = (
        triton.cdiv(B, BLOCK_B) * G,
        triton.cdiv(K, BLOCK_K) * triton.cdiv(O, BLOCK_K),
    )
    _btt_fwd_kernel[grid](
        x_view,
        W1,
        W2,
        out_view,
        B,
        G,
        K,
        I,
        J,
        O,
        *x_view.stride(),
        *W1.stride(),
        *W2.stride(),
        *out_view.stride(),
        BLOCK_B=BLOCK_B,
        BLOCK_K=BLOCK_K,
        BLOCK_J=BLOCK_J,
        RANK=R,
    )
    return out


if __name__ == "__main__":
    torch.manual_seed(0)
    triton_interpreting = os.environ.get("TRITON_INTERPRET", "0") == "1"
    print("Interpreting:", triton_interpreting, triton.__version__)
    device = "cpu" if triton_interpreting else "cuda"
    G, K, R, I, J, O = 2, 16, 1, 16, 16, 16

    bsize, in_features, out_features = 128, 256, 256
    x, W1, W2 = (
        torch.randn(bsize, G * in_features, device=device),
        torch.randn(G, K, R, I, J, device=device),
        torch.randn(G, K, O, R, I, device=device),
    )

    y = _btt_fwd_triton(x, W1, W2)
    y_ref = btt_fwd_ref(x, W1, W2)
    torch.testing.assert_close(y, y_ref)
    print("Succeeded")

Environment details

Triton: Tested on 3.1.0 and main (d5ba6ac)
GPU: A100 and 4070 Ti

@falkaer falkaer added the bug label Nov 21, 2024
@Jokeren Jokeren self-assigned this Nov 21, 2024
@Jokeren
Copy link
Contributor

Jokeren commented Nov 21, 2024

It's an interesting case that we'll investigate

@Jokeren Jokeren assigned ThomasRaoux and unassigned Jokeren Nov 21, 2024
@lezcano
Copy link
Contributor

lezcano commented Nov 23, 2024

Hot take: We should implement implicit broadcasting of layouts à la NumPy and support in one go ND dots

@lezcano lezcano moved this from In Progress to Done in Linear Layout Dec 2, 2024
@Edenzzzz
Copy link

Edenzzzz commented Dec 7, 2024

I also met the same issue in my code, where a tl.store following 3D matmul throws this, even with Triton Nightly version with the new fix doesn't work (triton_nightly-3.0.0.post20240716052845).
image

@Edenzzzz
Copy link

Edenzzzz commented Dec 7, 2024

Tried to install the fix from latest main but failed to build from source 😂
#4272 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

5 participants