Skip to content

Commit

Permalink
Remove maxpool2d preshard workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Feb 7, 2025
1 parent 2778366 commit 80d6a8d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 65 deletions.
15 changes: 4 additions & 11 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@ struct Env {
#else
constexpr static Env
#endif
get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true,
get(bool swapBinaryOperands = true,
bool readUpdateIndexFromDeviceForKVCache = true,
bool toLayoutAPIAssumeSingleChip = true,
bool usePaddingPairSignatureWithQueueId = true)
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
;
#else
{
return Env(true, true, true, true, true);
return Env(true, true, true, true);
}
#endif
// TODO(bug #855): Ideally we should have an op that preshards for maxpool2d
// instead of adding a method in runtime
bool maxpool2dPreshard;

// TODO(bug #1124): We're currently swapping the operands for binary ops
// in runtime if the lhs operand is smaller (and requires broadcast onto the
// rhs operand). We should add this check in the compiler.
Expand Down Expand Up @@ -60,12 +56,11 @@ struct Env {
bool usePaddingPairSignatureWithQueueId;

private:
constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands,
constexpr Env(bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache,
bool toLayoutAPIAssumeSingleChip,
bool usePaddingPairSignatureWithQueueId)
: maxpool2dPreshard(maxpool2dPreshard),
swapBinaryOperands(swapBinaryOperands),
: swapBinaryOperands(swapBinaryOperands),
readUpdateIndexFromDeviceForKVCache(
readUpdateIndexFromDeviceForKVCache),
toLayoutAPIAssumeSingleChip(toLayoutAPIAssumeSingleChip),
Expand All @@ -75,8 +70,6 @@ struct Env {

inline std::ostream &operator<<(std::ostream &os, const Env &env) {
os << "workaround::Env{\n";
os << "\t"
<< "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n";
os << "\t"
<< "swapBinaryOperands: " << env.swapBinaryOperands << ",\n";
os << "\t"
Expand Down
9 changes: 4 additions & 5 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands,
const Env &Env::get(bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache,
bool toLayoutAPIAssumeSingleChip,
bool usePaddingPairSignatureWithQueueId) {
static const Env config(maxpool2dPreshard, swapBinaryOperands,
readUpdateIndexFromDeviceForKVCache,
toLayoutAPIAssumeSingleChip,
usePaddingPairSignatureWithQueueId);
static const Env config(
swapBinaryOperands, readUpdateIndexFromDeviceForKVCache,
toLayoutAPIAssumeSingleChip, usePaddingPairSignatureWithQueueId);
return config;
}
#endif
Expand Down
42 changes: 1 addition & 41 deletions runtime/lib/ttnn/operations/pool/maxpool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,13 @@
#include "operations/pool/maxpool2d.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/detail/workarounds.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"
#include "ttnn/types.hpp"
#include <optional>

namespace tt::runtime::ttnn::operations::pool {

// TODO(bug #855): Ideally we should have an op that preshards for maxpool2d
// instead of adding a method in runtime
template <typename DeviceType>
static ::ttnn::Tensor
preshardForMaxPool2d(const ::tt::target::ttnn::MaxPool2dOp *op,
DeviceType &device, const ::ttnn::Tensor &input) {
const ::ttnn::Shape inputShape =
::tt::runtime::ttnn::operations::utils::toTTNNShape(
*op->in()->desc()->shape());
uint32_t output_height =
1 + (op->input_height() + 2 * op->padding_height() -
op->dilation_height() * (op->kernel_height() - 1) - 1) /
op->stride_height();
uint32_t output_width =
1 + (op->input_width() + 2 * op->padding_width() -
op->dilation_width() * (op->kernel_width() - 1) - 1) /
op->stride_width();

constexpr bool en_ch_padding = false;

auto parallel_config = ::ttnn::operations::conv::determine_parallel_config(
::ttnn::TensorMemoryLayout::HEIGHT_SHARDED, op->batch_size(),
op->channels(), output_height, output_width, op->channels(),
device.compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR,
en_ch_padding);
auto sharded_memory_config = ::ttnn::operations::conv::
create_sharded_memory_config_from_parallel_config(inputShape,
parallel_config, 1);
return ::ttnn::to_memory_config(input, sharded_memory_config, std::nullopt);
}

void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::operations::pool::Pool2DOp<
Expand All @@ -53,15 +21,7 @@ void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) {

::ttnn::Tensor input = tensorPool.at(op->in()->global_id());
DEBUG_ASSERT(input.is_allocated());
if (workaround::Env::get().maxpool2dPreshard) {
DeviceVariant targetDevice =
context.getTargetDevice(op->device()->global_id());
input = std::visit(
[&](auto &&targetDevice) -> ::ttnn::Tensor {
return preshardForMaxPool2d(op, targetDevice.get(), input);
},
targetDevice);
}

::ttnn::MemoryConfig outMemConfig =
::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
::ttnn::Tensor out = operation.invoke(
Expand Down
8 changes: 0 additions & 8 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,6 @@ def initialize_api():
choices=[True, False],
help="enable async mode device execution for TTNN runtime",
)
Run.register_arg(
name="--disable-maxpool2d-preshard",
type=bool,
default=False,
choices=[True, False],
help="disable maxpool2d preshard workaround",
)
Run.register_arg(
name="--disable-swap-binary-operands",
type=bool,
Expand Down Expand Up @@ -421,7 +414,6 @@ def convert_input_layouts(device, inputs, fbb, program_index):
debug_env = ttrt.runtime.DebugEnv.get(self["--load-kernels-from-disk"])
self.logging.debug(f"setting tt runtime debug env={debug_env}")
workaround_env = ttrt.runtime.WorkaroundEnv.get(
not self["--disable-maxpool2d-preshard"],
not self["--disable-swap-binary-operands"],
not self["--disable-read-update-index-for-kv-cache"],
not self["--disable-to-layout-api-assume-single-chip"],
Expand Down

0 comments on commit 80d6a8d

Please sign in to comment.