Skip to content

Commit

Permalink
Add start time information to CollTrace (NVIDIA#46)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#46

Differential Revision: D55168758
  • Loading branch information
YulunW authored and facebook-github-bot committed Mar 21, 2024
1 parent dfd19ca commit 3bb6a27
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 22 deletions.
49 changes: 31 additions & 18 deletions src/colltrace/CollTrace.cc
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include "CollTrace.h"
#include "ExtChecks.h"
#include "FbInternal.h"
#include "bootstrap.h"
#include "comm.h"
#include "nccl.h"
#include "ExtChecks.h"

#include "ExtUtils.h"
#include <unistd.h>
#include <algorithm>
#include <chrono>
#include <cstdint>
#include <fstream>
#include <memory>
#include <mutex>
#include <string>
#include <unistd.h>
#include <chrono>
#include <fstream>
#include <sstream>
#include <string>
#include "ExtUtils.h"

/*
=== BEGIN_NCCL_CVAR_INFO_BLOCK ===
Expand Down Expand Up @@ -124,8 +124,7 @@ CollTrace::~CollTrace() {
comm_->commHash,
comm_->rank,
e.what());
}
catch(...) {
} catch (...) {
WARN(
"COLLTRACE: comm %p commHash %lx rank %d - Destroy FAILED: Unkown exception",
comm_,
Expand Down Expand Up @@ -174,7 +173,8 @@ CollTrace::Dump CollTrace::dump() {
std::lock_guard<std::mutex> lock(workerMutex_);
CollTrace::Dump dump{};

if (curCollState_ == CurrentCollState::IN_PROGRESS) {
if (curCollState_ == CurrentCollState::IN_PROGRESS ||
curCollState_ == CurrentCollState::WAIT_START) {
// copy contents
dump.currentColl =
std::unique_ptr<CollTraceColl>(new CollTraceColl(curEvent_->coll));
Expand Down Expand Up @@ -234,13 +234,20 @@ void* CollTrace::collTraceThreadFnImpl() {
} else if (curEvent_->eventType == CollTraceEvent::EventType::WAKE_UP) {
continue;
}
curCollState_ = CurrentCollState::WAIT_START;
cudaError_t res = cudaEventSynchronize(curEvent_->start.get());
{
std::lock_guard<std::mutex> lock(workerMutex_);
curEvent_->coll.startTs = std::chrono::high_resolution_clock::now();
}
curCollState_ = CurrentCollState::IN_PROGRESS;
cudaError_t res = cudaEventSynchronize(curEvent_->stop.get());
res = cudaEventSynchronize(curEvent_->stop.get());
curCollState_ = CurrentCollState::DONE;
float latency = -1;

if (res == cudaSuccess) {
res = cudaEventElapsedTime(&latency, curEvent_->start.get(), curEvent_->stop.get());
res = cudaEventElapsedTime(
&latency, curEvent_->start.get(), curEvent_->stop.get());
}

{
Expand Down Expand Up @@ -271,8 +278,7 @@ void* CollTrace::collTraceThreadFnImpl() {
// FIXME: we should revisit bootstrapAllGather() here since commAbort
// may be called either on local rank or a remote rank causing socket
// failure
if (comm_->tuner != NULL &&
features & CollTrace::Features::ONLINE_TUNING) {
if (comm_->tuner != NULL && features & CollTrace::Features::ONLINE_TUNING) {
// Online tuning - average latencies across ranks & send to tuner
float* latencies = NULL;
NCCLCHECKIGNORE(
Expand Down Expand Up @@ -356,7 +362,9 @@ bool CollTrace::logCollSample(CollTraceColl& coll) {
intMap["nChannels"] = coll.info.nChannels;
intMap["nThreads"] = coll.info.nThreads;
intMap["latency (microseconds)"] = 1000 * coll.latency;

intMap["startTs"] = std::chrono::duration_cast<std::chrono::microseconds>(
coll.startTs.time_since_epoch())
.count();
ncclFbLogSample("nccl_coll_trace", normalMap, intMap);
return true;
}
Expand All @@ -376,7 +384,8 @@ static std::vector<std::string> collKeys = {
"channelId",
"nChannels",
"nThreads",
"latencyUs"};
"latencyUs",
"startTs"};

std::unordered_map<std::string, std::string> CollTraceColl::retrieveMap(
bool quoted) {
Expand Down Expand Up @@ -406,7 +415,11 @@ std::unordered_map<std::string, std::string> CollTraceColl::retrieveMap(
infoMap["channelId"] = std::to_string(info.channelId);
infoMap["nChannels"] = std::to_string(info.nChannels);
infoMap["nThreads"] = std::to_string(info.nThreads);
infoMap["latencyUs"] = std::to_string(latency < 0 ? -1: latency * 1000);
infoMap["latencyUs"] = std::to_string(latency < 0 ? -1 : latency * 1000);
infoMap["startTs"] =
std::to_string(std::chrono::duration_cast<std::chrono::microseconds>(
startTs.time_since_epoch())
.count());
return infoMap;
}

Expand Down Expand Up @@ -442,7 +455,7 @@ ncclResult_t collTraceDestroy(ncclComm* comm) {
comm->collTrace.reset();
}
// Try catch clause here is not going to be useful as destructors are noexcept
// by default. Instead of throwing an exception it will just crash the program.
// We need to think about a better way to handle this.
// by default. Instead of throwing an exception it will just crash the
// program. We need to think about a better way to handle this.
return ncclSuccess;
}
9 changes: 6 additions & 3 deletions src/colltrace/CollTrace.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ struct CollTraceColl {
ncclInfo info;
int64_t iteration;
cudaStream_t stream;
float latency {-1};
float latency{-1};
// This is achieved by waiting for the start event. We can only guarantee
// before this time point kernel has already started, but we cannot guarantee
// kernel started exactly at this time point.
std::chrono::time_point<std::chrono::high_resolution_clock> startTs{};

// serialize the entry to a json format string
std::string serialize(bool quoted = false);
Expand Down Expand Up @@ -76,13 +80,12 @@ class CollTrace {

enum class CurrentCollState {
PENDING,
WAIT_START,
IN_PROGRESS,
DONE,
};

struct Dump {
// Fixme: use a dedicated class to keep the information of collectives
// instead of reusing CollTraceColl and CollTraceEvent
std::deque<CollTraceColl> pastColls;
std::deque<CollTraceColl> pendingColls;
std::unique_ptr<CollTraceColl> currentColl;
Expand Down
55 changes: 54 additions & 1 deletion src/colltrace/tests/CollTraceDistTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#include <fstream>
#include <iostream>
#include <sstream>
#include <string_view>
#include <rfe/scubadata/ScubaData.h>
#include "Ctran.h"
#include "checks.h"
#include "gmock/gmock.h"
#include "json/json.h"
#include "nccl_cvars.h"
#include "tests_common.cuh"
Expand Down Expand Up @@ -305,9 +307,60 @@ TEST_F(CollTraceTest, DumpWithUnfinished) {
auto dump = comm->collTrace->dump();

EXPECT_TRUE(dump.pastColls.size() >= nColl);
// +1 for the extra wakeup event that might be created by dump() function
EXPECT_TRUE(dump.pendingColls.size() <= nColl);
int hasPending = dump.currentColl == nullptr ? 0 : 1;
EXPECT_TRUE(dump.pendingColls.size() + dump.pendingColls.size() == nColl * 2 - hasPending);
for (auto& coll: dump.pastColls) {
EXPECT_GT(coll.startTs.time_since_epoch().count(), 0);
}
for (auto& coll: dump.pendingColls) {
EXPECT_EQ(coll.startTs.time_since_epoch().count(), 0);
}
NCCLCHECK_TEST(ncclCommDestroy(comm));

NCCL_COLLTRACE.clear();
}

TEST_F(CollTraceTest, TestSerializedDump) {
// overwrite CollTrace features before creating comm
NCCL_COLLTRACE.push_back("trace");
ncclComm_t comm =
createNcclComm(this->globalRank, this->numRanks, this->localRank);
const int count = 1048576;
const int nColl = 10;

prepareAllreduce(count);
for (int i = 0; i < nColl; i++) {
NCCLCHECK_TEST(
ncclAllReduce(sendBuf, recvBuf, count, ncclInt, ncclSum, comm, stream));
}

EXPECT_TRUE(comm->collTrace != nullptr);
comm->collTrace->waitForWorkerFinishQueue();

// schedule more after the first 10 coll are finished
for (int i = 0; i < nColl; i++) {
NCCLCHECK_TEST(
ncclAllReduce(sendBuf, recvBuf, count, ncclInt, ncclSum, comm, stream));
}

auto dump = comm->collTrace->dump();

EXPECT_TRUE(dump.pastColls.size() >= nColl);
EXPECT_TRUE(dump.pendingColls.size() <= nColl);
int hasPending = dump.currentColl == nullptr ? 0 : 1;
EXPECT_TRUE(dump.pendingColls.size() + dump.pendingColls.size() == nColl * 2 - hasPending);
constexpr std::string_view startTsStr = "\"startTs\": ";
for (auto& coll: dump.pastColls) {
auto serialized = coll.serialize(true);
EXPECT_THAT(serialized, testing::HasSubstr(startTsStr));
EXPECT_GT(coll.startTs.time_since_epoch().count(), 0);
}
for (auto& coll: dump.pendingColls) {
auto serialized = coll.serialize(true);
EXPECT_THAT(serialized, testing::HasSubstr(startTsStr));
EXPECT_EQ(coll.startTs.time_since_epoch().count(), 0);
}
NCCLCHECK_TEST(ncclCommDestroy(comm));

NCCL_COLLTRACE.clear();
Expand Down

0 comments on commit 3bb6a27

Please sign in to comment.