From c55ab6a41edd8685f66ce9d7ce710ddd69678645 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Tue, 28 May 2024 12:32:44 +0200 Subject: [PATCH 1/2] fix position bias in tensor parallel --- src/devices.cc | 2 +- src/layers/attention.cc | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/devices.cc b/src/devices.cc index 47582f8be..a2936e0a6 100644 --- a/src/devices.cc +++ b/src/devices.cc @@ -196,7 +196,7 @@ namespace ctranslate2 { for (auto* comm : _nccl_comms) { //finalizing NCCL if (*comm) { - NCCL_CHECK(ncclCommAbort(*comm)); + NCCL_CHECK(ncclCommFinalize(*comm)); NCCL_CHECK(ncclCommDestroy(*comm)); } } diff --git a/src/layers/attention.cc b/src/layers/attention.cc index f340c44f9..18e2710f7 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -1,5 +1,7 @@ #include "ctranslate2/layers/attention.h" #include "ctranslate2/ops/split.h" +#include "ctranslate2/utils.h" + #include #include @@ -210,11 +212,20 @@ namespace ctranslate2 { is_decoder, with_cache ? key_length - 1 : 0); } + StorageView* position_bias_per_gpu = position_bias; + StorageView position_bias_tmp(position_bias->dtype(), position_bias->device()); + if (ScopedMPISetter::getCurRank() != 0) { + const dim_t num_head_per_gpu = SAFE_DIVIDE(position_bias->dim(0), ScopedMPISetter::getNRanks()); + ops::Slide slide_ops(0, num_head_per_gpu * ScopedMPISetter::getCurRank(), + num_head_per_gpu, true); + slide_ops(*position_bias, position_bias_tmp); + position_bias_per_gpu = &position_bias_tmp; + } DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(), - primitives::add_batch_broadcast(position_bias->data(), + primitives::add_batch_broadcast(position_bias_per_gpu->data(), output.data(), - position_bias->size(), + position_bias_per_gpu->size(), output.size())); } From 99af848c8cb6a54f42464dee7361dc1086189433 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Wed, 29 May 2024 10:40:20 +0200 Subject: [PATCH 2/2] add symbol ncclCommFinalize --- src/cuda/nccl_stub.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cuda/nccl_stub.cc b/src/cuda/nccl_stub.cc index 669518cb2..8a782f23d 100644 --- a/src/cuda/nccl_stub.cc +++ b/src/cuda/nccl_stub.cc @@ -69,9 +69,9 @@ extern "C" { return func(comm); } - ncclResult_t ncclCommAbort(ncclComm_t comm) { + ncclResult_t ncclCommFinalize(ncclComm_t comm) { using Signature = ncclResult_t(*)(ncclComm_t comm); - static auto func = ctranslate2::load_symbol("ncclCommAbort"); + static auto func = ctranslate2::load_symbol("ncclCommFinalize"); return func(comm); }