Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCCL_PARAM per ncclComm #225

Open
pietern opened this issue May 16, 2019 · 2 comments
Open

NCCL_PARAM per ncclComm #225

pietern opened this issue May 16, 2019 · 2 comments

Comments

@pietern
Copy link

pietern commented May 16, 2019

Currently the NCCL_PARAM macro uses a static variable to store the value and initializes it only once. This means that changing them at runtime is not possible (e.g. if you want run a sweep to find optimal settings). It would be great to change these on the fly. Since creating a new communicator isn't in the fast path it should be possible to make it non-static and load it every time a new communicator is created.

@ljluestc
Copy link

ljluestc commented Nov 3, 2024


#include <nccl.h>
#include <iostream>
#include <unordered_map>

// Mock function to represent NCCL parameter fetching
// Replace this with actual logic to fetch NCCL parameters.
ncclResult_t getNCCLParam(const char* paramName, int* paramValue) {
    // Here we simulate fetching a parameter value. 
    // You would replace this with the actual retrieval logic.
    static std::unordered_map<std::string, int> params = {
        {"NCCL_PARAM1", 42},
        {"NCCL_PARAM2", 100}
    };

    if (params.find(paramName) != params.end()) {
        *paramValue = params[paramName];
        return ncclSuccess;
    }
    return ncclInvalidArgument;
}

// Function to create a communicator
ncclResult_t createCommunicator(ncclComm_t* comm) {
    int paramValue;
    
    // Dynamically fetch the NCCL parameters each time a communicator is created
    NCCLCHECK(getNCCLParam("NCCL_PARAM1", &paramValue));
    std::cout << "Using NCCL_PARAM1: " << paramValue << std::endl;

    NCCLCHECK(getNCCLParam("NCCL_PARAM2", &paramValue));
    std::cout << "Using NCCL_PARAM2: " << paramValue << std::endl;

    // Create the NCCL communicator here
    // Replace with actual NCCL communicator creation logic
    NCCLCHECK(ncclCommInitRank(comm, /*numRanks*/, /*uniqueID*/, /*rank*/));

    return ncclSuccess;
}

// Error handling macro for NCCL
#define NCCLCHECK(func) do {                                 \
    ncclResult_t res = (func);                             \
    if (res != ncclSuccess) {                              \
        std::cerr << "NCCL error: " << ncclGetErrorString(res) << std::endl; \
        return res;                                        \
    }                                                      \
} while (0)

int main() {
    ncclComm_t comm;

    // Example of creating multiple communicators with dynamic NCCL parameters
    for (int i = 0; i < 3; ++i) {
        std::cout << "Creating communicator " << i + 1 << std::endl;
        NCCLCHECK(createCommunicator(&comm));
    }

    // Cleanup communicator
    NCCLCHECK(ncclCommDestroy(comm));

    return 0;
}

@marksantesson
Copy link

@pietern , can you just disable the caching in NCCL_PARAM for your testing? If not, perhaps have a global counter... when it increases then all the NCCL_PARAMs can know to reload from cache.

Something like:

extern uint32_t globalReset;

#define NCCL_PARAM(name, env, deftVal) \
  int64_t ncclParam##name() { \
    constexpr int64_t uninitialized = INT64_MIN; \
    static_assert(deftVal != uninitialized, "default value cannot be the uninitialized value."); \
    static int64_t cache = uninitialized; \
    static uint32_t resetCount = 0; \
    if (resetCount < globalReset) cache = uninitialized;
    if (__builtin_expect(__atomic_load_n(&cache, __ATOMIC_RELAXED) == uninitialized, false)) { \
      ncclLoadParam("NCCL_" env, deftVal, uninitialized, &cache); \
    } \
    return cache; \
  }

void resetParams() {
    ++globalReset;
}

Then in src/misc/param.cc:

uint32_t globalReset = 0;

Note that the ncclLoadParam function also checks cache, so if you just disable it then you'll need to disable it there, too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants