Skip to content

Commit

Permalink
[LAYOUTS] Cache LinearLayout creation (#5542)
Browse files Browse the repository at this point in the history
It was reported that triton compilation times have heavily increased
lately. The cause of this is that we very often create the associated LL
to check properties of a given Layout. We do this thousands of times,
and this gets very expensive.

In this PR, we implement a thread-safe cache for LinearLayouts. We clear
this
cache after we are done with the TTGIR -> LLVM conversion.

In the future, we will make `DistributedEncoding` inherit from
`LinearLayoutEncoding`, which will mean that `DistributedEncoding`s
will always have access to their associated LinearLayout. Even in this
scenario I still think that caching will be good, as there is no real
1-to-1 correspondence between `DistributedEncoding`s and `LinearLayout`s
due to broadcasting, where we tile a layout along the tensor or we make
it smaller. As such, I think this cache may be also useful in the
future.
  • Loading branch information
lezcano authored Jan 8, 2025
1 parent 51dddd3 commit 67f5707
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 15 deletions.
58 changes: 51 additions & 7 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,60 @@
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"

// LinearLayoutCache Utils
using CacheKey =
std::tuple<std::vector<int64_t>, mlir::Attribute, std::optional<int32_t>>;

namespace llvm {
template <typename T> size_t hash_value(const std::vector<T> &vec) {
return hash_combine_range(vec.begin(), vec.end());
}
} // namespace llvm

namespace std {
template <> struct hash<CacheKey> {
size_t operator()(const CacheKey &key) const noexcept {
using llvm::hash_value;
size_t seed = 0;
std::apply(
[&seed](const auto &...elems) {
((seed = llvm::hash_combine(seed, hash_value(elems))), ...);
},
key);
return seed;
}
};
} // namespace std

namespace mlir::triton::gpu {

class LinearLayoutCache {
public:
std::optional<LinearLayout> get(const CacheKey &key) {
std::shared_lock lock(mutex);
auto it = cache.find(key);
if (it != cache.end()) {
return it->second;
}
return std::nullopt;
}

void set(CacheKey key, LinearLayout result) {
std::scoped_lock lock(mutex);
cache.emplace(std::move(key), std::move(result));
}

private:
std::unordered_map<CacheKey, LinearLayout> cache;
llvm::sys::SmartRWMutex<true> mutex;
};
} // namespace mlir::triton::gpu

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"

namespace mlir {
namespace triton {
namespace gpu {

namespace mlir::triton::gpu {
struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
StringRef getName() final { return "<SharedMemory>"; }
};
Expand Down Expand Up @@ -240,8 +286,6 @@ llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);
llvm::SmallVector<unsigned>
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);

} // namespace gpu
} // namespace triton
} // namespace mlir
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
7 changes: 7 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def TritonGPU_Dialect : Dialect {
}
return cast<IntegerAttr>(threadsPerWarp).getInt();
}

std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth);

private:
LinearLayoutCache llCache;
}];

let useDefaultTypePrinterParser = 1;
Expand Down
33 changes: 25 additions & 8 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,22 +875,39 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
}

std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth /*= std::nullopt*/) {
// Layouts are distributed or shared
TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth) {
CacheKey key{std::vector<int64_t>(shape.begin(), shape.end()), layout,
elemBitWidth};
auto result = llCache.get(key);
if (result.has_value()) {
return result;
}

// Layouts are distributed or shared in triton core
if (auto distributed = dyn_cast<DistributedEncodingTrait>(layout)) {
return distributed.toLinearLayout(shape);
result = distributed.toLinearLayout(shape);
} else if (auto shared = dyn_cast<SharedEncodingAttr>(layout)) {
if (shared.getHasLeadingOffset()) {
assert(elemBitWidth.has_value());
return sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth);
result = sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth);
} else {
return sharedToLinearLayoutNoLeadingOffset(shape, shared);
result = sharedToLinearLayoutNoLeadingOffset(shape, shared);
}
}

// Third party layouts
return std::nullopt;
if (result.has_value()) {
llCache.set(std::move(key), *result);
}
return result;
}

std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth /*= std::nullopt*/) {
auto *ctx = layout.getContext();
return ctx->getLoadedDialect<TritonGPUDialect>()->toLinearLayout(
shape, layout, elemBitWidth);
}

LinearLayout getLayoutWithinBlock(const LinearLayout &layout) {
Expand Down

0 comments on commit 67f5707

Please sign in to comment.