Skip to content

Commit

Permalink
Add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
derekelewis committed Aug 8, 2024
1 parent 14ba838 commit afe1989
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions cuda/torchMatrixMultiply.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ torch::Tensor cuda_matrixMultiply(const torch::Tensor &a, const torch::Tensor &b
float *b_ptr = b_contiguous.data_ptr<float>();
float *result_ptr = result.data_ptr<float>();

// Assumes square matrices and we cast to int for simplicity
// and compatibility with our existing kernel code. In practice,
// we would need to handle non-square matrices and use an unsigned long
// to match PyTorch's tensor sizes.
int dim{static_cast<int>(a.sizes()[0])};

dim3 blockSize(16, 16);
Expand Down

0 comments on commit afe1989

Please sign in to comment.