-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds changes to address review comments
- Loading branch information
1 parent
f57f76b
commit 6a23263
Showing
6 changed files
with
115 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | ||
|
||
// Invalid size | ||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTA [256, 16]}} | ||
%1 = amdgpu.view_slice %arg0[0,0] [256, 2] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> | ||
tt.return | ||
} | ||
|
||
// ----- | ||
|
||
// Invalid offset | ||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTA [256, 16]}} | ||
%1 = amdgpu.view_slice %arg0[0,5] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> | ||
tt.return | ||
} | ||
|
||
// ----- | ||
|
||
// Invalid result layout | ||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// expected-error @+1 {{result layout must match source layout}} | ||
%1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2> | ||
tt.return | ||
} | ||
|
||
// ----- | ||
|
||
// Invalid result element type | ||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// expected-error @+1 {{result element type must match source element type}} | ||
%1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1> | ||
tt.return | ||
} | ||
|
||
// ----- | ||
|
||
// Invalid result rank | ||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// expected-error @+1 {{result rank must be equal to source rank}} | ||
%1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> | ||
tt.return | ||
} | ||
|
||
// ----- | ||
|
||
// Invalid rank | ||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// expected-error @+1 {{currently only 2D tensors are supported}} | ||
%1 = amdgpu.view_slice %arg0[0,0,0] [256,16,2] [1,1,1] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> | ||
tt.return | ||
} | ||
|
||
// ----- | ||
|
||
// Invalid stride | ||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
tt.func @invalid_stride(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// expected-error @+1 {{expected unit strides but found unsupported stride [1, 2]}} | ||
%1 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,2] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> | ||
tt.return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,14 @@ | ||
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s | ||
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s | ||
|
||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { | ||
tt.func @basic_insert_slice_async_1d(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// CHECK: llvm.func @basic_insert_slice_async_1d | ||
tt.func @basic_insert_slice(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// CHECK: llvm.func @basic_insert_slice | ||
// CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> | ||
// CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> | ||
// CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> | ||
%72 = amdgpu.view_slice %arg0[0,0] [256, 16] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> | ||
tt.return | ||
} | ||
} | ||
|
||
module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { | ||
tt.func @basic_insert_slice_async_1d(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { | ||
// CHECK: llvm.func @basic_insert_slice_async_1d | ||
// CHECK: error: sizes [256, 2] must be a multiple of shapePerCTA [256, 16] | ||
// XFAIL: * | ||
%72 = amdgpu.view_slice %arg0[0,0] [256, 2] [1,1] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> | ||
tt.return | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 2 additions & 2 deletions
4
third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters