-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ROCm support #3462
ROCm support #3462
Changes from 11 commits
c9604bc
396635d
0fc58d2
b57b58b
b999e0e
4faa88b
626dc1a
9654632
fb2be88
9d2d750
922a4ed
67e7bfd
fd99f8c
7db1b3d
8cd59e4
0283aa9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -197,6 +197,16 @@ function(generate_ivf_interleaved_code) | |
"64|2048|8" | ||
) | ||
|
||
if(USE_ROCM) | ||
set(CU_OR_HIP "hip") | ||
else() | ||
set(CU_OR_HIP "cu") | ||
endif() | ||
|
||
if (USE_ROCM) | ||
list(TRANSFORM FAISS_GPU_SRC REPLACE cu$ hip) | ||
endif() | ||
|
||
# Traverse through the Cartesian product of X and Y | ||
foreach(sub_codec ${SUB_CODEC_TYPE}) | ||
foreach(metric_type ${SUB_METRIC_TYPE}) | ||
|
@@ -210,10 +220,10 @@ function(generate_ivf_interleaved_code) | |
set(filename "template_${sub_codec}_${metric_type}_${sub_threads}_${sub_num_warp_q}_${sub_num_thread_q}") | ||
# Remove illegal characters from filename | ||
string(REGEX REPLACE "[^A-Za-z0-9_]" "" filename ${filename}) | ||
set(output_file "${CMAKE_CURRENT_BINARY_DIR}/${filename}.cu") | ||
set(output_file "${CMAKE_CURRENT_BINARY_DIR}/${filename}.${CU_OR_HIP}") | ||
|
||
# Read the template file | ||
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/impl/scan/IVFInterleavedScanKernelTemplate.cu" template_content) | ||
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/impl/scan/IVFInterleavedScanKernelTemplate.${CU_OR_HIP}" template_content) | ||
|
||
# Replace the placeholders | ||
string(REPLACE "SUB_CODEC_TYPE" "${sub_codec}" template_content "${template_content}") | ||
|
@@ -290,6 +300,10 @@ if(FAISS_ENABLE_RAFT) | |
target_compile_definitions(faiss_gpu PUBLIC USE_NVIDIA_RAFT=1) | ||
endif() | ||
|
||
if (USE_ROCM) | ||
list(TRANSFORM FAISS_GPU_SRC REPLACE cu$ hip) | ||
endif() | ||
|
||
# Export FAISS_GPU_HEADERS variable to parent scope. | ||
set(FAISS_GPU_HEADERS ${FAISS_GPU_HEADERS} PARENT_SCOPE) | ||
|
||
|
@@ -304,6 +318,12 @@ foreach(header ${FAISS_GPU_HEADERS}) | |
) | ||
endforeach() | ||
|
||
if (USE_ROCM) | ||
find_package(HIP REQUIRED) | ||
find_package(hipBLAS REQUIRED) | ||
target_link_libraries(faiss_gpu PRIVATE hip::host roc::hipblas) | ||
target_compile_options(faiss_gpu PRIVATE) | ||
else() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# Prepares a host linker script and enables host linker to support | ||
# very large device object files. | ||
# This is what CUDA 11.5+ `nvcc -hls=gen-lcs -aug-hls` would generate | ||
|
@@ -322,3 +342,4 @@ target_link_options(faiss_gpu PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") | |
find_package(CUDAToolkit REQUIRED) | ||
target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$<BOOL:${FAISS_ENABLE_RAFT}>:raft::raft> $<$<BOOL:${FAISS_ENABLE_RAFT}>:raft::compiled> $<$<BOOL:${FAISS_ENABLE_RAFT}>:nvidia::cutlass::cutlass> $<$<BOOL:${FAISS_ENABLE_RAFT}>:OpenMP::OpenMP_CXX>) | ||
target_compile_options(faiss_gpu PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr $<$<BOOL:${FAISS_ENABLE_RAFT}>:-Xcompiler=${OpenMP_CXX_FLAGS}>>) | ||
endif() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -365,8 +365,8 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) { | |
|
||
// Our code is pre-built with and expects warpSize == 32, validate that | ||
FAISS_ASSERT_FMT( | ||
prop.warpSize == 32, | ||
"Device id %d does not have expected warpSize of 32", | ||
prop.warpSize == 32 || prop.warpSize == 64, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this ROCm specific? If so, can we allow 64 only for ROCm? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have both wavefront 32 (E.g. navi) and 64 (E.g. MI250) devices. So this offers support for both. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It sounds like Nvidia is 32 only and ROCm is 32 or 64. Should we lock it accordingly in code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If that is desired, I could rework that assert using a ROCm flag to only allow a warpSize of 64 (and 32) on ROCm devices. It shouldn't be an issue at all! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think I would do that. |
||
"Device id %d does not have expected warpSize of 32 or 64", | ||
device); | ||
|
||
// Create streams | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
#!/bin/bash | ||
|
||
# go one level up from faiss/gpu | ||
top=$(dirname "${BASH_SOURCE[0]}")/.. | ||
echo "top=$top" | ||
cd $top | ||
echo "pwd=`pwd`" | ||
|
||
# create all destination directories for hipified files into sibling 'gpu-rocm' directory | ||
for src in $(find ./gpu -type d) | ||
do | ||
dst=$(echo $src | sed 's/gpu/gpu-rocm/') | ||
echo "Creating $dst" | ||
mkdir -p $dst | ||
done | ||
|
||
# run hipify-perl against all *.cu *.cuh *.h *.cpp files, no renaming | ||
# run all files in parallel to speed up | ||
for ext in cu cuh h cpp | ||
do | ||
for src in $(find ./gpu -name "*.$ext") | ||
do | ||
dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') | ||
hipify-perl -o=$dst.tmp $src & | ||
done | ||
done | ||
wait | ||
|
||
# rename all hipified *.cu files to *.hip | ||
for src in $(find ./gpu-rocm -name "*.cu.tmp") | ||
do | ||
dst=${src%.cu.tmp}.hip.tmp | ||
mv $src $dst | ||
done | ||
|
||
# replace header include statements "<faiss/gpu/" with "<faiss/gpu-rocm" | ||
# replace thrust::cuda::par with thrust::hip::par | ||
# adjust header path location for hipblas.h to avoid unnecessary deprecation warnings | ||
# adjust header path location for hiprand_kernel.h to avoid unnecessary deprecation warnings | ||
for ext in hip cuh h cpp | ||
do | ||
for src in $(find ./gpu-rocm -name "*.$ext.tmp") | ||
do | ||
sed -i 's@#include <faiss/gpu/@#include <faiss/gpu-rocm/@' $src | ||
sed -i 's@thrust::cuda::par@thrust::hip::par@' $src | ||
sed -i 's@#include <hipblas.h>@#include <hipblas/hipblas.h>@' $src | ||
sed -i 's@#include <hiprand_kernel.h>@#include <hiprand/hiprand_kernel.h>@' $src | ||
done | ||
done | ||
|
||
# hipify was run in parallel above | ||
# don't copy the tmp file if it is unchanged | ||
for ext in hip cuh h cpp | ||
do | ||
for src in $(find ./gpu-rocm -name "*.$ext.tmp") | ||
do | ||
dst=${src%.tmp} | ||
if test -f $dst | ||
then | ||
if diff -q $src $dst >& /dev/null | ||
then | ||
echo "$dst [unchanged]" | ||
rm $src | ||
else | ||
echo "$dst" | ||
mv $src $dst | ||
fi | ||
else | ||
echo "$dst" | ||
mv $src $dst | ||
fi | ||
done | ||
done | ||
|
||
# copy over CMakeLists.txt | ||
for src in $(find ./gpu -name "CMakeLists.txt") | ||
do | ||
dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') | ||
if test -f $dst | ||
then | ||
if diff -q $src $dst >& /dev/null | ||
then | ||
echo "$dst [unchanged]" | ||
else | ||
echo "$dst" | ||
cp $src $dst | ||
fi | ||
else | ||
echo "$dst" | ||
cp $src $dst | ||
fi | ||
done | ||
|
||
# Copy over other files | ||
for ext in py | ||
do | ||
for src in $(find ./gpu -name "*.$ext") | ||
do | ||
dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') | ||
if test -f $dst | ||
then | ||
if diff -q $src $dst >& /dev/null | ||
then | ||
echo "$dst [unchanged]" | ||
else | ||
echo "$dst" | ||
cp $src $dst | ||
fi | ||
else | ||
echo "$dst" | ||
cp $src $dst | ||
fi | ||
done | ||
done | ||
|
||
|
||
################################################################################### | ||
# C_API Support | ||
################################################################################### | ||
|
||
# Now get the c_api dir | ||
# This points to the faiss/c_api dir | ||
top_c_api=$(dirname "${BASH_SOURCE[0]}")/../../c_api | ||
echo "top=$top_c_api" | ||
cd ../$top_c_api | ||
echo "pwd=`pwd`" | ||
|
||
|
||
# create all destination directories for hipified files into sibling 'gpu-rocm' directory | ||
for src in $(find ./gpu -type d) | ||
do | ||
dst=$(echo $src | sed 's/gpu/gpu-rocm/') | ||
echo "Creating $dst" | ||
mkdir -p $dst | ||
done | ||
|
||
# run hipify-perl against all *.cu *.cuh *.h *.cpp files, no renaming | ||
# run all files in parallel to speed up | ||
for ext in cu cuh h cpp c | ||
do | ||
for src in $(find ./gpu -name "*.$ext") | ||
do | ||
dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') | ||
hipify-perl -o=$dst.tmp $src & | ||
done | ||
done | ||
wait | ||
|
||
# rename all hipified *.cu files to *.hip | ||
for src in $(find ./gpu-rocm -name "*.cu.tmp") | ||
do | ||
dst=${src%.cu.tmp}.hip.tmp | ||
mv $src $dst | ||
done | ||
|
||
# replace header include statements "<faiss/gpu/" with "<faiss/gpu-rocm" | ||
# replace thrust::cuda::par with thrust::hip::par | ||
# adjust header path location for hipblas.h to avoid unnecessary deprecation warnings | ||
# adjust header path location for hiprand_kernel.h to avoid unnecessary deprecation warnings | ||
for ext in hip cuh h cpp c | ||
do | ||
for src in $(find ./gpu-rocm -name "*.$ext.tmp") | ||
do | ||
sed -i 's@#include <faiss/gpu/@#include <faiss/gpu-rocm/@' $src | ||
sed -i 's@thrust::cuda::par@thrust::hip::par@' $src | ||
sed -i 's@#include <hipblas.h>@#include <hipblas/hipblas.h>@' $src | ||
sed -i 's@#include <hiprand_kernel.h>@#include <hiprand/hiprand_kernel.h>@' $src | ||
done | ||
done | ||
|
||
# hipify was run in parallel above | ||
# don't copy the tmp file if it is unchanged | ||
for ext in hip cuh h cpp c | ||
do | ||
for src in $(find ./gpu-rocm -name "*.$ext.tmp") | ||
do | ||
dst=${src%.tmp} | ||
if test -f $dst | ||
then | ||
if diff -q $src $dst >& /dev/null | ||
then | ||
echo "$dst [unchanged]" | ||
rm $src | ||
else | ||
echo "$dst" | ||
mv $src $dst | ||
fi | ||
else | ||
echo "$dst" | ||
mv $src $dst | ||
fi | ||
done | ||
done | ||
|
||
# copy over CMakeLists.txt | ||
for src in $(find ./gpu -name "CMakeLists.txt") | ||
do | ||
dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') | ||
if test -f $dst | ||
then | ||
if diff -q $src $dst >& /dev/null | ||
then | ||
echo "$dst [unchanged]" | ||
else | ||
echo "$dst" | ||
cp $src $dst | ||
fi | ||
else | ||
echo "$dst" | ||
cp $src $dst | ||
fi | ||
done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you name it something like
GPU_EXT_PREFIX
?