Skip to content

Commit

Permalink
Adds changes to address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hmalgewatta committed Oct 7, 2024
1 parent f57f76b commit 6a23263
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 68 deletions.
70 changes: 70 additions & 0 deletions test/Conversion/amd/invalid_viewslice_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics

// Invalid size
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTA [256, 16]}}
%1 = amdgpu.view_slice %arg0[0,0] [256, 2] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
tt.return
}

// -----

// Invalid offset
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTA [256, 16]}}
%1 = amdgpu.view_slice %arg0[0,5] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
tt.return
}

// -----

// Invalid result layout
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// expected-error @+1 {{result layout must match source layout}}
%1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2>
tt.return
}

// -----

// Invalid result element type
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// expected-error @+1 {{result element type must match source element type}}
%1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1>
tt.return
}

// -----

// Invalid result rank
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// expected-error @+1 {{result rank must be equal to source rank}}
%1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1>
tt.return
}

// -----

// Invalid rank
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// expected-error @+1 {{currently only 2D tensors are supported}}
%1 = amdgpu.view_slice %arg0[0,0,0] [256,16,2] [1,1,1] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1>
tt.return
}

// -----

// Invalid stride
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @invalid_stride(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// expected-error @+1 {{expected unit strides but found unsupported stride [1, 2]}}
%1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,2] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
tt.return
}
16 changes: 3 additions & 13 deletions test/TritonGPU/amd/amd-viewslice-op.mlir
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s

#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @basic_insert_slice_async_1d(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// CHECK: llvm.func @basic_insert_slice_async_1d
tt.func @basic_insert_slice(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// CHECK: llvm.func @basic_insert_slice
// CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
%72 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
tt.return
}
}

module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @basic_insert_slice_async_1d(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// CHECK: llvm.func @basic_insert_slice_async_1d
// CHECK: error: sizes [256, 2] must be a multiple of shapePerCTA [256, 16]
// XFAIL: *
%72 = amdgpu.view_slice %arg0[0,0] [256, 2] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
tt.return
}
}
47 changes: 32 additions & 15 deletions third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,42 +45,59 @@ def TritonAMDGPU_ViewSliceOp
OffsetSizeAndStrideOpInterface, Pure]> {
let summary = "view slice operation";
let description = [{
The "view_slice" operation enables "viewing" a slice of a tensor in
The "view_slice" operation enables viewing a slice of a tensor in
registers without data exchange.

The "view_slice" operation supports the following arguments:

* source: the base tensor on which to create a "view" tensor
* offsets: offsets into the base tensor at which to create the "view"
* source: the base tensor on which to create a view tensor
* offsets: offsets into the base tensor at which to create the view
* size: size of the result "view" tensor
* strides: the number of strides for each dimension

Currently only 2D tensors are supported.

Example 1:

```mlir
%1 = triton_gpu.convert_layout %0 : tensor<128x128x!tt.ptr<f16>, #blocked>
-> tensor<128x128x!tt.ptr<f16>, #blocked2>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8],
threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8],
threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}>
%1 = triton_gpu.convert_layout %0 : tensor<128x128xf16, #blocked>
-> tensor<128x128xf16, #blocked1>
// create a slice of base tensor %1 with
// static offsets and sizes for each dimension
%2 = amdgpu.view_slice %0[0, 0] [128, 8] [1, 1] :
tensor<128x128x!tt.ptr<f16>, #blocked2> to
tensor<128x8x!tt.ptr<f16>, #blocked2>
%2 = amdgpu.view_slice %0[0, 0] [128, 32] [1, 1] :
tensor<128x128xf16, #blocked1> to tensor<128x32xf16, #blocked1>
```

Example 1 shows how "view_slice" operation may be used. In this example a
new view of 128x32 is created. "view_slice" works on tensors with layout
where the desired slice has the same layout as the source tensor.
"%0" cannot be sliced directly as the resulting slice cannot have the same
layout as "%0". Therefore it needs to be converted to a layout suitable
for slicing. "#blocked1" layout is appropriate for this as it keeps the
sizePerThread the same thus keeping coalescing properties the same.
In order to utilize all threads in a warp, "threadsPerWarp" is set to
[16,4] for this new layout. This layout conversion carried out before
using "view_slice" ensures slicing still uses all threads efficiently.
}];

let arguments = (ins AnyRankedTensor:$source, Variadic<I32>:$offsets,
Variadic<I32>:$sizes, Variadic<I32>:$strides,
DenseI64ArrayAttr:$static_offsets, DenseI64ArrayAttr:$static_sizes,
let arguments = (ins AnyRankedTensor:$source,
Variadic<I32>:$offsets,
Variadic<I32>:$sizes,
Variadic<I32>:$strides,
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides);
let results = (outs AnyRankedTensor:$result);

let builders = [
// Build a ViewSliceOp with mixed static and dynamic entries and the same
// result type
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
OpBuilder<(ins "RankedTensorType":$resultType,
"Value":$source,
"ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H
#define AMD_INCLUDE_TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H
#ifndef TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H
#define TRITONAMDGPU_TO_LLVM_PATTERNS_AMDGPU_OP_TO_LLVM_H

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ LogicalResult ViewSliceOp::verify() {

if (!hasUnitStride()) {
return emitError("expected unit strides but found unsupported stride [")
<< getStrides() << "]";
<< getStaticStrides() << "]";
}

return success();
Expand Down
44 changes: 7 additions & 37 deletions third_party/amd/python/test/test_core.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,20 @@
# flake8: noqa: F821,F841
import contextlib
import itertools
import re
from typing import Optional
import math
import textwrap
import tempfile

import numpy as np
import pytest
import torch
import os
import inspect
from numpy.random import RandomState

import triton
import triton.language as tl
from triton.language.extra import libdevice

from triton._internal_testing import (
is_interpreter,
is_hip,
get_arch,
torch_float8_dtypes,
torch_dtypes,
)
from triton._internal_testing import is_hip


@contextlib.contextmanager
def promotion_numpy_2_0():
state = np._get_promotion_state()
np._set_promotion_state("weak")
try:
yield
finally:
np._set_promotion_state(state)


# TODO: enable multiple cta cluster testing.
# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1]
num_ctas_list = [1]

GPU_DIALECT = "triton_gpu"
if is_interpreter():
THREADS_PER_WARP = 1
elif is_hip():

if is_hip():
THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size
else:
THREADS_PER_WARP = 32
Expand All @@ -71,8 +41,8 @@ def __str__(self):

view_layout = [
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
]
Expand All @@ -92,14 +62,14 @@ def __str__(self):
@pytest.mark.parametrize("blocked_layout", blocked_layout)
def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout,
device='cuda'):
if torch.version.hip is None:
if not is_hip():
pytest.skip("view_slice is AMD specific instruction.")

ir = f"""
#blocked = {blocked_layout}
#view_layout = {view_layout}
module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {str(64)} : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
tt.func public @kernel(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
%cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
Expand Down

0 comments on commit 6a23263

Please sign in to comment.