diff --git a/BUILD b/BUILD new file mode 100644 index 000000000000..6381b59d31fc --- /dev/null +++ b/BUILD @@ -0,0 +1,908 @@ +# This package imports OpenAI's Triton (https://github.com/openai/triton). +# +# There are two versions of Triton in google3 at the moment. The older version +# can be found at //third_party/py/triton. This is the MLIR-based version close +# to head. We expect to transition users to this version in the following +# weeks. +# +# There is no SLA associated with this package and it may get broken by LLVM +# imports at any time. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = [":license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # # Add your project here if you need to depend on Triton's C++ sources. + # # Add a point of contact we can reach out to when needed in the comment. + # # + # # If you need to use the Python fronted, add your project to + # # google3/third_party/py/triton/BUILD instead. + # # + # # By adding your project here, you agree to the Triton SLA: go/triton-google3-sla + # "//third_party/py/jax:__subpackages__", # cjfj@ + # "//third_party/tensorflow/compiler/xla:__subpackages__", # bchetioui@ + # "//platforms/xla/experimental/gpu:__subpackages__", # csigg@ + # # Triton-internal visibility + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end + # TODO(csigg): fix and remove + features = [ + "-parse_headers", + "-use_header_modules", + ], +) + +# copybara:uncomment_begin +# license(name = "license") +# +# licenses(["notice"]) +# +# exports_files(["LICENSE"]) +# copybara:uncomment_end + +config_setting( + name = "compiler_is_msvc", + flag_values = { + # copybara:comment_begin + "@bazel_tools" + + # copybara:comment_end + "//tools/cpp:compiler": "msvc-cl", + }, +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + ":compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +td_library( + name = "td_files", + srcs = glob(["include/triton/**/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/Triton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/Triton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "include/triton/Dialect/Triton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/Triton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/Triton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=Triton", + ], + "include/triton/Dialect/Triton/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_combine_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/Transforms", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/Transforms/TritonCombine.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/Transforms/Combine.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPU", + ], + "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonToTritonGPU", + ], + "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_target_llvmir_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonLLVMIR", + ], + "include/triton/Target/LLVMIR/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Target/LLVMIR/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPUToLLVM", + ], + "include/triton/Conversion/TritonGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonGPUToLLVM/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_type_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-type-interface-decls"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td", + deps = ["td_files"], +) + +cc_library( + name = "TritonAnalysis", + srcs = [ + "lib/Analysis/Alias.cpp", + "lib/Analysis/Allocation.cpp", + "lib/Analysis/Membar.cpp", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "lib/Analysis/Utility.cpp", + # "lib/Analysis/AxisInfo.cpp", + ], + hdrs = [ + "include/triton/Analysis/Alias.h", + "include/triton/Analysis/Allocation.h", + "include/triton/Analysis/Membar.h", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "include/triton/Analysis/AxisInfo.h", + # "include/triton/Analysis/Utility.h", + "include/triton/Conversion/MLIRTypes.h", + "include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h", + "include/triton/Conversion/TritonGPUToLLVM/Utility.h", + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", + ], + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonDialects", + srcs = glob([ + "lib/Dialect/Triton/IR/*.cpp", + "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", + "lib/Tools/*.cpp", + ]) + [ + "include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h", # Avoid circular dependency. + "lib/Analysis/AxisInfo.cpp", # Avoid circular dependency. + "lib/Analysis/Utility.cpp", # Avoid circular dependency. + "lib/Dialect/TritonGPU/Transforms/Utility.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/triton/Dialect/Triton/IR/*.h", + "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", + "include/triton/Tools/*.h", + ]) + [ + "include/triton/Analysis/AxisInfo.h", # Avoid circular dependency. + "include/triton/Analysis/Utility.h", # Avoid circular dependency. + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", # Avoid circular dependency. + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = ["include"], + deps = [ + ":triton_dialect_inc_gen", + ":triton_gpu_attr_inc_gen", + ":triton_gpu_dialect_inc_gen", + ":triton_gpu_ops_inc_gen", + ":triton_gpu_types_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_nvidia_gpu_types_inc_gen", + ":triton_ops_inc_gen", + ":triton_types_inc_gen", + ":triton_type_interfaces_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@triton//third_party/nvidia:NVGPUDialect", + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@triton//third_party/f2reduce", + ], +) + +cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +) + +cc_library( + name = "TritonGPUTransforms", + srcs = glob( + [ + "lib/Dialect/TritonGPU/Transforms/*.cpp", + "lib/Dialect/TritonGPU/Transforms/*.h", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h", + ], + exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"], + ), + hdrs = glob( + [ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ], + exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"], + ) + [ + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":TritonGPUToLLVM", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonGPUToLLVM", + srcs = glob([ + "lib/Conversion/TritonGPUToLLVM/*.h", + "lib/Conversion/TritonGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/triton/Tools/Sys/*.hpp", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]), + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + ":triton_gpu_attr_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonToTritonGPU", + srcs = glob([ + "lib/Conversion/TritonToTritonGPU/*.h", + "lib/Conversion/TritonToTritonGPU/*.cpp", + ]), + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonLLVMIR", + srcs = glob([ + "lib/Target/LLVMIR/*.cpp", + "lib/Target/LLVMIR/*.h", + ]), + hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonTransforms", + ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + # copybara:uncomment "//third_party/py/triton/google:find_cuda", + ], +) + +cc_library( + name = "TritonPTX", + srcs = glob([ + "lib/Target/PTX/*.cpp", + ]), + hdrs = glob(["include/triton/Target/PTX/*.h"]), + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "TritonHSACO", + srcs = glob([ + "lib/Target/HSACO/*.cpp", + ]), + hdrs = glob(["include/triton/Target/HSACO/*.h"]), + deps = [ + ":TritonLLVMIR", + ":TritonTools", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + ], +) + +cc_library( + name = "TritonTools", + hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], +) + +cc_library( + name = "AllPassesAndDialects", + srcs = [ + "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", + ], + hdrs = ["bin/RegisterTritonDialects.h"], + includes = ["."], # because it includes third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//mlir:AllPassesAndDialects", + "@triton//test:TritonTestAnalysis", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +cc_binary( + name = "triton-opt", + srcs = [ + "bin/triton-opt.cpp", + ], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + ], +) + +cc_binary( + name = "triton-llvm-opt", + srcs = [ + "bin/triton-llvm-opt.cpp", + "lib/Target/LLVMIR/LLVMPasses.h", + ], + deps = [ + ":TritonLLVMIR", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + ], +) + +# See go/triton-debug for usage. +cc_binary( + name = "triton-reduce", + srcs = ["bin/triton-reduce.cpp"], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirReduceLib", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUDialectToLLVM", + ], +) + +cc_binary( + name = "triton-tensor-layout", + srcs = ["bin/triton-tensor-layout.cpp"], + deps = [ + ":AllPassesAndDialects", + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + ], +) + +filegroup( + name = "metadata-file", + srcs = ["METADATA"], +) diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 36344442bd3a..547d6a6cd659 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -df0864e761107b07e38f5503e0cbee0cebb4c5e8 +29b92d07746fac26cd64c914bc9c5c3833974f6d diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 3912191f4f3e..f91ceb3bbae3 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false. return get(context, vec, perPhase, maxPhase, order, CTALayout); } - // ---- begin Ampere ---- - if (mmaEnc.isAmpere()) { + // ---- begin Ampere & Hopper ---- + if (mmaEnc.isAmpere() || mmaEnc.isHopper()) { int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); perPhase = std::max(perPhase, 1); std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; @@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false. llvm_unreachable("invalid operand index"); } - // ---- begin version 3 ---- - if (mmaEnc.isHopper()) { - llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr" - " is Hopper has not been implemented yet"); - return $_get(context, 1, 1, 1, order, CTALayout, true); - } - // ---- not implemented ---- llvm_unreachable("unsupported swizzling for provided MMA version"); }]>, @@ -1222,7 +1215,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: SmallVector getMMAv1Rep(int opIdx) const; SmallVector getMMAv1ShapePerWarp(int opIdx) const; int getMMAv1Vec(int opIdx) const; - SmallVector getMMAv2Rep(ArrayRef shape, + SmallVector getMMAv2OrV3Rep(ArrayRef shape, int bitwidth, int opIdx) const; bool supportReduction() const { @@ -1317,6 +1310,10 @@ The parent field is the layout of d. kWidth defines number of consecutive elements stored by one thread along k dimension. Some layouts do not use this parameter, either because they have a fixed number of elements along the K dim, or they use all elements of the tensor along the K dim. + +We require kWidth to be provided for Hopper because the dtype at loading might be +different from the dtype at WGMMA, due to casting. The kWidth is determined by the +dtype at WGMMA. }]; let parameters = ( @@ -1327,16 +1324,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim ); let builders = [ - // Specially for MMAV1(Volta) AttrBuilder<(ins "unsigned":$opIdx, "Attribute":$parent, "Type":$eltTy), [{ NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); - if (!parentAttr || !parentAttr.isAmpere()) - return $_get(context, opIdx, parent, 0); + if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper())) + return $_get(context, opIdx, parent, 0); // For MMAV1 + // For MMAV2 and V3 unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); - unsigned MMAv2kWidth = 32 / bitwidth; - return $_get(context, opIdx, parent, MMAv2kWidth); + unsigned kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, kWidth); }]> ]; diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 787dee35fb25..ae1706a597ac 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -11,6 +11,24 @@ using namespace mlir::triton::gpu; namespace mlir::triton::gpu { +namespace { + +bool isDotOpTensorAndPacked(Type srcTy) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return false; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!encoding) + return false; + auto parentEnc = dyn_cast(encoding.getParent()); + // By code convention, values for Hopper's dotOp-encoded tensors are not packed + if (!parentEnc || parentEnc.isHopper()) + return false; + return true; +} + +} // namespace + Type getElementType(Value value) { auto type = value.getType(); if (auto tensorType = dyn_cast(type)) @@ -33,14 +51,15 @@ SmallVector reorderValues(const SmallVector &values, Type inType, // If the parent of the dot operand is in block encoding, we don't need to // reorder elements auto parentEncoding = dyn_cast(ouEncoding.getParent()); - if (!parentEncoding) + if (!parentEncoding || parentEncoding.isHopper()) return values; size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); auto ouEltTy = ouTensorTy.getElementType(); if (inBitWidth == ouBitWidth) return values; - if (inBitWidth == 16 && ouBitWidth == 32) { + if ((inBitWidth == 16 && ouBitWidth == 32) || + (inBitWidth == 32 && ouBitWidth == 16)) { SmallVector ret; for (unsigned i = 0; i < values.size(); i += 8) { ret.push_back(values[i]); @@ -82,12 +101,10 @@ SmallVector reorderValues(const SmallVector &values, Type inType, SmallVector unpackI32(const SmallVector &inValues, Type srcTy, ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) + if (!isDotOpTensorAndPacked(srcTy)) return inValues; + auto tensorTy = cast(srcTy); + SmallVector outValues; for (auto v : inValues) { // cast i32 to appropriate eltType vector and extract elements @@ -104,12 +121,10 @@ SmallVector unpackI32(const SmallVector &inValues, Type srcTy, SmallVector packI32(const SmallVector &inValues, Type srcTy, ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) + if (!isDotOpTensorAndPacked(srcTy)) return inValues; + auto tensorTy = cast(srcTy); + SmallVector outValues; auto eltType = typeConverter->convertType(tensorTy.getElementType()); int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 48f31bdf2a9d..244da1770ddd 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1022,13 +1022,18 @@ LogicalResult DotOperandEncodingAttr::verify( return emitError() << "triton_gpu.dot_op parent paramenter cannot be null"; } if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 0 && !parentAttr.isAmpere()) + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) return emitError() << "triton_gpu.dot_op kWidth parameter can only be " - "non-zero for Ampere MMA parent"; - if (kWidth == 0 && parentAttr.isAmpere()) + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) return emitError() << "triton_gpu.dot_op kWidth parameter is mandatory for " - "Ampere MMA parent"; + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) + return emitError() + << "triton_gpu.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent, since Hopper WGMMA only allows first " + "operand to be in registers"; return success(); } @@ -1957,9 +1962,10 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { return 2 * getMMAv1Rep(opIdx)[opIdx]; } -SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, +SmallVector NvidiaMmaEncodingAttr::getMMAv2OrV3Rep(ArrayRef shape, int bitwidth, int opIdx) const { + assert(isAmpere() || isHopper()); auto rank = shape.size(); auto warpsPerCTA = getWarpsPerCTA(); SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; @@ -1967,7 +1973,6 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, rank == 3 ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) : 1; - assert(isAmpere()); if (opIdx == 0) return {numRepBatch, @@ -1982,6 +1987,7 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, warpsPerCTA[rank - 1]))}; } } + unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); @@ -1989,11 +1995,17 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( int warpsPerCTAN = getWarpsPerCTA()[1]; // H100 if (isHopper()) { - return getTotalElemsPerThread(shape, eltTy); + assert(opIdx == 0); + auto instrMNK = getInstrShape(); + int repM = ceil(shapePerCTA[0], instrMNK[0] * warpsPerCTAM); + int repK = ceil(shapePerCTA[1], instrMNK[2]); + // For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds + // kWidth elements for each quadrant. WGMMA is repeated repM * repK times. + return 4 * kWidth * repM * repK; } // A100 if (isAmpere()) { - auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); + auto rep = getMMAv2OrV3Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); if (opIdx == 0) return 4 * rep[0] * rep[1] * rep[2]; if (opIdx == 1) @@ -2720,6 +2732,11 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding, so we want to keep this layout conversion. + if (mlir::isa( + convert.getSrc().getType().getEncoding())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); return mlir::success(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index d9bbd51bd9a1..7776a93305ff 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), newLayout, SharedMemorySpace); rewriter.setInsertionPointAfterValue(arg); + + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding. + if (auto dotOpEnc = mlir::dyn_cast( + argType.getEncoding())) { + // Create a layout conversion from DotOperandEncoding to BlockedEncoding + // then pass it to the LocalAllocOp. + auto newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), dotOpEnc.getParent()); + auto dotOperandToBlockedCvt = + rewriter.create(arg.getLoc(), newArgType, arg); + return rewriter.create(arg.getLoc(), newType, + dotOperandToBlockedCvt); + } + return rewriter.create(arg.getLoc(), newType, arg); } @@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern { mutable llvm::DenseMap dotOpInstNs; static bool bwdFilter(Operation *op) { + // Dot operand layout assignment to Predicates are not currently supported + // during lowering from TritonGPU to LLVM in Triton for MMA cases. This + // condition limits visibility of the original bit-width so that predicate + // are not considered, hence, kwidth can never be = 32. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + return false; + } return op->getNumOperands() == 1 && (isa(op) || isPureUnaryInlineAsm(op) || diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 6d8279795209..c55aea243f8b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -1,9 +1,11 @@ +#include "mlir/IR/IRMapping.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -15,6 +17,120 @@ namespace gpu { namespace { +// Helpers + +// Returns whether we can hoist DotOp Encoding through `op`. +// Roughly, whether op is elementwise and thus threads don't need +// to exchange elements. But some ops are not currently supported even though +// they meet that criterion. +bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) { + // Only consider custom conversions or arith ops. + // TODO(jlebar): Is this too restrictive? + if (!isa(op) && !isPureUnaryInlineAsm(op) && + !isa(op->getDialect())) + return false; + + // Quick handling to fix loading issues when computing the original + // bitwidth is unable to realize that there is a mixed-precision dot + // (hence kWidth = 1) but wants to hoist through the type conversion. + if (isa(op) && dotOpEnc.getKWidth() == 1) + return false; + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(op)) + return false; + + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(op)) { + Type opType = getElementTypeOrSelf(op->getOperand(0)); + if (opType.isInteger(1)) + return false; + } + + return true; +} + +// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A +// is in registers). +bool canHoistDotOpEncV3(Operation* op) { + // Must have exactly one result and at least one operand + if (op->getNumOperands() == 0 || op->getNumResults() != 1) + return false; + + auto isBlockedOrDotOpRankedTensor = [](Type ty) { + auto tensorTy = dyn_cast(ty); + if (!tensorTy) + return false; + return isa(tensorTy.getEncoding()); + }; + + // Operands and results must be of RankedTensorType and Blocked or DotOp + if (!(all_of(op->getOperandTypes(), isBlockedOrDotOpRankedTensor) && + all_of(op->getResultTypes(), isBlockedOrDotOpRankedTensor))) + return false; + + // Only consider custom conversions or arith ops. + if (!isa(op) && !isPureUnaryInlineAsm(op) && + !isa(op->getDialect())) + return false; + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(op)) + return false; + + // Downcasting not currently supported; it will likely require minor + // adjustments in sharedToDotOperandMMv2 + auto oprType = getElementTypeOrSelf(op->getOperand(0)); + auto resType = getElementTypeOrSelf(op->getResult(0)); + if (oprType.getIntOrFloatBitWidth() > resType.getIntOrFloatBitWidth()) + return false; + + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(op) && oprType.isInteger(1)) + return false; + + return true; +} + +// Helper to perform a "deep" clone of the given slice (i.e., set of ops), +// returning a tuple (newSlice, sliceMap), where newSlice is the cloned slice, +// and sliceMap the IRMapping that maps the ops and result values of the +// original slice to those in the cloned slice. +auto cloneSlice(PatternRewriter& rewriter, const SetVector& slice) { + IRMapping sliceMap; + SetVector newSlice; + + // First pass: clone ops; the result values are cloned as well, but the operands still + // refer to the original result values + for (Operation *op : slice) { + auto newOp = rewriter.clone(*op); + newSlice.insert(newOp); + sliceMap.map(op, newOp); + for (auto [result, newResult] : llvm::zip(op->getResults(), newOp->getResults())) { + assert(result != newResult); + sliceMap.map(result, newResult); + } + } + + // Second pass: replace operand references in cloned ops to point to cloned values + for (auto [op, newOp] : sliceMap.getOperationMap()) + for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) { + auto defOp = operand.getDefiningOp(); + if (!slice.contains(defOp)) + continue; + + newOp->setOperand(oprIdx, sliceMap.lookup(operand)); + } + + return std::make_tuple(newSlice, sliceMap); +} + // Given // convert(trans(src)) #dot_operand -> // convert(local_load(trans(alloc(src)))) @@ -111,7 +227,8 @@ class HoistLayoutConversion : public OpRewritePattern { PatternRewriter &rewriter) const override { // Only consider conversions to dot operand. auto cvtTy = cast(cvt.getType()); - if (!isa(cvtTy.getEncoding())) + auto dotOpEnc = dyn_cast(cvtTy.getEncoding()); + if (!dotOpEnc) return failure(); auto src = cvt.getSrc().getDefiningOp(); @@ -126,16 +243,7 @@ class HoistLayoutConversion : public OpRewritePattern { [](Type ty) { return isa(ty); })) return failure(); - // Only consider custom conversions or arith ops. - // TODO(jlebar): Is this too restrictive? - if (!isa(src) && !isPureUnaryInlineAsm(src) && - src->getDialect()->getTypeID() != TypeID::get()) - return failure(); - - // Currently, these instructions are not supported during lowering of - // shared -> dot_operand layout. Not all types and type conversions are - // supported. - if (isa(src)) + if (!canHoistDotOpEncV2(src, dotOpEnc)) return failure(); // Check that the conversion is transitively dependent on a load, and all @@ -165,12 +273,7 @@ class HoistLayoutConversion : public OpRewritePattern { if (isa(currOp)) { foundLoad = true; } else if (foundLoad) { - // Bail out if there exists an op after Load that is not FpToFp, - // Bitcast, or Arith. - if (!isa(currOp) && - !isPureUnaryInlineAsm(currOp) && - currOp->getDialect()->getTypeID() != - TypeID::get()) + if (!canHoistDotOpEncV2(currOp, dotOpEnc)) return failure(); } } @@ -286,8 +389,9 @@ struct MMAV3UseRegOperand dstEnc.getVersionMajor() != 3) return failure(); auto srcTy = cast(alloc.getSrc().getType()); + auto kWidth = 32 / srcTy.getElementTypeBitWidth(); auto dotOperandEnc = DotOperandEncodingAttr::get( - dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); + dotOp.getContext(), /*opIdx=*/0, srcEnc, kWidth); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); if (!isMmaToDotShortcut(srcTy, newTy)) @@ -300,6 +404,145 @@ struct MMAV3UseRegOperand } }; +// MMAV3's analog of HoistLayoutConversion, for operand A only; will make WarpGroupDot +// accept operand A in registers instead of shmem. +// +// Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot +// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; warp_group_dot +// +// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a time and +// requires multiple passes, this pattern will directly hoist the convert to the right +// place in one pass. +// +// Or, to be more precise, this pattern deletes the local_alloc op and inserts a +// convert_layout op after each load that warp_group_dot uses; so this is not simply hoisting +// a convert_layout op up as in V2, but can be considered as first changing local_alloc to +// convert_layout and then hoisting, which results in WGMMA now accepting operand A in DotOp +// layout rather than Shared. +struct MMAV3HoistLayoutConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp, + PatternRewriter &rewriter) const override { + // Can only hoist operand 0 + auto alloc = dotOp.getOperand(0).getDefiningOp(); + if (!alloc || !alloc.getSrc()) + return rewriter.notifyMatchFailure(dotOp, + "operand A must be produced by local_alloc"); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + if (!isa(getEncoding(dotOp.getOperand(0)))) + return rewriter.notifyMatchFailure(dotOp, + "requires Shared encoding for operand A"); + + // Step 1: Performs checks for early stop + auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); + if (!srcEnc) + return rewriter.notifyMatchFailure(alloc, + "requires src to have Blocked encoding"); + + auto dstEnc = dyn_cast(getEncoding(dotOp.getResult())); + if (!dstEnc || dstEnc.getVersionMajor() != 3) + return rewriter.notifyMatchFailure(dotOp, + "requires result in NvidiaMma encoding"); + + // Step 2: Obtain slice of ops between load/constant and local_alloc + SetVector slice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = [&](Operation *op) { + // Stop before Load, ConstantOp, or LocalLoad + return (op->getParentRegion() == alloc->getParentRegion()) + && !isa(op) + && (op->getNumOperands() != 0); + }; + getBackwardSlice(alloc.getOperation(), &slice, opt); + + // Step 3: Verify slice can be hoisted through + if (slice.empty()) + return rewriter.notifyMatchFailure(dotOp, "nothing to hoist through"); + + // We define frontierOp as an op outside this slice whose result is used by an op in + // this slice. We must eventually convert the result of all frontierOps to + // DotOperandEncoding. This is done via the insertion of ConvertLayout after each + // frontierOp. + // We currently support frontierOp to be load or constant. + for (Operation *currOp : slice) { + if (!canHoistDotOpEncV3(currOp)) + return rewriter.notifyMatchFailure(currOp, "cannot hoist through"); + + // We previously ensured that all ops in slice have at least one operand + for (auto operand : currOp->getOperands()) { + auto defOp = operand.getDefiningOp(); + if (!slice.contains(defOp)) { + // ensure frontierOp is load or constant + if (!isa(defOp)) + return rewriter.notifyMatchFailure(defOp, "must be load or constant"); + } + } + } + + // Step 4: Clone slice + auto [newSlice, sliceMap] = cloneSlice(rewriter, slice); + + // Step 5: Modify the cloned slice to have dotOp encoding. + // Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot + // After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; warp_group_dot + // + // Specifically, this step will change all value types from #blocked to #dot_op + // encoding in the cloned slice, and for those values produced by frontierOps (i.e., + // outside the slice), we will insert convert_layout's after the frontierOp. + auto srcTy = cast(alloc.getSrc().getType()); + Type inputEltTy = srcTy.getElementType(); + auto dotOperandEnc = DotOperandEncodingAttr::get( + dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy); + + for (auto op : newSlice) { + // Step 5a: If any operand is defined by a frontierOp, we must insert a + // convert_layout(#dot_op) after the frontierOp and before currOp + for (auto [oprIdx, operand] : llvm::enumerate(op->getOperands())) { + + auto defOp = operand.getDefiningOp(); + + // defOp is not frontier (i.e. it's within slice); no need to convert the + // layout of its result + if (newSlice.contains(defOp)) + continue; + + // We checked earlier that all operands are ranked tensors + auto operandTy = cast(operand.getType()); + auto operandEltTy = operandTy.getElementType(); + + Type cvtTy = RankedTensorType::get( + operandTy.getShape(), operandTy.getElementType(), dotOperandEnc); + rewriter.setInsertionPoint(op); + auto cvt = rewriter.create(defOp->getLoc(), cvtTy, operand); + + op->setOperand(oprIdx, cvt); + } + + // Step 5b: Change the result to have DotOp rather than Blocked encoding + auto resTy = dyn_cast(op->getResult(0).getType()); + op->getResult(0).setType(RankedTensorType::get( + resTy.getShape(), resTy.getElementType(), dotOperandEnc)); + } + + // Step 6: replace LHS operand with alloc's parent in the cloned slice + // This changes the warpGroupDot to accept a DotOp tensor as operand A instead of + // a Shared memdesc. + auto newDotOperand = sliceMap.lookup(alloc.getSrc()); + rewriter.modifyOpInPlace(dotOp, [&]() { + dotOp.setOperand(0, newDotOperand); + }); + + return success(); + } +}; + } // namespace #define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS @@ -321,9 +564,11 @@ class TritonGPUOptimizeDotOperandsPass auto ret = pm.run(m); mlir::RewritePatternSet patterns(context); + patterns.add(context); patterns.add(context); - if (this->hoistLayoutConversion.getValue()) + if (this->hoistLayoutConversion.getValue()) { patterns.add(context); + } patterns.add(context); patterns.add(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index dc5f395c6753..e920de798289 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -380,6 +380,14 @@ static bool loadIsMMAv3(Operation *loadOp) { if (!sharedEnc.getHasLeadingOffset()) return false; + // In case LHS is in registers, don't pipeline for now TODO(ggengnv) is this necessary? + auto op = *alloc->getUsers().begin(); + if (auto localLoad = dyn_cast(op)) { + auto resTy = cast(localLoad->getResultTypes()[0]); + if (!resTy || isa(resTy.getEncoding())) + return false; + } + // MMA V3 case. auto newOrder = sharedEnc.getOrder(); auto ty = cast(loadOp->getResultTypes()[0]); diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 2cbc00142b42..db71b3b82061 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, type.getMemorySpace()), v, offsetsVal); + // We need to assign kwidth to zero in the case where the parent layout is + // Blocked, otherwise the verifier emits a failure. The parent layout is + // Blocked only when Tensor Cores are disabled. + int kwidth = dyn_cast(dotEncoding) + ? 0 + : prefetchWidth / 8; auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + builder.getContext(), opIdx, dotEncoding, kwidth); Value prefetchSlice = builder.create( v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); @@ -190,6 +196,15 @@ LogicalResult Prefetcher::initialize() { break; if (!op->getResult(0).hasOneUse()) break; + // Similar to issues faced in HoistLayoutConversion pattern in + // OptimizeDotOperands.cpp, we can't propagate through type casts from + // predicates as they aren't supported in Triton when encoded with dot_op + // layout. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + break; + } rets.push_back(op->getOperand(0)); if (auto cvt = dyn_cast(op)) { foundConvertFromShared = true; diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 000000000000..334dd4aec41a --- /dev/null +++ b/python/BUILD @@ -0,0 +1,77 @@ +# NOTE: Do not depend on any targets from this directory, +# but use //third_party/py/triton instead. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__pkg__", + "@triton//python:__subpackages__", + ], +) + +cc_library( + name = "passes", + hdrs = ["src/passes.h"], + includes = ["src"], + visibility = ["@triton//third_party:__subpackages__"], +) + +pybind_extension( + name = "libtriton", + srcs = [ + "src/interpreter.cc", + "src/ir.cc", + "src/llvm.cc", + "src/main.cc", + "src/passes.cc", + ], + copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"], + deps = [ + ":passes", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + "//:TritonHSACO", + "//:TritonLLVMIR", + "//:TritonNvidiaGPUTransforms", + "//:TritonPTX", + "//:TritonToTritonGPU", + "//:TritonTools", + "//:TritonTransforms", + "@triton//third_party/nvidia:triton_nvidia", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["triton/**/*.py"], + ), +) diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD new file mode 100644 index 000000000000..a88f4eeae1f8 --- /dev/null +++ b/python/test/regression/BUILD @@ -0,0 +1,26 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["test_*.py"], + exclude = [ + "test_performance.py", #TODO(b/321005767): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/regression/conftest.py b/python/test/regression/conftest.py new file mode 100644 index 000000000000..7a02d322b49f --- /dev/null +++ b/python/test/regression/conftest.py @@ -0,0 +1,12 @@ +# content of conftest.py + +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default='cuda') + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/python/test/unit/BUILD b/python/test/unit/BUILD new file mode 100644 index 000000000000..f75527bab1f7 --- /dev/null +++ b/python/test/unit/BUILD @@ -0,0 +1,180 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests", "pytest_test") + +package( + default_applicable_licenses = ["//:license"], +) + +_requires_gpu_sm80 = [ + "config-cuda-only", + "requires-gpu-sm80", +] + +_requires_config_cuda = select( + {"@local_config_cuda//cuda:using_clang_allow_exec": []}, + no_match_error = "Requires --config=cuda", +) + +EXCLUDE_TESTS = [ + "language/test_reproducer.py", # this is not an actual test, but a tool for running reproducers + "language/test_subprocess.py", # TODO(b/320224484): fix failing test + "runtime/test_launch.py", # TODO(b/320226169): fix failing tests + "tools/test_aot.py", # TODO(b/320224484): fix failing test + "tools/test_disasm.py", # TODO(b/320224484): fix failing test + "hopper/test_persistent_warp_specialized_gemm.py", # TODO (b/342348738): fix failing test + "runtime/test_cublas.py", # TODO(b/346755023): fix failing test +] + +# Runs all python tests on H100 +pytest_multi_tests( + name = "hopper", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + name_suffix = "_h100", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["**/test_*.py"], + exclude = EXCLUDE_TESTS + [ + "language/test_core.py", + "language/test_pipeliner.py", # TODO(b/362458006): fix failing test + "hopper/test_experimental_tma.py", # TODO(b/362458006): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "hopper/language/test_core_h100", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "language", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["language/**/test_*.py"], + exclude = EXCLUDE_TESTS + ["language/test_core.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "language/test_core", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "instrumentation", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["instrumentation/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "runtime", + srcs = ["conftest.py"], + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["runtime/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "tools", + size = "large", + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["tools/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "unit", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 039f7ac1ac4f..3d1cbc5a82f0 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2139,6 +2139,8 @@ def kernel(X, Z, BLOCK: tl.constexpr): reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] >= 9, + reason='Reduction test produces wrong results on H100, b/342347027') @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + @@ -3642,6 +3644,25 @@ def _kernel(out): kernel[(1, )](out) assert torch.all(out == out_ref) +@pytest.mark.interpreter +def test_dot_on_broadcast(device): + @triton.jit + def _kernel(a, b, out): + a_offsets = tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + lhs = tl.load(a + a_offsets, mask=a_offsets < 32 * 64) + rhs = tl.load(b) + rhs_bc = tl.broadcast_to(rhs, [32, 32]) + c = tl.dot(lhs, rhs_bc) + out_ptr = out + tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + a = torch.ones((64, 32), dtype=getattr(torch, 'float32'), device=device) + b = torch.tensor([1.0], dtype=getattr(torch, 'float32'), device=device) + out_ref = torch.matmul(a, torch.broadcast_to(b, (32, 32))) + out = torch.zeros((64, 32), dtype=getattr(torch, 'float32'), device=device) + _kernel[(1, )](a, b, out, num_stages=1, num_warps=4) + assert torch.all(out == out_ref) + # --------------- # test arange diff --git a/python/triton/_C/include b/python/triton/_C/include index b85a409837d1..8a5dba6c4b56 120000 --- a/python/triton/_C/include +++ b/python/triton/_C/include @@ -1 +1 @@ -../../../include/ \ No newline at end of file +../../../include \ No newline at end of file diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba97b..f9bab523bf6c 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -46,5 +46,8 @@ def _discover_backends(): _find_concrete_subclasses(driver, DriverBase)) return backends - -backends = _discover_backends() +from triton.backends.nvidia.driver import CudaDriver +from triton.backends.nvidia.compiler import CUDABackend +backends = { + "nvidia": Backend(CUDABackend, CudaDriver) +} diff --git a/test/BUILD b/test/BUILD new file mode 100644 index 000000000000..0379d89208e9 --- /dev/null +++ b/test/BUILD @@ -0,0 +1,63 @@ +# copybara:uncomment_begin +# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests") +# load("//tools/build_defs/build_test:build_test.bzl", "build_test") +# +# package( +# default_applicable_licenses = ["//:license"], +# default_compatible_with = ["//buildenv/target:non_prod"], +# default_visibility = ["//:__subpackages__"], +# ) +# +# glob_lit_tests( +# name = "all_tests", +# data = [ +# "@llvm-project//llvm:FileCheck", +# "//:triton-llvm-opt", +# "//:triton-opt", +# "//:triton-tensor-layout", +# ], +# driver = "@llvm-project//mlir:run_lit.sh", +# exclude = [ +# "Conversion/amd/dedup-by-constancy.mlir", # AMD-specific, broken +# "TritonGPU/dot-operands.mlir", # TODO: b/283035396 - broken by cl536931041.patch +# "TritonGPU/optimize_epilogue.mlir", # TODO: b/346283526 - AMD-specific, triggering UBSAN +# ], +# test_file_exts = [ +# "mlir", +# "ll", +# ], +# ) +# +# build_test( +# name = "build_test", +# allow_empty_target = False, +# targets = [ +# "//:TritonAnalysis", +# "//:TritonDialects", +# "//:TritonGPUToLLVM", +# "//:TritonGPUTransforms", +# "//:TritonLLVMIR", +# "//:TritonPTX", +# "//:TritonToTritonGPU", +# "//:TritonTools", +# "//:TritonTransforms", +# "//:triton-opt", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "TritonTestAnalysis", + srcs = glob(["lib/Analysis/*.cpp"]), + deps = [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index d44529966274..12bf0d8f4d43 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -97,9 +97,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: - tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } } @@ -114,10 +114,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} - tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { + tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : - tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> + tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> tt.return } } @@ -193,7 +193,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: prmt.b32 // CHECK: prmt.b32 tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) { - %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 728fd8eadfd9..62a2d469996a 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -143,7 +143,6 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : } } - // ----- // Verify that we use mmav2 when the k dim is too small for mmav3. @@ -159,3 +158,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return %result : tensor<128x128xf32, #blocked> } } + +// ----- + +// CHECK-DAG: #[[$BLOCKED:.*]] = #triton_gpu.blocked +// CHECK-DAG: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3 +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @local_alloc_dot_operand(%in0: tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> {tt.divisibility = 16 : i32}, %in1: f32, %in2: tensor<64x32xf32, #blocked>) -> (tensor<64x32xf32, #blocked>) { + // CHECK-LABEL: local_alloc_dot_operand + %splat_in1 = tt.splat %in1 : f32 -> tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: %[[LHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc + // CHECK: %[[RHS_CVT:.*]] = triton_gpu.convert_layout {{.*}} #triton_gpu.dot_op<{{.*}}> -> {{.*}} #[[$BLOCKED]] + // CHECK: %[[RHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc %[[RHS_CVT]] + // CHECK: triton_nvidia_gpu.warp_group_dot %[[LHS_LOCAL_ALLOC]], %[[RHS_LOCAL_ALLOC]] + %res = tt.dot %in0, %splat_in1, %in2, inputPrecision = tf32 : tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf32, #blocked> + tt.return %res : tensor<64x32xf32, #blocked> + } +} diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index ecee359cb19a..f015f9651065 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -133,3 +133,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return %2 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> } } // end module + +// ----- + +// CHECK: #[[$BLOCKED:.*]] = #triton_gpu.blocked +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @cvt_from_dot_op_into_local_allow_not_canonicalized(%in: tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> !tt.memdesc<256x32xf32, #shared1> { + // CHECK-LABEL: cvt_from_dot_op_into_local_allow_not_canonicalized + %cvt_in = triton_gpu.convert_layout %in : tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<256x32xf32, #blocked> + %alloc = triton_gpu.local_alloc %cvt_in : (tensor<256x32xf32, #blocked>) -> !tt.memdesc<256x32xf32, #shared1> + // CHECK: %[[ALLOC:.*]] = triton_gpu.local_alloc {{.*}} (tensor<{{.*}}, #[[$BLOCKED]]{{.*}}>) -> + tt.return %alloc : !tt.memdesc<256x32xf32, #shared1> + } +} // end module + diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 82fc1ddf7b65..ab70a081a6bd 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -164,8 +164,8 @@ tt.func @update_kwidth_slice( #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !tt.memdesc<128x64xf16, #shared1> %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> @@ -180,8 +180,8 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdes #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A_fp8 -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1> %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> @@ -211,3 +211,51 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : tt.return %td : tensor<128x128xf32, #mma> } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = triton_gpu.convert_layout %[[A_BLOCK]] : tensor<128x64xbf16, #blocked> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_DOTOP]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !tt.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a_bf16 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a = tt.fp_to_fp %a_bf16 : tensor<128x64xbf16, #blocked> -> tensor<128x64xf16, #blocked> + %dota = triton_gpu.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared1> + %r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise_chained +// CHECK: %[[CST_DOTOP:.*]] = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = triton_gpu.convert_layout %[[A_BLOCK]] : tensor<128x64xi8, #blocked> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = arith.sitofp %[[A_DOTOP]] : tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_SCALED:.*]] = arith.mulf %[[A_CASTED]], %[[CST_DOTOP]] : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_NEGATED:.*]] = arith.negf %[[A_SCALED]] : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise_chained(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !tt.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %a_i8 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a_f16 = arith.sitofp %a_i8 : tensor<128x64xi8, #blocked> to tensor<128x64xf16, #blocked> + %a_scaled = arith.mulf %a_f16, %cst : tensor<128x64xf16, #blocked> + %a_negated = arith.negf %a_scaled : tensor<128x64xf16, #blocked> + %dota = triton_gpu.local_alloc %a_negated: (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared1> + %r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} diff --git a/test/TritonGPU/invalid-attributes.mlir b/test/TritonGPU/invalid-attributes.mlir index c8b3c2ef6b0b..26a7c0773b9f 100644 --- a/test/TritonGPU/invalid-attributes.mlir +++ b/test/TritonGPU/invalid-attributes.mlir @@ -2,7 +2,7 @@ // expected-error@+2 {{triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: 2}} #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked}> +#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked, kWidth = 2}> // ----- @@ -12,19 +12,25 @@ // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere or Hopper MMA parent}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> + +// ----- + +// expected-error@+2 {{triton_gpu.dot_op opIdx parameter must be 0 for Hopper MMA parent, since Hopper WGMMA only allows first operand to be in registers}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index d391be688c23..2c2182154d6a 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -398,8 +398,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -481,7 +481,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 @@ -519,7 +519,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -624,7 +624,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 @@ -685,7 +685,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index 74fd2e05551b..a7064ea82204 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -113,7 +113,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %115 = triton_nvidia_gpu.warp_group_dot %113, %114, %cst :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> %117 = triton_gpu.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared> - %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // The first dot gets converted to dot-async + wait. The second one // doesn't have a wait because the first wait is sufficient. // CHECK: triton_nvidia_gpu.warp_group_dot @@ -121,7 +121,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: triton_nvidia_gpu.warp_group_dot // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait // CHECK: scf.yield - %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %121 = arith.addf %120, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 9fbc540b92a6..f178eb24050a 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -245,3 +245,20 @@ tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt } // end module // ----- + +// CHECK: tt.func @matmul_loop_on_blocked_layout +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_on_blocked_layout(%arg_lhs: !tt.memdesc<16x512xf32, #shared, mutable>, %arg_rhs: !tt.memdesc<512x32xf32, #shared, mutable>, %arg_init: tensor<16x32xf32, #blocked>, %itr_val : i32) -> (tensor<16x32xf32, #blocked>) { + %loop:3 = scf.for %itr = %itr_val to %itr_val step %itr_val iter_args(%init = %arg_init, %lhs = %arg_lhs, %rhs = %arg_rhs) -> (tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable>) : i32 { + %lhs_ll = triton_gpu.local_load %lhs : !tt.memdesc<16x512xf32, #shared, mutable> -> tensor<16x512xf32, #blocked> + %lhs_ll_cvt = triton_gpu.convert_layout %lhs_ll : tensor<16x512xf32, #blocked> -> tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %rhs_ll = triton_gpu.local_load %rhs : !tt.memdesc<512x32xf32, #shared, mutable> -> tensor<512x32xf32, #blocked> + %rhs_ll_cvt = triton_gpu.convert_layout %rhs_ll : tensor<512x32xf32, #blocked> -> tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %res = tt.dot %lhs_ll_cvt, %rhs_ll_cvt, %init : tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x32xf32, #blocked> + scf.yield %res, %lhs, %rhs : tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable> + } + tt.return %loop#0 : tensor<16x32xf32, #blocked> + } +} // end module diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD new file mode 100644 index 000000000000..bbdf7408f85e --- /dev/null +++ b/third_party/amd/BUILD @@ -0,0 +1,250 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu/fusions/triton:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + "//:compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +cc_library( + name = "TritonAMDGPUTransforms", + srcs = glob([ + "lib/TritonAMDGPUTransforms/**/*.h", + "lib/TritonAMDGPUTransforms/**/*.cpp", + ]) + ["include/TritonAMDGPUToLLVM/TargetUtils.h"], + hdrs = glob([ + "include/TritonAMDGPUTransforms/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUTransforms", + ], + deps = [ + ":triton_conversion_amdgpu_transforms_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + ], +) + +cc_library( + name = "TritonAMDGPU", + srcs = glob([ + "lib/Dialect/TritonAMDGPU/**/*.h", + "lib/Dialect/TritonAMDGPU/**/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/TritonAMDGPU/**/*.h", + ]), + includes = [ + "..", + "include", + ], + deps = [ + ":triton_amdgpu_attr_def_inc_gen", + ":triton_amdgpu_dialect_inc_gen", + ":triton_amdgpu_ops_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:TensorDialect", + ], +) + +cc_library( + name = "TritonAMDGPUToLLVM", + srcs = glob([ + "lib/TritonAMDGPUToLLVM/**/*.h", + "lib/TritonAMDGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonAMDGPUToLLVM/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUToLLVM", + ], + deps = [ + ":TritonAMDGPU", + ":TritonAMDGPUDialectToLLVM", + ":TritonAMDGPUTransforms", + ":triton_conversion_amdgpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +cc_library( + name = "TritonAMDGPUDialectToLLVM", + srcs = glob([ + "lib/TritonAMDGPUDialectToLLVM/**/*.h", + "lib/TritonAMDGPUDialectToLLVM/**/*.cpp", + ]), + includes = [ + "include", + ], + deps = [ + "//:TritonGPUToLLVM", + ], +) + +td_library( + name = "td_files", + srcs = glob(["include/**/*.td"]), + includes = ["include"], + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_ops_inc_gen", + tbl_outs = [ + ( + [ + "--gen-llvmir-conversions", + ], + "include/Dialect/TritonAMDGPU/IR/OpsConversions.inc", + ), + ( + [ + "--gen-op-decls", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.h.inc", + ), + ( + [ + "--gen-op-defs", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_dialect_inc_gen", + tbl_outs = [ + ( + [ + "--gen-dialect-decls", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.h.inc", + ), + ( + [ + "--gen-dialect-defs", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_attr_def_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPUToLLVM", + ], + "include/TritonAMDGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUToLLVM/Passes.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_transforms_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPU", + ], + "include/TritonAMDGPUTransforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUTransforms/Passes.td", + deps = [":td_files"], +) diff --git a/third_party/f2reduce/BUILD b/third_party/f2reduce/BUILD new file mode 100644 index 000000000000..93829539e1b9 --- /dev/null +++ b/third_party/f2reduce/BUILD @@ -0,0 +1,31 @@ +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# copybara:uncomment_begin +# license( +# name = "license", +# license_text = "LICENCE.txt", +# ) +# +# licenses(["notice"]) +# +# exports_files(["LICENCE.txt"]) +# copybara:uncomment_end + +cc_library( + name = "f2reduce", + srcs = ["f2reduce.cpp"], + hdrs = ["f2reduce.h"], + # copybara:uncomment strip_include_prefix = "/third_party/triton", +) diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD new file mode 100644 index 000000000000..f062b61a9ee6 --- /dev/null +++ b/third_party/nvidia/BUILD @@ -0,0 +1,306 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +pybind_library( + name = "cublas_headers", + hdrs = glob([ + "include/*.h", + ]), + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + +pybind_library( + name = "triton_nvidia", + srcs = [ + "triton_nvidia.cc", + ], + compatible_with = [], + # copybara:uncomment_begin + # visibility = [ + # "@triton//python:__subpackages__", + # ], + # copybara:uncomment_end + deps = [ + ":NVGPUDialect", + ":NVGPUToLLVM", + ":TritonNVIDIAGPUToLLVM", + ":cublas_headers", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "@triton//python:passes", + ], +) + +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/NVGPUToLLVM/*.h", + ]), + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + ], + deps = [ + ":NVGPUDialect", + ":TritonNVIDIAGPUToLLVM", + ":triton_conversion_nvgpu_to_llvm_passes_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + ], +) + +cc_library( + name = "TritonNVIDIAGPUToLLVM", + srcs = glob([ + "lib/TritonNVIDIAGPUToLLVM/*.h", + "lib/TritonNVIDIAGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonNVIDIAGPUToLLVM/*.h", + ]) + [ + "lib/TritonNVIDIAGPUToLLVM/Utility.h", + ], + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + "lib/TritonNVIDIAGPUToLLVM", + ], + deps = [ + ":NVGPUDialect", + ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:triton_gpu_attr_inc_gen", + ], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/NVGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNVIDIAGPUToLLVM", + ], + "include/TritonNVIDIAGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +td_library( + name = "td_files", + srcs = glob(["include/Dialect/NVGPU/IR/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +cc_library( + name = "NVGPUDialect", + srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/NVGPU/IR/*.h", + ]), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = [ + "..", # because nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc + "../..", # because third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + "include", + ], + deps = [ + ":nvgpu_attr_inc_gen", + ":nvgpu_dialect_inc_gen", + ":nvgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + # The following is added to make Utility compile + "//:TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD new file mode 100644 index 000000000000..a5b34aa5c29b --- /dev/null +++ b/third_party/nvidia/backend/BUILD @@ -0,0 +1,30 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +pybind_extension( + name = "cuda_utils", + srcs = ["cuda_utils.cc"], + visibility = [ + "//learning/deepmind/jax/triton/ops:__subpackages__", + "//third_party/py/triton:__subpackages__", + ], + deps = [ + "//platforms/gpus/cuda/dynamic_libcuda", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cuda_runtime", + "@llvm-project//llvm:Support", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index bb0d86888120..19c732c354d1 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -154,6 +154,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { typedef CUresult (*cuOccupancyMaxActiveClusters_t)( int *numClusters, CUfunction func, const CUlaunchConfig *config); +#if CUDA_VERSION >= 12000 typedef CUresult (*cuTensorMapEncodeTiled_t)( CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, @@ -161,6 +162,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); +#endif #define defineGetFunctionHandle(name, symbolName) \ static symbolName##_t name() { \ @@ -187,8 +189,10 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, cuOccupancyMaxActiveClusters); +#if CUDA_VERSION >= 12000 defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, cuTensorMapEncodeTiled); +#endif static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, @@ -281,6 +285,9 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else unsigned long long global_address; uint64_t dim; uint32_t tensorDim; @@ -321,11 +328,15 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); Py_INCREF(Py_None); return Py_None; +#endif } // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else unsigned long long global_address; uint64_t dims[2]; uint32_t tensorDims[2]; @@ -384,6 +395,7 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); Py_INCREF(Py_None); return Py_None; +#endif } static PyMethodDef ModuleMethods[] = { diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 8de0efefca84..637071275e39 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -291,10 +291,36 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern { Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { auto outputStructType = cast(op.getType()); - uint32_t numOutputRegs = outputStructType.getBody().size(); - std::string output = - outputStructType.getBody().front().isF32() ? "=f" : "=r"; - return Constraints(numOutputRegs, output); + std::vector outputConstraints; + outputConstraints.reserve(outputStructType.getBody().size()); + for (mlir::Type type : outputStructType.getBody()) { + if (type.isF32()) { + outputConstraints.push_back("=f"); + continue; + } else if (type.isF64()) { + outputConstraints.push_back("=d"); + continue; + } + unsigned bitwidth = isa(type) ? + 64 : type.getIntOrFloatBitWidth(); + switch (bitwidth) { + case 1: + outputConstraints.push_back("=b"); + break; + case 16: + outputConstraints.push_back("=h"); + break; + case 32: + outputConstraints.push_back("=r"); + break; + case 64: + outputConstraints.push_back("=l"); + break; + default: + assert(false && "unsupported bitwidth"); + } + } + return outputConstraints; } OperandsAndConstraints diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index 197901d8555c..37c3bdc7d45c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -1,6 +1,6 @@ add_triton_library(TritonNVIDIAGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp ConvertLayoutOpToLLVM.cpp DotOpToLLVM/MMAv1.cpp DotOpToLLVM/MMAv2.cpp diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 8fb44ce644ba..2aebcb2cae28 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -34,13 +34,13 @@ Value convertLayout(int opIdx, Value tensor, const SharedMemoryObject &smemObj, } // namespace SharedToDotOperandMMAv1 -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread); -} +} // namespace SharedToDotOperandMMAv2OrV3 namespace { @@ -88,11 +88,20 @@ struct LocalLoadOpConversion auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); Value res; - if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 - res = SharedToDotOperandMMAv2::convertLayout( + + if (isOuter) { + assert(false && "MMA Layout does not support outer product"); + return res; + } + + if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3 + if (mmaLayout.isHopper()) + assert(dotOperandLayout.getOpIdx() == 0); + + res = SharedToDotOperandMMAv2OrV3::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, smemObj, typeConverter, getThreadId(rewriter, loc)); - } else if (!isOuter && mmaLayout.isVolta() && isMMA) { // tensor core v1 + } else if (mmaLayout.isVolta() && isMMA) { // tensor core v1 bool isMMAv1Row = mmaLayout.getMMAv1IsRow(dotOperandLayout.getOpIdx()); auto srcSharedLayout = cast(src.getType().getEncoding()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp similarity index 86% rename from third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp rename to third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index bf033bdd5322..9897e1b17e6e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -25,6 +25,7 @@ class MMA16816SmemLoader { ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, int elemBytes, + int mmaElemBytes, bool isHopper, ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, const Location &loc); @@ -67,6 +68,8 @@ class MMA16816SmemLoader { int perPhase; int maxPhase; int elemBytes; + int mmaElemBytes; + bool isHopper; ConversionPatternRewriter &rewriter; const Location &loc; MLIRContext *ctx{}; @@ -203,10 +206,10 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value lane, Value cSwizzleOffset) { // vecWidth // <-------> // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || *t0 ... *t0 t1 ... t1 t2 ... t2 t3 ... t3 /|\ -// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | -// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height -// ... | -// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ +// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | +// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height +// ... | +// t28 ... t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ // --------------------------------------------- || -------------------------------------------- // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 // t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 @@ -364,6 +367,7 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; } else { // base pointers + // ptrs[k][...] holds `vec` pointers each for (quadK == k) std::array, 2> ptrs; for (int i = 0; i < vecWidth; i++) ptrs[0][i] = getPtr(ptrIdx + i); @@ -383,11 +387,13 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, i0 = add(i0, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); i1 = add(i1, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); } + // ii[m] holds the offset for (quadM == m) std::array ii = {i0, i1}; // load 4 32-bit values from shared memory // (equivalent to ldmatrix.x4) SmallVector> vptrs(4, SmallVector(vecWidth)); + // i iterates the 2x2 quads, m-first for (int i = 0; i < 4; ++i) for (int j = 0; j < vecWidth; ++j) { vptrs[i][j] = gep(ptr_ty(ctx, 3), shemTy, ptrs[i / 2][j], ii[i % 2]); @@ -402,7 +408,9 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, int canonWidth = (8 * elemBytes * inc) / canonBits; Type canonInt = int_ty(canonBits); std::array retElems; - retElems.fill(undef(vec_ty(canonInt, 32 / canonBits))); + // don't pack to 32b for Hopper + int vecSize = isHopper ? 1 : 32 / canonBits; + retElems.fill(undef(vec_ty(canonInt, vecSize))); for (int r = 0; r < 2; ++r) { for (int em = 0; em < 2 * vecWidth; em += inc) { int e = em % vecWidth; @@ -421,8 +429,11 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, } if (isActualTrans) std::swap(retElems[1], retElems[2]); - return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), - bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; + + auto iTy = isHopper ? int_ty(8 * elemBytes * inc) : i32_ty; + + return {bitcast(retElems[0], iTy), bitcast(retElems[1], iTy), + bitcast(retElems[2], iTy), bitcast(retElems[3], iTy)}; } } @@ -432,7 +443,8 @@ MMA16816SmemLoader::MMA16816SmemLoader( ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, - int elemBytes, ConversionPatternRewriter &rewriter, + int elemBytes, int mmaElemBytes, bool isHopper, + ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, const Location &loc) : nPerWarp(nPerWarp), order(order.begin(), order.end()), warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder), @@ -441,17 +453,29 @@ MMA16816SmemLoader::MMA16816SmemLoader( matShape(matShape.begin(), matShape.end()), multiDimWarpId(multiDimWarpId.begin(), multiDimWarpId.end()), perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes), + mmaElemBytes(mmaElemBytes), isHopper(isHopper), rewriter(rewriter), loc(loc), ctx(rewriter.getContext()) { + // If the current elemType width is different from the MMA elemType width, i.e. + // width-changing casting is done later in DotOp Layout... then, in the case of + // Hopper, the number of bytes held by each thread after loading will no longer + // be 32B. Hence this flag is required to stipulate different logic. + bool isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes); + contiguousMatShape = matShape[order[0]]; stridedMatShape = matShape[order[1]]; stridedSmemOffset = smemStrides[order[1]]; smemBatchOffset = smemStrides[order[2]]; - vecWidth = 4 / elemBytes; + if (isHopperWidthChange) { + vecWidth = 4 / mmaElemBytes; + } else { + vecWidth = 4 / elemBytes; + } // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; nonKOrder = (kOrder == 2) ? 1 : 2; canUseLdmatrix = elemBytes == 2 || (!needTrans); canUseLdmatrix = canUseLdmatrix && (kWidth == vecWidth); + canUseLdmatrix = canUseLdmatrix && !isHopperWidthChange; if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed, @@ -504,21 +528,57 @@ Type getSharedMemTy(Type argType) { llvm::report_fatal_error("mma16816 data type not supported"); } +std::vector unpackInt(const std::vector &inValues, Type elTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + const int inBitWidth = inValues[0].getType().getIntOrFloatBitWidth(); + std::vector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(elTy); + auto vecType = vec_ty(eltType, inBitWidth / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < inBitWidth / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + Value composeValuesToDotOperandLayoutStruct( const ValueTable &vals, int batch, int n0, int n1, const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter, Type elTy, bool isHopper) { std::vector elems; + // Existing convention for the ordering of quad values in llvm.struct + // is m-major for Hopper and k-major for Ampere, even though both Ampere + // and Hopper MMA's expect m-major ordering in PTX. + // + // To unify the ordering conventions would potentially require touching + // `ConvertLayoutOpToLLVM.cpp`, `ElementwiseOpToLLVM.cpp`, `MMAv2.cpp`, + // `WGMMA.cpp`, and possibly others. For now, we are using an if-check + // here to route to the correct ordering. for (int b = 0; b < batch; ++b) for (int m = 0; m < n0; ++m) - for (int k = 0; k < n1; ++k) { - elems.push_back(vals.at({b, 2 * m, 2 * k})); - elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); - } + for (int k = 0; k < n1; ++k) + if (isHopper) { + elems.push_back(vals.at({b, 2 * m, 2 * k})); + elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); + elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); + elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); + } else { + elems.push_back(vals.at({b, 2 * m, 2 * k})); + elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); + elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); + elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); + } + assert(!elems.empty()); + if (isHopper) { + elems = unpackInt(elems, elTy, rewriter, loc, typeConverter); + } + Type elemTy = elems[0].getType(); MLIRContext *ctx = elemTy.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( @@ -544,18 +604,20 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, const int maxPhase = sharedLayout.getMaxPhase(); const int vecPhase = sharedLayout.getVec(); const int elemBytes = descTy.getElementTypeBitWidth() / 8; + const int mmaElemBytes = 4 / kWidth; + const bool isHopper = mmaLayout.getVersionMajor() == 3; auto order = sharedLayout.getOrder(); int nPerWarp = std::max(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8); - // (a, b) is the coordinate. auto load = [=, &rewriter, &vals](int batch, int a, int b) { MMA16816SmemLoader loader( nPerWarp, warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides, shapePerCTA /*tileShape*/, instrShape, matShape, multiDimWarpId, - perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); + perPhase, maxPhase, elemBytes, mmaElemBytes, + isHopper, rewriter, typeConverter, loc); // Offset of a slice within the original tensor in shared memory Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = loader.computeOffsets(lane, cSwizzleOffset); @@ -573,6 +635,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, auto [ha0, ha1, ha2, ha3] = loader.loadX4( batch, (kOrder == 2) ? a : b /*mat0*/, (kOrder == 2) ? b : a /*mat1*/, ptrs, matTy, getSharedMemTy(eltTy)); + if (!isA) std::swap(ha1, ha2); // the following is incorrect @@ -595,16 +658,21 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, MemDescType descTy, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread, bool isA) { + auto mmaLayout = mlir::cast(encoding.getParent()); + bool isHopper = mmaLayout.getVersionMajor() == 3; auto shapePerCTA = getShapePerCTA(descTy); int bitwidth = descTy.getElementTypeBitWidth(); - auto mmaLayout = mlir::cast(encoding.getParent()); + // For Hopper WGMMA, the sum of bitwidth of the elements in each quad should add + // up to 32. We use kWidth to compute the element bitwidth of the input to WGMMA, + // which could be different from `bitwidth` due to later casting. + int mmaBitwidth = isHopper ? (32 / encoding.getKWidth()) : bitwidth; ValueTable vals; - int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; - int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; + int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / mmaBitwidth; + int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth; auto numRep = - mmaLayout.getMMAv2Rep(shapePerCTA, bitwidth, encoding.getOpIdx()); + mmaLayout.getMMAv2OrV3Rep(shapePerCTA, mmaBitwidth, encoding.getOpIdx()); int kWidth = encoding.getKWidth(); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); @@ -616,7 +684,6 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, delinearize(rewriter, loc, warp, warpsPerCTA, order); Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0])); int warpsPerTile; - auto rank = shapePerCTA.size(); Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16)); Value warpN = urem(multiDimWarpId[2], i32_val(shapePerCTA[2] / 8)); if (isA) @@ -652,7 +719,8 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, // Format the values to LLVM::Struct to passing to mma codegen. return composeValuesToDotOperandLayoutStruct( - vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter); + vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter, + descTy.getElementType(), /*unpack=*/isHopper); } template @@ -764,7 +832,7 @@ getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, return expandedSmemObj; } -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, @@ -785,4 +853,4 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, expandedSmemObj, typeConverter, thread, false); } } -} // namespace SharedToDotOperandMMAv2 +} // namespace SharedToDotOperandMMAv2OrV3 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index af897ef546dd..928b46cbbd90 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -318,10 +318,10 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); auto repA = cast(dotOpA.getParent()) - .getMMAv2Rep(aShapePerCTA, bitwidth, dotOpA.getOpIdx()); + .getMMAv2OrV3Rep(aShapePerCTA, bitwidth, dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); auto repB = cast(dotOpB.getParent()) - .getMMAv2Rep(bShapePerCTA, bitwidth, dotOpB.getOpIdx()); + .getMMAv2OrV3Rep(bShapePerCTA, bitwidth, dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 1bb55373e046..cfc487c59ecb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -264,6 +264,31 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, // Return a vector of Value of the accumulator start at startIndex and pack the // values into 32bits in case the accumulator is fp16. +// +// `elements` contains all loaded register values for operand A. +// This consists of operand A for possibly multiple wgmma instructions. +// For each wgmma, each warp in a warp group feeds a single "warp matrix" +// Each warp matrix consists of 2x2 "quads". +// Each thread holds several elements in each quad. Right before a wgmma, +// the sum of bitwidth of +// the elements in each quad should add up to 32. +// +// These values are stored unrolled in `elements`. +// The ordering of dimensions is as follows: +// batch (only 1 batch for Hopper currently) +// matM (m-index of the "warp matrix") +// matK (k-index of the "warp matrix") +// quadK (k-index of the "quad" in the core matrix) +// quadM (m-index of the "quad" in the core matrix) +// vecIdx (index of the element in the quad; this is always along the k-dim) +// +// This ordering is decided when a tensor in DotOpEnc is lowered into llvm. +// For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand. +// Thus, both lowerings must obey this above ordering for the below code to be correct. +// +// Additionally, note that WGMMA expects quadK ordered before quadM, i.e. the layout +// is quadM-major. This is opposite to Ampere's ordering for ldmatrix and dotOp. +// (see SharedToDotOperandMMAv2OrV3.cpp) llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, Location loc, const SmallVector &elements, diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 1269dcda00aa..3cccc5fb6a1c 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -1,4 +1,4 @@ -#include "Dialect/NVGPU/IR/Dialect.h" +#include "Dialect/NVGPU/IR/Dialect.h" #include "NVGPUToLLVM/NVGPUToLLVMPass.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" #include "cublas_instance.h" diff --git a/third_party/proton/proton/_C/include b/third_party/proton/proton/_C/include index fe4f4a1aa9bd..4400934bdf78 120000 --- a/third_party/proton/proton/_C/include +++ b/third_party/proton/proton/_C/include @@ -1 +1 @@ -../../csrc/include/ \ No newline at end of file +../../csrc/include \ No newline at end of file diff --git a/unittest/BUILD b/unittest/BUILD new file mode 100644 index 000000000000..4cbadcfa4655 --- /dev/null +++ b/unittest/BUILD @@ -0,0 +1,144 @@ +load("//tools/build_defs/build_test:build_test.bzl", "build_test") + +package( + default_applicable_licenses = ["//:license"], + default_compatible_with = ["//buildenv/target:non_prod"], + default_visibility = ["//:__subpackages__"], +) + +cc_test( + name = "AnalysisTest", + srcs = glob(["Analysis/*.cpp"]), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTestCatchAll", + srcs = glob( + [ + "Dialect/**/*.cpp", + ], + exclude = [ + "Dialect/TritonGPU/DialectTest.cpp", + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTest", + srcs = [ + "Dialect/TritonGPU/DialectTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "LinearLayoutConversionsTest", + srcs = [ + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "SwizzleTest", + srcs = [ + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "ConversionTest", + srcs = glob( + [ + "Conversion/**/*.cpp", + "Conversion/**/*.h", + ], + exclude = [ + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "//:TritonDialects", + "//:TritonNvidiaGPUTransforms", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +build_test( + name = "build_test", + allow_empty_target = False, + targets = [ + ":ConversionTest", + ":AnalysisTest", + ":DialectTest", + ], +)