Skip to content

Commit

Permalink
Add memory index guard in wmma device ops (#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
aska-0096 authored Apr 11, 2023
1 parent f532988 commit e85178b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}

// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);

if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}

return true;
}

Expand Down
7 changes: 7 additions & 0 deletions include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}

// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);

if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
{
return false;
}
return true;
}

Expand Down

0 comments on commit e85178b

Please sign in to comment.