Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize trace hang && fix event leak #58707

Merged
merged 34 commits into from
Nov 18, 2023

Conversation

hitywt
Copy link

@hitywt hitywt commented Nov 6, 2023

PR types

Others

PR changes

Others

Description

  1. fix cuda event leak
  2. add nccl comm init logs for trace hang function usage

Copy link
Contributor

@gongweibao gongweibao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments.

// convert vector to string, concatenate continuous intervals with `:`,
// concatenate discontinuous intervals with `#` eg: [1,2,3,4,5,7,8,9] =>
// 1:3#4#5#7:9
inline std::string VectorToString(const std::vector<int>& vec) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add unit test of it.

",seq:" + std::to_string(seq_) +
",started:" + std::to_string(IsStarted()) +
",completed:" + std::to_string(IsCompleted()) +
auto global_ranks =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add unit test of it and test the limit of the Msg length.

@@ -484,6 +484,7 @@ class ProcessGroup {
}

protected:
int global_rank_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int global_rank_{-1};

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int global_rank_{-1};

fixed

@@ -860,6 +862,44 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
comm_ctx->set_nccl_comm(nccl_comm);

// gather global ranks in current group
int* gpu_global_rank = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Use tensor instead of raw data since cuda memory APIs are not efficient?
  2. If use raw data, check the result of CUDA API.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Use tensor instead of raw data since cuda memory APIs are not efficient?
  2. If use raw data, check the result of CUDA API.

fixed

@hitywt hitywt changed the title fix event leak && add ncc comm group init info optimize trace hang && fix event leak Nov 17, 2023
Copy link
Contributor

@gongweibao gongweibao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gongweibao gongweibao merged commit bcf9676 into PaddlePaddle:incubate/new_frl Nov 18, 2023
hitywt pushed a commit to hitywt/Paddle that referenced this pull request Nov 21, 2023
hitywt pushed a commit to hitywt/Paddle that referenced this pull request Nov 24, 2023
ForFishes added a commit that referenced this pull request Nov 28, 2023
* add comm async trace module, (#56916)

* Fix trace hang (#57536)

* fix trace hang

* fix compile error

* fix code style

* tinyfix

* tiny update

* fix code style

---------

Co-authored-by: ForFishes <[email protected]>

* Fix nccl trace (#58338)

* fix nccl_async_trace destruct problem when train finished

* update

* format code style

* optimize trace hang && fix event leak (#58707)

* update

* fix compile problems

* fix code style

* fix logging

* fix code style

* remove useless

* add ut && tinyfix

* opt cudaMalloc and cudaMemcpy

update

* tinyfix

---------

Co-authored-by: ForFishes <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants