-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for AMX instructions #5818
Conversation
This will only be included if LLVM >= 12 is used to build Halide
(Synced to head to fix some irrelevant LLVM build issues) |
That is a good excuse for me to setup LLVM 13 locally. I will try to see why it is not building with it |
Recent changes in LLVM trunk made the previous calling convention deprecated (and thus compiling with warning/error)
The OSX failure is unrelated (will be fixed by #5841), should be good to land |
You should sync this to master to force the bots to retry. |
I'm not sure if the buildbot is still running since there is a "cancelled" message there. |
Please try syncing to master once again; hopefully the buildbots will finally be clean. |
Thanks, I will do that, hopefully it will all be green now. |
Failures are the unrelated cuda-hang failure that we still haven't diagnosed; ok to land |
src/CodeGen_X86.cpp
Outdated
@@ -190,7 +200,7 @@ const x86Intrinsic intrinsic_defs[] = { | |||
|
|||
{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, | |||
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, | |||
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, | |||
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: irrelevant whitespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in b1e1452
if (Halide_LLVM_VERSION VERSION_GREATER_EQUAL 12.0) | ||
# AMX instructions require LLVM 12 or newer | ||
list(APPEND RUNTIME_LL x86_amx) | ||
endif () | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does including this fail at build time or at runtime only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this fails at build time with the following message
[1/7] Generating initmod.x86_amx.bc
FAILED: src/runtime/initmod.x86_amx.bc /home/frederik/projects/halide/build-11/src/runtime/initmod.x86_amx.bc
cd /home/frederik/projects/halide/build-11/src/runtime && /usr/lib/llvm-11/bin/llvm-as /home/frederik/projects/halide/src/runtime/x86_amx.ll -o initmod.x86_amx.bc
/usr/lib/llvm-11/bin/llvm-as: /home/frederik/projects/halide/src/runtime/x86_amx.ll:3:18: error: expected type
%2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly
^
Just letting you know we haven't lost track of this and the TensorCore PRs. We had some different priorities and annual leave. I look forward to getting this merged soon. |
aac8f78
to
f0f9f3e
Compare
I don't think the test failures are related to anything in this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM so far -- aside from style nits, I think it would be good to split the new test into correctness and performance tests, as Halide does for virtually all other features.
src/ExtractTileOperations.cpp
Outdated
@@ -0,0 +1,414 @@ | |||
#include "ExtractTileOperations.h" | |||
|
|||
#include "IRMatch.h" // expr_match |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: We don't usually add comments explaining why each header is included.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speaking of which, it might be a good idea to run IWYU on our codebase...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
src/ExtractTileOperations.cpp
Outdated
|
||
enum class AMXOpType { | ||
Int8, | ||
Bf16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I assume this is bfloat16? If so, spelling it out (eg Bfloat16
) would be preferable.
case AMXOpType::Bf16: | ||
return Float(32, 256); | ||
default: | ||
return Type(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this is a should-never-happen case, so doing something like internal_error << "Unexpected";
would be appropriate.
src/ExtractTileOperations.cpp
Outdated
const auto wild_i32 = Variable::make(Int(32), "*"); | ||
const auto wild_i32x = Variable::make(Int(32, 0), "*"); | ||
|
||
Tile<2> is_2d_tile_index(const Expr &e) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd expect a function named "is_whatever" to return bool, but this returns a struct. Something like get_2d_tile_index
would be better.
src/ExtractTileOperations.cpp
Outdated
return {}; | ||
} | ||
|
||
Tile<3> is_3d_tile_index(const Expr &e) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same nit here.
src/ExtractTileOperations.cpp
Outdated
// 4 bytes for i32, f32 | ||
auto colbytes = tile_y * 4; | ||
auto matmul = | ||
Call::make(res_type, "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to split this into two lines
src/ExtractTileOperations.cpp
Outdated
op_type = AMXOpType::Bf16; | ||
} | ||
|
||
user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be helpful to append amx_name or tile_name to the error message, for debugging purposes?
src/ExtractTileOperations.cpp
Outdated
} | ||
|
||
auto alloc_type = amx_op_type_result_type(op_type); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for this blank line
src/ExtractTileOperations.cpp
Outdated
} | ||
|
||
auto body = mutate(op->body); | ||
return ProducerConsumer::make(amx_name, op->is_producer, body); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: std::move(body)
?
test/performance/tiled_matmul.cpp
Outdated
.vectorize(mmyi); | ||
|
||
Func result = mm.in(); | ||
//result.print_loop_nest(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't check in commented-out code (unless there is a comment explaining why, as is done elsewhere in this file)
When converting to correctness tests there's a bit of a change in the pattern for the rhs load when using |
When using `Buffer` instead of `ImageParam` the `Ramp` expression generated is 1D instead of 2D, therefore we recognize this with a special case. The lanes are still matched against the dimensions of the LHS 3d tile lanes.
I think the recent commits addressed all comments, is there anything else that needs to be addressed? |
lgtm. The pattern matching seems to be pretty ad-hoc and possibly brittle, but that can always be improved later. The checks for LLVM 12 will be removed pretty soon too. |
This is failing for LLVM11 for Makefile-based builds. I'll see if I can prep a patch. |
This pull request continues the work started by @jwlawson in #5780 with the objective of adding initial support for AMX instructions in Halide.
The main addition here is the fix for building Halide using LLVM 11. Support for AMX instructions requires LLVM 12 or newer so when building with LLVM 11 the unsupported instructions are not included.
A new LLVM module was created (x86_amx.ll) to contain all the required intrinsics to enable support for tile operations, this module is only included when LLVM >= 12 is present.