From ebd5efeedf86d865e93c85abfdf9fd8c8f5b3037 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Thu, 23 May 2024 02:09:25 -0700 Subject: [PATCH] Add basic bf16 support to ggml-cuda --- ggml-cuda/common.cuh | 7 +++++-- ggml-cuda/convert.cu | 2 ++ ggml-cuda/dmmv.cu | 20 ++++++++++++++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 8f6fd71cfea35..2b15106e36a02 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -25,10 +25,12 @@ #include #include #include +#include #ifdef __HIP_PLATFORM_AMD__ // for rocblas_initialize() #include "rocblas/rocblas.h" #endif // __HIP_PLATFORM_AMD__ +#define __nv_bfloat16 hip_bfloat16 #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F @@ -38,8 +40,8 @@ #define CUBLAS_OP_T HIPBLAS_OP_T #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_TF32_TENSOR_OP_MATH 0 -#define CUDA_R_16F HIPBLAS_R_16F -#define CUDA_R_32F HIPBLAS_R_32F +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate @@ -123,6 +125,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index c0a4447075c6e..736b7091162fb 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda<__nv_bfloat16>; default: return nullptr; } diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 47d4d5d9e91da..33c4e5ed16ea6 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -422,6 +422,14 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int v.y = x[ib + iqs + 1]; } +static __device__ void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const __nv_bfloat16 * x = (const __nv_bfloat16 *) vx; + + // automatic __nv_bfloat16 -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} + template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block @@ -584,6 +592,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa <<>>(vx, y, dst, ncols, nrows); } +static void convert_mul_mat_vec_bf16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_bf16> + <<>>(vx, y, dst, ncols, nrows); +} + void ggml_cuda_op_dequantize_mul_mat_vec( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, @@ -649,6 +666,9 @@ void ggml_cuda_op_dequantize_mul_mat_vec( case GGML_TYPE_F16: convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; + case GGML_TYPE_BF16: + convert_mul_mat_vec_bf16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; default: GGML_ASSERT(false); break;