Skip to content

Commit

Permalink
[Opt] Support Sm70 on V100
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Sep 12, 2023
1 parent 1152d4f commit 36d8bd0
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 20 deletions.
2 changes: 1 addition & 1 deletion benchmark/merlin_hashtable_benchmark.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ void test_main(std::vector<API_Select>& apis, const size_t dim,
options.dim = dim;
options.max_hbm_for_vectors = nv::merlin::GB(hbm4values);
options.io_by_cpu = io_by_cpu;
using Table = nv::merlin::HashTable<K, V, S, EvictStrategy::kLru>;
using Table = nv::merlin::HashTable<K, V, S, EvictStrategy::kLru, Sm80>;

std::shared_ptr<Table> table = std::make_shared<Table>();
table->init(options);
Expand Down
6 changes: 6 additions & 0 deletions include/merlin/core_kernels/lookup.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,12 @@ struct LookupValueBufConfig<Sm80> {
static constexpr uint32_t size_pipeline_v2 = 128 * sizeof(float);
};

template <>
struct LookupValueBufConfig<Sm70> {
static constexpr uint32_t size_pipeline_v1 = 112 * sizeof(float);
static constexpr uint32_t size_pipeline_v2 = 64 * sizeof(float);
};

template <typename K, typename V, typename S = uint64_t,
typename ArchTag = Sm80>
struct SelectPipelineLookupKernelWithIO {
Expand Down
15 changes: 10 additions & 5 deletions include/merlin/core_kernels/update.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,17 @@ struct ValueConfig_Update<Sm80> {
static constexpr uint32_t size_pipeline = 128 * sizeof(byte4);
};

template <typename K, typename V, typename S, int Strategy, typename ArchTag>
struct KernelSelector_Update;
template <>
struct ValueConfig_Update<Sm70> {
// Value size greater than it will bring poor performance for TLP.
static constexpr uint32_t size_tlp = 8 * sizeof(byte4);
// Value size greater than it will reduce the occupancy for Pipeline.
// When the value is very high, the kernel will fail to launch.
static constexpr uint32_t size_pipeline = 64 * sizeof(byte4);
};

template <typename K, typename V, typename S, int Strategy>
struct KernelSelector_Update<K, V, S, Strategy, Sm80> {
using ArchTag = Sm80;
template <typename K, typename V, typename S, int Strategy, typename ArchTag>
struct KernelSelector_Update {
using ValueConfig = ValueConfig_Update<ArchTag>;
using Params = Params_Update<K, V, S>;

Expand Down
7 changes: 1 addition & 6 deletions include/merlin/core_kernels/update_score.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -502,13 +502,8 @@ struct Launch_Pipeline_UpdateScore {
}
};

/// TODO: support more arch.
template <typename K, typename V, typename S, int Strategy, typename ArchTag>
struct KernelSelector_UpdateScore;

template <typename K, typename V, typename S, int Strategy>
struct KernelSelector_UpdateScore<K, V, S, Strategy, Sm80> {
using ArchTag = Sm80;
struct KernelSelector_UpdateScore {
using Params = Params_UpdateScore<K, V, S>;

static bool callable(bool unique_key, uint32_t bucket_size) {
Expand Down
17 changes: 12 additions & 5 deletions include/merlin/core_kernels/upsert_and_evict.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1234,12 +1234,19 @@ struct ValueConfig_UpsertAndEvict<Sm80> {
static constexpr uint32_t size_pipeline = 128 * sizeof(byte4);
};

template <typename K, typename V, typename S, int Strategy, typename ArchTag>
struct KernelSelector_UpsertAndEvict;
template <>
struct ValueConfig_UpsertAndEvict<Sm70> {
// Value size greater than it will bring poor performance for TLPv1.
static constexpr uint32_t size_tlp_v1 = 16 * sizeof(byte4);
// Value size greater than it will bring wrong result for TLPv2.
static constexpr uint32_t size_tlp_v2 = 32 * sizeof(byte4);
// Value size greater than it will reduce the occupancy for Pipeline.
// When the value is very high, the kernel will fail to launch.
static constexpr uint32_t size_pipeline = 64 * sizeof(byte4);
};

template <typename K, typename V, typename S, int Strategy>
struct KernelSelector_UpsertAndEvict<K, V, S, Strategy, Sm80> {
using ArchTag = Sm80;
template <typename K, typename V, typename S, int Strategy, typename ArchTag>
struct KernelSelector_UpsertAndEvict {
using ValueConfig = ValueConfig_UpsertAndEvict<ArchTag>;
using Params = Params_UpsertAndEvict<K, V, S>;

Expand Down
5 changes: 2 additions & 3 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1072,9 +1072,8 @@ class HashTable {
if (((step_counter++) % kernel_select_interval_) == 0) {
load_factor = fast_load_factor(0, stream, false);
}
using Selector =
KernelSelector_UpdateScore<key_type, value_type, score_type,
evict_strategy, ArchTag>;
using Selector = KernelSelector_UpdateScore<key_type, value_type,
score_type, evict_strategy>;
if (Selector::callable(unique_key,
static_cast<uint32_t>(options_.max_bucket_size))) {
typename Selector::Params kernelParams(
Expand Down

0 comments on commit 36d8bd0

Please sign in to comment.