Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Support three new evict strategy(lfu, epoch_lfu, epoch_lru) #152

Merged
merged 4 commits into from
Sep 4, 2023

Conversation

rhdong
Copy link
Member

@rhdong rhdong commented Jul 5, 2023

No description provided.

@rhdong rhdong requested review from evanzhen and Lifann July 5, 2023 04:05
@github-actions
Copy link

github-actions bot commented Jul 5, 2023

@rhdong rhdong force-pushed the rhdong/epoch-lru branch 2 times, most recently from 178ec1d to 931e10f Compare July 5, 2023 04:15
@rhdong rhdong force-pushed the rhdong/epoch-lru branch from 931e10f to 2937056 Compare July 18, 2023 11:18
@rhdong rhdong requested a review from jiashuy July 18, 2023 11:18
@rhdong rhdong force-pushed the rhdong/epoch-lru branch from 2937056 to 837d1f5 Compare July 18, 2023 11:29
@rhdong
Copy link
Member Author

rhdong commented Jul 18, 2023

/blossom-ci

@@ -431,6 +453,8 @@ class HashTable {
score_type* evicted_scores, // (n)
size_type* d_evicted_counter, // (1)
cudaStream_t stream = 0,
const score_type global_epoch =
Copy link
Member Author

@rhdong rhdong Jul 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Lifann , please help review the change on API, this PR implemented two new strategy which needs global_epoch as the high 32bit of the score, for more detail please refer to https://github.com/rhdong/HierarchicalKV/tree/rhdong/epoch-lru#evict-strategy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fixed data format param for limited case will make the API hard to expand in the future.

Copy link
Member Author

@rhdong rhdong Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you, but it seems no better choice. Do you have a better recommendation?
Basically, I hope the caller has full right to control the epoch, and I believe the built-in epoch counter could've clearer API definition. It's a little hard...

Copy link
Member Author

@rhdong rhdong Aug 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a possible choice:

  • add a std::atomic<S> global_epoch; member for the Hashtable.
  • and add two member functions to operating it. And call them before calling the main API:
void global_epoch(const S epoch);
S global_epoch() const;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/*
 * A prototype to show how to get param inside functor implicitly.
 */

#include "cuda_runtime.h"
#include <stdio.h>
#include <vector>
#include <unistd.h>
using namespace std;

namespace functor {

__device__ int param_;

void set_param(int score) {
  int ret = 0;
  ret = cudaMemcpyToSymbol(param_, &score, sizeof(int), 0, cudaMemcpyHostToDevice);
  printf("symbol op1 ret=%d\n", ret);
}

struct Functor {
 public:
  __device__ inline int internal_get_param() {
    return ::functor::param_;
  }
  __device__ void operator()(int& key, int& score) {
    int inc = internal_get_param();
    score = inc;
  }
};

}

template <typename ScorePred>
__global__ void gpu_any_kernel(int* keys, int* scores, size_t n) {
  int tid = blockDim.x * blockIdx.x + threadIdx.x;
  ScorePred pred;
  if (tid < n) {
    pred(keys[tid], scores[tid]);
  }
}

void get_data(vector<int>& data, int** d_data) {
  int ret = 0;
  cudaMalloc(d_data, data.size() * sizeof(int));
  cudaMemcpy(*d_data, data.data(), data.size() * sizeof(int), cudaMemcpyHostToDevice);
}

int main() {
  int ret = 0;
  cudaSetDevice(0);
  cudaStream_t stream;
  cudaStreamCreate(&stream);
  vector<int> keys = {1,2,3,4,5};
  vector<int> scores = {0,0,0,0,0};
  int* d_keys = nullptr;
  int* d_scores = nullptr;
  get_data(keys, &d_keys);
  get_data(scores, &d_scores);
  functor::set_param(5);
  cudaDeviceSynchronize();

  // check whether if set_param work.
  gpu_any_kernel<functor::Functor><<<5, 1, 0, stream>>>(d_keys, d_scores, keys.size());
  ret = cudaMemcpyAsync(scores.data(), d_scores, sizeof(int) * scores.size(), cudaMemcpyDeviceToHost, stream);
  printf("copy d2h ret=%d\n", ret);
  cudaStreamSynchronize(stream);
  for (int i = 0; i < 5; i++) {
    printf("i=%d, score=%d\n", i, scores[i]); // It will get 5 on all pos.
  }
  return 0;
}

Copy link
Collaborator

@Lifann Lifann Aug 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/*
 * A prototype to show how to get param inside functor implicitly.
 */

#include "cuda_runtime.h"
#include <stdio.h>
#include <vector>
#include <unistd.h>
using namespace std;

namespace functor {

__device__ int param_;

void set_param(int score) {
  int ret = 0;
  ret = cudaMemcpyToSymbol(param_, &score, sizeof(int), 0, cudaMemcpyHostToDevice);
  printf("symbol op1 ret=%d\n", ret);
}

struct Functor {
 public:
  __device__ inline int internal_get_param() {
    return ::functor::param_;
  }
  __device__ void operator()(int& key, int& score) {
    int inc = internal_get_param();
    score = inc;
  }
};

}

template <typename ScorePred>
__global__ void gpu_any_kernel(int* keys, int* scores, size_t n) {
  int tid = blockDim.x * blockIdx.x + threadIdx.x;
  ScorePred pred;
  if (tid < n) {
    pred(keys[tid], scores[tid]);
  }
}

void get_data(vector<int>& data, int** d_data) {
  int ret = 0;
  cudaMalloc(d_data, data.size() * sizeof(int));
  cudaMemcpy(*d_data, data.data(), data.size() * sizeof(int), cudaMemcpyHostToDevice);
}

int main() {
  int ret = 0;
  cudaSetDevice(0);
  cudaStream_t stream;
  cudaStreamCreate(&stream);
  vector<int> keys = {1,2,3,4,5};
  vector<int> scores = {0,0,0,0,0};
  int* d_keys = nullptr;
  int* d_scores = nullptr;
  get_data(keys, &d_keys);
  get_data(scores, &d_scores);
  functor::set_param(5);
  cudaDeviceSynchronize();

  // check whether if set_param work.
  gpu_any_kernel<functor::Functor><<<5, 1, 0, stream>>>(d_keys, d_scores, keys.size());
  ret = cudaMemcpyAsync(scores.data(), d_scores, sizeof(int) * scores.size(), cudaMemcpyDeviceToHost, stream);
  printf("copy d2h ret=%d\n", ret);
  cudaStreamSynchronize(stream);
  for (int i = 0; i < 5; i++) {
    printf("i=%d, score=%d\n", i, scores[i]); // It will get 5 on all pos.
  }
  return 0;
}

In the case, it's possible to apply any param implicitly on functor, which can traverse on the key-value-score pairs of inputs in device or global function. And it makes abstract on different scoring strategies possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged our thoughts here, please help review again, thanks: afd5813

@rhdong
Copy link
Member Author

rhdong commented Jul 18, 2023

/blossom-ci

@rhdong rhdong force-pushed the rhdong/epoch-lru branch from 837d1f5 to bd985eb Compare July 19, 2023 03:20
@rhdong
Copy link
Member Author

rhdong commented Jul 19, 2023

/blossom-ci

1 similar comment
@EmmaQiaoCh
Copy link
Collaborator

/blossom-ci

@rhdong rhdong force-pushed the rhdong/epoch-lru branch from bd985eb to 2c24809 Compare July 25, 2023 05:40
@rhdong
Copy link
Member Author

rhdong commented Jul 25, 2023

/blossom-ci

@rhdong rhdong force-pushed the rhdong/epoch-lru branch from 2c24809 to 7719b3a Compare July 25, 2023 06:12
@rhdong
Copy link
Member Author

rhdong commented Jul 25, 2023

/blossom-ci

@rhdong
Copy link
Member Author

rhdong commented Aug 24, 2023

/blossom-ci

@rhdong rhdong requested a review from LinGeLin August 24, 2023 07:57
@rhdong
Copy link
Member Author

rhdong commented Aug 24, 2023

/blossom-ci

enum class EvictStrategy {
kLru = 0, ///< LRU mode.
kCustomized = 1 ///< Customized mode.
struct EvictStrategy {
Copy link
Collaborator

@Lifann Lifann Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since EvictStategy is a template param for HashTable, is it very need to use it as a constexpr static int, instead of EvictStategy::kEpochLru? If it is, maybe use a macro to make the template param meaningful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the structure mainly serves for conveniently exposing the API to and being easily referred by the end-users via EvictStategy, and I believe the macro has no name scope and can dirty the user namespace.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have re-designed it, please refer to the latest code.

@@ -294,6 +308,9 @@ class HashTable {
*
* @param stream The CUDA stream that is used to execute the operation.
*
* @param global_epoch The global epoch for EpochLRU, EpochLFU, when it's set
* to `DEFAULT_GLOBAL_EPOCH`, insert score in @p scores directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean "insert score in @scores"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

@rhdong
Copy link
Member Author

rhdong commented Aug 25, 2023

/blossom-ci

@rhdong rhdong requested review from Lifann and evanzhen August 25, 2023 10:47
@rhdong
Copy link
Member Author

rhdong commented Aug 31, 2023

/blossom-ci

@rhdong
Copy link
Member Author

rhdong commented Sep 1, 2023

/blossom-ci

@rhdong
Copy link
Member Author

rhdong commented Sep 1, 2023

/blossom-ci

@rhdong
Copy link
Member Author

rhdong commented Sep 1, 2023

/blossom-ci

@rhdong
Copy link
Member Author

rhdong commented Sep 2, 2023

Benchmark Comparasion:

On pure HBM mode, thread_local param with EpochLru:

  • dim = 32, capacity = 128 Million-KV, HBM = 16 GB, HMEM = 0 GB
λ insert_or_assign find find_or_insert assign find* find_or_insert* insert_and_evict
0.50 1.056 2.223 1.249 1.639 3.742 1.744 0.923
0.75 0.859 2.202 0.642 1.634 1.851 1.295 0.875
1.00 0.468 2.217 0.472 1.658 0.949 0.528 0.681

On pure HBM mode, No thread_local param with EpochLru:

  • dim = 32, capacity = 128 Million-KV, HBM = 16 GB, HMEM = 0 GB
λ insert_or_assign find find_or_insert assign find* find_or_insert* insert_and_evict
0.50 1.060 2.254 1.251 1.658 3.850 1.748 0.924
0.75 0.859 2.216 0.645 1.645 1.860 1.300 0.884
1.00 0.467 2.269 0.472 1.660 0.952 0.529 0.684

The difference between the two settings should be less than 0.5%.

@rhdong
Copy link
Member Author

rhdong commented Sep 2, 2023

/blossom-ci

@rhdong
Copy link
Member Author

rhdong commented Sep 4, 2023

/blossom-ci

Copy link
Collaborator

@Lifann Lifann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rhdong rhdong merged commit 65028ec into NVIDIA-Merlin:master Sep 4, 2023
key_pos, key, score);
ScoreFunctor::update_with_digest(
bucket_keys_ptr, key_pos, scores, kv_idx, score, bucket_capacity,
get_digest<K>(key), (occupy_result != OccupyResult::DUPLICATE));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, occupy_result is not updated here, its value is INITIAL.
Anyway, here, occupy_result != OccupyResult:DUPLICATE always true.

Copy link
Member Author

@rhdong rhdong Sep 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, occupy_result is not updated here, its value is INITIAL. Anyway, here, occupy_result != OccupyResult:DUPLICATE always true.

@jiashuy, thanks for pointing out. I believe if running here always means the key is missed, this flag can be set to True directly. Anyway, the current logic looks right.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants