From 8cc003becd7f391d98fca26a398590876bf0155e Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 11 Sep 2023 16:51:45 +0800 Subject: [PATCH] [Opt] Support Sm70 on V100 --- benchmark/merlin_hashtable_benchmark.cc.cu | 2 +- include/merlin/core_kernels/lookup.cuh | 6 ++++++ include/merlin/core_kernels/update.cuh | 15 ++++++++++----- include/merlin/core_kernels/update_score.cuh | 7 +------ .../merlin/core_kernels/upsert_and_evict.cuh | 17 ++++++++++++----- include/merlin_hashtable.cuh | 5 ++--- 6 files changed, 32 insertions(+), 20 deletions(-) diff --git a/benchmark/merlin_hashtable_benchmark.cc.cu b/benchmark/merlin_hashtable_benchmark.cc.cu index 7e9637b29..a0a807c32 100644 --- a/benchmark/merlin_hashtable_benchmark.cc.cu +++ b/benchmark/merlin_hashtable_benchmark.cc.cu @@ -480,7 +480,7 @@ void test_main(std::vector& apis, const size_t dim, options.dim = dim; options.max_hbm_for_vectors = nv::merlin::GB(hbm4values); options.io_by_cpu = io_by_cpu; - using Table = nv::merlin::HashTable; + using Table = nv::merlin::HashTable; std::shared_ptr table = std::make_shared
(); table->init(options); diff --git a/include/merlin/core_kernels/lookup.cuh b/include/merlin/core_kernels/lookup.cuh index 007573d84..b6665d60c 100644 --- a/include/merlin/core_kernels/lookup.cuh +++ b/include/merlin/core_kernels/lookup.cuh @@ -667,6 +667,12 @@ struct LookupValueBufConfig { static constexpr uint32_t size_pipeline_v2 = 128 * sizeof(float); }; +template <> +struct LookupValueBufConfig { + static constexpr uint32_t size_pipeline_v1 = 112 * sizeof(float); + static constexpr uint32_t size_pipeline_v2 = 64 * sizeof(float); +}; + template struct SelectPipelineLookupKernelWithIO { diff --git a/include/merlin/core_kernels/update.cuh b/include/merlin/core_kernels/update.cuh index eb50fe84a..2cec35ed0 100644 --- a/include/merlin/core_kernels/update.cuh +++ b/include/merlin/core_kernels/update.cuh @@ -567,12 +567,17 @@ struct ValueConfig_Update { static constexpr uint32_t size_pipeline = 128 * sizeof(byte4); }; -template -struct KernelSelector_Update; +template <> +struct ValueConfig_Update { + // Value size greater than it will bring poor performance for TLP. + static constexpr uint32_t size_tlp = 8 * sizeof(byte4); + // Value size greater than it will reduce the occupancy for Pipeline. + // When the value is very high, the kernel will fail to launch. + static constexpr uint32_t size_pipeline = 64 * sizeof(byte4); +}; -template -struct KernelSelector_Update { - using ArchTag = Sm80; +template +struct KernelSelector_Update { using ValueConfig = ValueConfig_Update; using Params = Params_Update; diff --git a/include/merlin/core_kernels/update_score.cuh b/include/merlin/core_kernels/update_score.cuh index 7293bfa36..aaa441533 100644 --- a/include/merlin/core_kernels/update_score.cuh +++ b/include/merlin/core_kernels/update_score.cuh @@ -502,13 +502,8 @@ struct Launch_Pipeline_UpdateScore { } }; -/// TODO: support more arch. -template -struct KernelSelector_UpdateScore; - template -struct KernelSelector_UpdateScore { - using ArchTag = Sm80; +struct KernelSelector_UpdateScore { using Params = Params_UpdateScore; static bool callable(bool unique_key, uint32_t bucket_size) { diff --git a/include/merlin/core_kernels/upsert_and_evict.cuh b/include/merlin/core_kernels/upsert_and_evict.cuh index 44bc20c14..88c523bef 100644 --- a/include/merlin/core_kernels/upsert_and_evict.cuh +++ b/include/merlin/core_kernels/upsert_and_evict.cuh @@ -1234,12 +1234,19 @@ struct ValueConfig_UpsertAndEvict { static constexpr uint32_t size_pipeline = 128 * sizeof(byte4); }; -template -struct KernelSelector_UpsertAndEvict; +template <> +struct ValueConfig_UpsertAndEvict { + // Value size greater than it will bring poor performance for TLPv1. + static constexpr uint32_t size_tlp_v1 = 16 * sizeof(byte4); + // Value size greater than it will bring wrong result for TLPv2. + static constexpr uint32_t size_tlp_v2 = 32 * sizeof(byte4); + // Value size greater than it will reduce the occupancy for Pipeline. + // When the value is very high, the kernel will fail to launch. + static constexpr uint32_t size_pipeline = 64 * sizeof(byte4); +}; -template -struct KernelSelector_UpsertAndEvict { - using ArchTag = Sm80; +template +struct KernelSelector_UpsertAndEvict { using ValueConfig = ValueConfig_UpsertAndEvict; using Params = Params_UpsertAndEvict; diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 718901fb0..44e4b6fa8 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -1072,9 +1072,8 @@ class HashTable { if (((step_counter++) % kernel_select_interval_) == 0) { load_factor = fast_load_factor(0, stream, false); } - using Selector = - KernelSelector_UpdateScore; + using Selector = KernelSelector_UpdateScore; if (Selector::callable(unique_key, static_cast(options_.max_bucket_size))) { typename Selector::Params kernelParams(