diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index a56ce1cd3d4..c2f74942f7a 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -16,6 +16,7 @@ #include #include +#include #include namespace sdp { @@ -55,27 +56,28 @@ inline std::array priority_order(sdp_params params) { // FlashAttention parallelizes across "batch_size * num_heads" // MemEff parallelizes across "batch_size * num_heads * num_queries" and can // be more efficient. batch_size, q_len, num_heads, k = inp.query.shape - if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) { + if (params.query.is_nested() || params.key.is_nested() || + params.value.is_nested()) { // See check_for_nested_inputs for details return { SDPBackend::efficient_attention, SDPBackend::flash_attention, SDPBackend::math}; } - const auto sizes = params.query.sizes(); if (params.query.dim() != 4) { return default_order; } - const auto batch_size{sizes[0]}, num_heads{sizes[1]}, query_lengths{sizes[2]}, - head_dim{sizes[3]}; + const auto batch_size{params.query.sym_size(0)}, + num_heads{params.query.sym_size(1)}, + query_lengths{params.query.sym_size(2)}, + head_dim{params.query.sym_size(3)}; if (batch_size > 0) { - const int64_t threads_flash = batch_size * num_heads; - const int64_t threads_cutlass = - threads_flash * (int64_t)std::floor(query_lengths / 64); - bool more_threads_cutlass = - (int64_t)std::floor(threads_cutlass / 2) >= threads_flash; + const auto threads_flash = batch_size * num_heads; + const auto threads_cutlass = + threads_flash * (query_lengths / c10::SymInt(64)); + bool more_threads_cutlass = (threads_cutlass / 2) >= threads_flash; bool small_threads_flash = threads_flash < 60; - bool large_head_dim = std::max(head_dim, params.key.sizes()[3]) == 128; + bool large_head_dim = head_dim.max(params.key.sym_size(3)) == 128; if ((small_threads_flash && more_threads_cutlass) || large_head_dim) { return { SDPBackend::efficient_attention, @@ -130,9 +132,9 @@ inline bool check_for_nested_inputs(sdp_params params){ return false; } -inline bool try_broadcast_param_size(int64_t q_size, - int64_t k_size, - int64_t v_size, +inline bool try_broadcast_param_size(const c10::SymInt q_size, + const c10::SymInt k_size, + const c10::SymInt v_size, c10::string_view param_name, bool debug) { auto max_size = std::max({q_size, k_size, v_size}); @@ -329,9 +331,9 @@ inline bool check_safe_kv_broadcast(at::Tensor param, bool debug){ inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) { // This is expected to be called after check_tensor_shapes ensuring that the size() // calls won't error since the inputs are all 4 dimensional - auto q_batch_size = params.query.size(0); - auto k_batch_size = params.key.size(0); - auto v_batch_size = params.value.size(0); + auto q_batch_size = params.query.sym_size(0); + auto k_batch_size = params.key.sym_size(0); + auto v_batch_size = params.value.sym_size(0); bool has_nested_input = check_for_nested_inputs(params); bool same_batch_size = q_batch_size == k_batch_size && q_batch_size == v_batch_size; @@ -362,9 +364,9 @@ inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) { return broadcastable_batch_size; } - auto q_num_heads = params.query.size(1); - auto k_num_heads = params.key.size(1); - auto v_num_heads = params.value.size(1); + auto q_num_heads = params.query.sym_size(1); + auto k_num_heads = params.key.sym_size(1); + auto v_num_heads = params.value.sym_size(1); bool same_num_heads = q_num_heads == k_num_heads && q_num_heads == v_num_heads; if (!(same_batch_size && same_num_heads)) { @@ -385,9 +387,9 @@ inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) { } inline bool check_head_dim_size(sdp_params params, bool debug) { - const int64_t query_size_last = params.query.size(-1); - const int64_t key_size_last = params.key.size(-1); - const int64_t value_size_last = params.value.size(-1); + const auto query_size_last = params.query.sym_size(-1); + const auto key_size_last = params.key.sym_size(-1); + const auto value_size_last = params.value.sym_size(-1); if (!(query_size_last == key_size_last && query_size_last == value_size_last && query_size_last % 8 == 0 && query_size_last <= 128 && value_size_last % 8 == 0 && @@ -398,9 +400,9 @@ inline bool check_head_dim_size(sdp_params params, bool debug) { " Got Query.size(-1): ", query_size_last, ", Key.size(-1): ", - params.key.size(-1), + params.key.sym_size(-1), ", Value.size(-1): ", - params.value.size(-1), + params.value.sym_size(-1), " instead."); } return false; @@ -437,10 +439,10 @@ inline int64_t minimum_gemm_alignment(sdp_params params) { } inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) { - const int64_t query_size_last = params.query.size(-1); - const int64_t value_size_last = params.value.size(-1); + const auto query_size_last = params.query.sym_size(-1); + const auto value_size_last = params.value.sym_size(-1); const int64_t alignment = minimum_gemm_alignment(params); - if (!(query_size_last == params.key.size(-1) && + if (!(query_size_last == params.key.sym_size(-1) && query_size_last % alignment == 0 && query_size_last > 0 && value_size_last % alignment == 0 && value_size_last > 0)) { if (debug) { @@ -451,9 +453,9 @@ inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) { "Got Query.size(-1): ", query_size_last, ", Key.size(-1): ", - params.key.size(-1), + params.key.sym_size(-1), ", Value.size(-1): ", - params.value.size(-1), + params.value.sym_size(-1), " instead."); } return false; @@ -527,7 +529,7 @@ inline bool check_gpu_sm86_head_dim_128(sdp_params params, bool debug) { // on sm86 when head_dim is 128. auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm86 = (dprops->major == 8) && (dprops->minor == 6); - if (is_sm86 && (params.query.size(-1) == 128)) { + if (is_sm86 && (params.query.sym_size(-1) == 128)) { if (debug) { TORCH_WARN( "Memory Efficient Attention does not currently support head_dim == 128 on sm86", diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 7b6b60781b6..1295fdcc67a 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -25,9 +25,7 @@ test_classes = {} ALL_DYNAMIC_XFAILS = { - "MiscTests": [ - "test_parsing_sdpa", - ], + "MiscTests": [], "ReproTests": [ # Could not infer dtype of torch._C.SymIntNode "test_convert_boxes_to_pooler_format",