Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about compute and communication overlap #205

Open
tashanzhishi opened this issue Apr 14, 2019 · 2 comments
Open

Question about compute and communication overlap #205

tashanzhishi opened this issue Apr 14, 2019 · 2 comments

Comments

@tashanzhishi
Copy link

tashanzhishi commented Apr 14, 2019

Hi, I use two streams (one compute stream and one communication stream) on single process which manager 8 GPUs by 8 threads. The comm stream rely on comp stream, and the comp stream rely on comm stream too(by cudaStreamWaitEvent). The comp stream do multiple matrix multiplication by cublasSgemm and the comm stream do once allreduce. I adjust the times of cublasSgemm and whether do cudaStreamSynchronize after allreduce or not, i have found some problems.

As shown in table1, comp number mean the number of cublasSgemm on comp stream, async means don't call cudaStreamSynchronize after allreduce, sync means do, comm means average cuda kernel time of ncclAllReduce(by nvprof), total means total host time of comp and comm(by std::chrono).
We can see that along with the growth of comp number, the time of sync comm don't change and async comm growth, but sync total time is less than async total.

table1

comp number async comm(ms) sync comm(ms) async total(ms) sync total(ms)
0 4.7 4.5 98 100
1 5.5 4.6 143 122
10 6.4 4.9 435 395
100 10 4.9 3060 2950

I introdece one stream mode, which is compute and communication on same stream. Sleep 2ms meaning do some cpu task.
In table2 and table3, we can found that one stream mode is better than two streams, sync is better than async, ant sync comm kernel time is best.

table2: comp x 50

comm(ms) total(ms)
async 9.4 1633
sync 4.9 1535
one stream 6 1553

table3: comp x50, sleep 2ms

comm(ms) total(ms)
async 9.3 1635
sync 5.0 1575
one stream 6.2 1560

I have two questions:

  1. Why cudaStreamSynchronize can accelerate ncclAllReduce kernel time and total time?
  2. Is one stream better than two stream when compute and communciation overlap?

My machine environment:
CPU: two intel Xeon E5-2682 v4
Memory: 256GB
GPU: 8 V100
CUDA: 9.0
nvidia driver: 390.30
linux kernel version: 3.10.0-327.36.3.el7.x86_64
OS: CentOS Linux release 7.2.1511
NCCL: 2.3.7-1
GPU topo is

GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 mlx4_0 CPU Affinity
GPU0 X PIX PHB PHB SYS SYS SYS SYS PHB 0-63
GPU1 PIX X PHB PHB SYS SYS SYS SYS PHB 0-63
GPU2 PHB PHB X PIX SYS SYS SYS SYS PHB 0-63
GPU3 PHB PHB PIX X SYS SYS SYS SYS PHB 0-63
GPU4 SYS SYS SYS SYS X PIX PHB PHB SYS 0-63
GPU5 SYS SYS SYS SYS PIX X PHB PHB SYS 0-63
GPU6 SYS SYS SYS SYS PHB PHB X PIX SYS 0-63
GPU7 SYS SYS SYS SYS PHB PHB PIX X SYS 0-63
mlx4_0 PHB PHB PHB PHB SYS SYS SYS SYS X

test code is

/* Includes, system */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <memory>
#include <vector>
#include <thread>

/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <nccl.h>

/* Matrix size */
#define N (256*8)
#define GPUS (8)
#define COMP_ITER (50)
#define ALL_ITER (20)

//float *h_C_ref;
float *d_A[GPUS];
float *d_B[GPUS];
float *d_C[GPUS];
float alpha = 1.0f;
float beta = 0.0f;
int n2 = N * N;
int i;

enum NCCL_MODE {
  ASYNC = 0,
  SYNC = 1,
  ONE_STREAM = 2,
};


std::unique_ptr<ncclComm_t[]> comms = nullptr;
std::unique_ptr<cudaStream_t[]> nccl_streams = nullptr;
std::unique_ptr<cudaStream_t[]> blas_streams = nullptr;
size_t timestamp() {
  using namespace std::chrono;
  return duration_cast<microseconds>(
    high_resolution_clock::now().time_since_epoch()).count();
}

void init_nccl() {
  comms.reset(new ncclComm_t[GPUS]);
  nccl_streams.reset(new cudaStream_t[GPUS]);
  blas_streams.reset(new cudaStream_t[GPUS]);
  ncclUniqueId nccl_id;
  ncclGetUniqueId(&nccl_id);
  ncclGroupStart();
  for (size_t i = 0; i < GPUS; ++i) {
    cudaSetDevice(i);
    cudaStreamCreate(nccl_streams.get()+i);
    ncclCommInitRank(comms.get() + i, GPUS, nccl_id, i);
    cudaStreamCreate(blas_streams.get()+i);
  }
  ncclGroupEnd();
}

int init_data(int dev) {
  float *ha;
  float *hb;
  float *hc;
  //float *h_C_ref;
  d_A[dev] = 0;
  d_B[dev] = 0;
  d_C[dev] = 0;
  //float *da = *d_A[dev] = 0;
  //float *db = *d_B[dev] = 0;
  //float *dc = *d_C[dev] = 0;
  cudaSetDevice(dev);
  /* Allocate host memory for the matrices */
  ha = reinterpret_cast<float *>(malloc(n2 * sizeof(ha[0])));
  hb = reinterpret_cast<float *>(malloc(n2 * sizeof(hb[0])));
  hc = reinterpret_cast<float *>(malloc(n2 * sizeof(hc[0])));

  /* Fill the matrices with test data */
  float e = rand() / static_cast<float>(RAND_MAX);
  for (i = 0; i < n2; i++) {
    ha[i] = hb[i] = hc[i] = e;
  }

  /* Allocate device memory for the matrices */
  if (cudaMalloc(reinterpret_cast<void **>(&d_A[dev]), n2 * sizeof(d_A[dev][0])) !=
      cudaSuccess) {
    fprintf(stderr, "!!!! device memory allocation error (allocate A)\n");
    return EXIT_FAILURE;
  }

  if (cudaMalloc(reinterpret_cast<void **>(&d_B[dev]), n2 * sizeof(d_B[dev][0])) !=
      cudaSuccess) {
    fprintf(stderr, "!!!! device memory allocation error (allocate B)\n");
    return EXIT_FAILURE;
  }

  if (cudaMalloc(reinterpret_cast<void **>(&d_C[dev]), n2 * sizeof(d_C[dev][0])) !=
      cudaSuccess) {
    fprintf(stderr, "!!!! device memory allocation error (allocate C)\n");
    return EXIT_FAILURE;
  }


  /* Initialize the device matrices with the host matrices */
  cublasSetVector(n2, sizeof(ha[0]), ha, 1, d_A[dev], 1);
  cublasSetVector(n2, sizeof(hb[0]), hb, 1, d_B[dev], 1);
  cublasSetVector(n2, sizeof(hc[0]), hc, 1, d_C[dev], 1);
  return 0;
}

int destroy_data(int dev) {
  //float *h_C_ref;
  float *da = d_A[dev];
  float *db = d_B[dev];
  float *dc = d_C[dev] ;
  /* Memory clean up */

  if (cudaFree(da) != cudaSuccess) {
    fprintf(stderr, "!!!! memory free error (A)\n");
    return EXIT_FAILURE;
  }

  if (cudaFree(db) != cudaSuccess) {
    fprintf(stderr, "!!!! memory free error (B)\n");
    return EXIT_FAILURE;
  }

  if (cudaFree(dc) != cudaSuccess) {
    fprintf(stderr, "!!!! memory free error (C)\n");
    return EXIT_FAILURE;
  }
  return 0;
}

/* Main */
int worker(int dev, int nccl_mode) {
  cublasStatus_t status;

  cublasHandle_t handle;
  auto &blas_stream = *(blas_streams.get() + dev);
  cudaSetDevice(dev);

  status = cublasCreate(&handle);

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! CUBLAS initialization error\n");
    return EXIT_FAILURE;
  }

  cublasSetStream(handle, blas_stream);

  /* Performs operation using cublas */
  auto &nccl_stream = *(nccl_streams.get() + dev);
  std::vector<cudaEvent_t> nccl_events;
  nccl_events.reserve(ALL_ITER);
  size_t start = timestamp();
  if (nccl_mode == NCCL_MODE::ONE_STREAM) {
    for (size_t i = 0; i < ALL_ITER; ++i) {
      for (size_t j = 0; j < COMP_ITER; ++j) {
      status = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, N, N, &alpha, d_A[dev],
        N, d_B[dev], N, &beta, d_C[dev], N);
      }
      ncclAllReduce(d_C[dev], d_C[dev], n2, ncclFloat, ncclSum, *(comms.get() + dev), blas_stream);
     usleep(2000);
    }
    cudaStreamSynchronize(blas_stream);
  } else {
    // nccl_mode is ASYNC_NCCL or SYNC_NCCL
    for (size_t i = 0; i < ALL_ITER; ++i) {
      if (i > 0) {
        cudaStreamWaitEvent(blas_stream, nccl_events.back(), 0);
      }

      for (size_t j = 0; j < COMP_ITER; ++j) {
        cudaEvent_t event;
        cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
        status = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, N, N, &alpha, d_A[dev],
                          N, d_B[dev], N, &beta, d_C[dev], N);
        cudaEventRecord(event, blas_stream);
        cudaStreamWaitEvent(nccl_stream, event, 0);
      }

      ncclAllReduce(d_C[dev], d_C[dev], n2, ncclFloat, ncclSum, *(comms.get() + dev), nccl_stream);
      nccl_events.emplace_back();
      cudaEventCreateWithFlags(&nccl_events.back(), cudaEventDisableTiming);
      cudaEventRecord(nccl_events.back(), nccl_stream);
      if (nccl_mode == SYNC) {
        cudaStreamSynchronize(nccl_stream);
      }
      usleep(2000);
    }
    cudaStreamSynchronize(nccl_stream);
  }
  fprintf(stderr, "device: [%d], %d iterations spent: [%d ms]\n", dev, ALL_ITER, ((timestamp()-start)/1000));

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! kernel execution error.\n");
    return EXIT_FAILURE;
  }

  /* Shutdown */
  status = cublasDestroy(handle);

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! shutdown error (A)\n");
    return EXIT_FAILURE;
  }
  return 0;
}

int main(int argc, char** argv) {
  init_nccl();
  for (int i = 0; i < GPUS; ++i) {
    init_data(i);
  }
  std::vector<std::thread> threads;
  size_t start = timestamp();

  int nccl_mode = atoi(argv[1]);
  if (nccl_mode == NCCL_MODE::ONE_STREAM)
      fprintf(stderr, "mode: ONE STREAM\n");
  else if (nccl_mode == NCCL_MODE::SYNC)
      fprintf(stderr, "mode: SYNC\n");
  else if (nccl_mode == NCCL_MODE::ASYNC)
      fprintf(stderr, "mode: ASYNC\n");
  else
      fprintf(stderr, "unknown mode: %d\n", nccl_mode);

  for (int i = 0; i < GPUS; ++i) {
    std::thread t(std::bind(&worker, i, nccl_mode));
    threads.push_back(std::move(t));
  }
  for (auto &t : threads) {
    t.join();
  }
  //fprintf(stderr, "nccl mode: [%d], spent: [%d ms]\n", nccl_mode, (timestamp() - start)/1000);

  for (int i = 0; i < GPUS; ++i) {
    destroy_data(i);
  }
  return 0;
}
@kwen2501
Copy link
Contributor

In the two-stream case, I am not sure if the computation and communication can really overlap because the code asks NCCL stream to wait for all cublasSgemm kernels to finish. That, eventually, would be similar to a one-stream case, with additional event synchronization overhead.

Regarding the difference between sync mode and async mode in NCCL kernel time, I am not sure how the kernel time is measured?

@tashanzhishi
Copy link
Author

@kwen2501 thanks for reply. I get the kernel time by nvprof.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants