Skip to content

Commit

Permalink
Fix calls to sizes to enable dynamic shapes with sdpa (#96674)
Browse files Browse the repository at this point in the history
Fixes part of #96414

Replaces any calls to sizes, with sym_sizes. Still seeing an error with the repro script:
``` Bash
Exception raised from sizes_default at /scratch/drisspg/work/pytorch/c10/core/TensorImpl.h:635 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x7d (0x7f697f4a141d in /scratch/drisspg/work/pytorch/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0xdd (0x7f697f49fbcd in /scratch/drisspg/work/pytorch/torch/lib/libc10.so)
frame #2: c10::TensorImpl::sizes_custom() const + 0x95 (0x7f697f4824c5 in /scratch/drisspg/work/pytorch/torch/lib/libc10.so)
frame #3: at::native::empty_like(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, c10::optional<c10::MemoryFormat>) + 0x92c (0x7f69809d18ac in /scratch/drisspg/work/pytorch/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x23f5ce7 (0x7f698193bce7 in /scratch/drisspg/work/pytorch/torch/lib/libtorch_cpu.so)
```

still trying to track down this empty call

from the looks of it, might be coming from at::layer_norm?
the BT from lldb is 221 frames however, so lots of noise

Pull Request resolved: pytorch/pytorch#96674
Approved by: https://github.com/ezyang
  • Loading branch information
drisspg authored and cyyever committed Mar 23, 2023
1 parent 8fe782c commit 7ca417d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 33 deletions.
62 changes: 32 additions & 30 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <functional>
#include <cmath>
#include <c10/core/SymInt.h>
#include <c10/util/string_view.h>

namespace sdp {
Expand Down Expand Up @@ -55,27 +56,28 @@ inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
// FlashAttention parallelizes across "batch_size * num_heads"
// MemEff parallelizes across "batch_size * num_heads * num_queries" and can
// be more efficient. batch_size, q_len, num_heads, k = inp.query.shape
if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) {
if (params.query.is_nested() || params.key.is_nested() ||
params.value.is_nested()) {
// See check_for_nested_inputs for details
return {
SDPBackend::efficient_attention,
SDPBackend::flash_attention,
SDPBackend::math};
}
const auto sizes = params.query.sizes();
if (params.query.dim() != 4) {
return default_order;
}
const auto batch_size{sizes[0]}, num_heads{sizes[1]}, query_lengths{sizes[2]},
head_dim{sizes[3]};
const auto batch_size{params.query.sym_size(0)},
num_heads{params.query.sym_size(1)},
query_lengths{params.query.sym_size(2)},
head_dim{params.query.sym_size(3)};
if (batch_size > 0) {
const int64_t threads_flash = batch_size * num_heads;
const int64_t threads_cutlass =
threads_flash * (int64_t)std::floor(query_lengths / 64);
bool more_threads_cutlass =
(int64_t)std::floor(threads_cutlass / 2) >= threads_flash;
const auto threads_flash = batch_size * num_heads;
const auto threads_cutlass =
threads_flash * (query_lengths / c10::SymInt(64));
bool more_threads_cutlass = (threads_cutlass / 2) >= threads_flash;
bool small_threads_flash = threads_flash < 60;
bool large_head_dim = std::max(head_dim, params.key.sizes()[3]) == 128;
bool large_head_dim = head_dim.max(params.key.sym_size(3)) == 128;
if ((small_threads_flash && more_threads_cutlass) || large_head_dim) {
return {
SDPBackend::efficient_attention,
Expand Down Expand Up @@ -130,9 +132,9 @@ inline bool check_for_nested_inputs(sdp_params params){
return false;
}

inline bool try_broadcast_param_size(int64_t q_size,
int64_t k_size,
int64_t v_size,
inline bool try_broadcast_param_size(const c10::SymInt q_size,
const c10::SymInt k_size,
const c10::SymInt v_size,
c10::string_view param_name,
bool debug) {
auto max_size = std::max({q_size, k_size, v_size});
Expand Down Expand Up @@ -329,9 +331,9 @@ inline bool check_safe_kv_broadcast(at::Tensor param, bool debug){
inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the size()
// calls won't error since the inputs are all 4 dimensional
auto q_batch_size = params.query.size(0);
auto k_batch_size = params.key.size(0);
auto v_batch_size = params.value.size(0);
auto q_batch_size = params.query.sym_size(0);
auto k_batch_size = params.key.sym_size(0);
auto v_batch_size = params.value.sym_size(0);

bool has_nested_input = check_for_nested_inputs(params);
bool same_batch_size = q_batch_size == k_batch_size && q_batch_size == v_batch_size;
Expand Down Expand Up @@ -362,9 +364,9 @@ inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) {
return broadcastable_batch_size;
}

auto q_num_heads = params.query.size(1);
auto k_num_heads = params.key.size(1);
auto v_num_heads = params.value.size(1);
auto q_num_heads = params.query.sym_size(1);
auto k_num_heads = params.key.sym_size(1);
auto v_num_heads = params.value.sym_size(1);
bool same_num_heads = q_num_heads == k_num_heads && q_num_heads == v_num_heads;

if (!(same_batch_size && same_num_heads)) {
Expand All @@ -385,9 +387,9 @@ inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) {
}

inline bool check_head_dim_size(sdp_params params, bool debug) {
const int64_t query_size_last = params.query.size(-1);
const int64_t key_size_last = params.key.size(-1);
const int64_t value_size_last = params.value.size(-1);
const auto query_size_last = params.query.sym_size(-1);
const auto key_size_last = params.key.sym_size(-1);
const auto value_size_last = params.value.sym_size(-1);
if (!(query_size_last == key_size_last &&
query_size_last == value_size_last && query_size_last % 8 == 0 &&
query_size_last <= 128 && value_size_last % 8 == 0 &&
Expand All @@ -398,9 +400,9 @@ inline bool check_head_dim_size(sdp_params params, bool debug) {
" Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.size(-1),
params.key.sym_size(-1),
", Value.size(-1): ",
params.value.size(-1),
params.value.sym_size(-1),
" instead.");
}
return false;
Expand Down Expand Up @@ -437,10 +439,10 @@ inline int64_t minimum_gemm_alignment(sdp_params params) {
}

inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) {
const int64_t query_size_last = params.query.size(-1);
const int64_t value_size_last = params.value.size(-1);
const auto query_size_last = params.query.sym_size(-1);
const auto value_size_last = params.value.sym_size(-1);
const int64_t alignment = minimum_gemm_alignment(params);
if (!(query_size_last == params.key.size(-1) &&
if (!(query_size_last == params.key.sym_size(-1) &&
query_size_last % alignment == 0 && query_size_last > 0 &&
value_size_last % alignment == 0 && value_size_last > 0)) {
if (debug) {
Expand All @@ -451,9 +453,9 @@ inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) {
"Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.size(-1),
params.key.sym_size(-1),
", Value.size(-1): ",
params.value.size(-1),
params.value.sym_size(-1),
" instead.");
}
return false;
Expand Down Expand Up @@ -527,7 +529,7 @@ inline bool check_gpu_sm86_head_dim_128(sdp_params params, bool debug) {
// on sm86 when head_dim is 128.
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm86 = (dprops->major == 8) && (dprops->minor == 6);
if (is_sm86 && (params.query.size(-1) == 128)) {
if (is_sm86 && (params.query.sym_size(-1) == 128)) {
if (debug) {
TORCH_WARN(
"Memory Efficient Attention does not currently support head_dim == 128 on sm86",
Expand Down
4 changes: 1 addition & 3 deletions test/dynamo/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
test_classes = {}

ALL_DYNAMIC_XFAILS = {
"MiscTests": [
"test_parsing_sdpa",
],
"MiscTests": [],
"ReproTests": [
# Could not infer dtype of torch._C.SymIntNode
"test_convert_boxes_to_pooler_format",
Expand Down

0 comments on commit 7ca417d

Please sign in to comment.