Skip to content

Commit

Permalink
feat(tracing): add nvtx provider
Browse files Browse the repository at this point in the history
Hook nvtx on existing lttng macros.

We figured out how to structure this in a way that
aligns the required usages of nvtx with cases
like NCCL_OFI_TRACE_SEND_WRITE_SEG COMPLETE/START. We use the NVTX
start/end API for ranges, and mark API for events.

Only supports RDMA protocol for now, SENDRECV protocol NVTX support will
be added in the future.

Signed-off-by: Eric Raut <[email protected]>
  • Loading branch information
rauteric committed Apr 17, 2024
1 parent 6a41fbb commit a3aea9e
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 21 deletions.
1 change: 1 addition & 0 deletions include/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ noinst_HEADERS = \
nccl_ofi_ofiutils.h \
nccl_ofi_tracepoint.h \
tracing_impl/lttng.h \
tracing_impl/nvtx.h \
nccl-headers/net.h \
nccl-headers/error.h \
nccl-headers/nvidia/err.h \
Expand Down
24 changes: 24 additions & 0 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ extern "C" {
#include "nccl_ofi_deque.h"
#include "nccl_ofi_freelist.h"
#include "nccl_ofi_idpool.h"
#include "nccl_ofi_tracepoint.h"

/* Maximum number of rails supported. This defines the size of
* messages exchanged during connection establishment (linear
Expand Down Expand Up @@ -170,6 +171,10 @@ typedef struct {
/* Total number of completions. Expect one completion for receiving the
* control message and one completion for each send segment. */
int total_num_compls;
#if HAVE_NVTX_TRACING
nvtxRangeId_t trace_id;
nvtxRangeId_t seg_trace_id[MAX_NUM_RAILS];
#endif
} rdma_req_send_data_t;

/*
Expand All @@ -184,6 +189,9 @@ typedef struct {
nccl_net_ofi_schedule_t *ctrl_schedule;
/* Pointer to recv parent request */
nccl_net_ofi_rdma_req_t *recv_req;
#if HAVE_NVTX_TRACING
nvtxRangeId_t trace_id;
#endif
} rdma_req_send_ctrl_data_t;

typedef struct {
Expand Down Expand Up @@ -224,6 +232,9 @@ typedef struct {
* For eager messages, the second completion will be received
* when the local read into the destination buffer is complete */
int total_num_compls;
#if HAVE_NVTX_TRACING
nvtxRangeId_t trace_id;
#endif
} rdma_req_recv_data_t;

/*
Expand Down Expand Up @@ -403,8 +414,13 @@ typedef struct nccl_net_ofi_rdma_send_comm {
* and `num_init_rails' is adjusted. */
int num_init_rails;

#if HAVE_NVTX_TRACING
nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM];
#endif

/* Array of `num_rails` communicator rails */
nccl_net_ofi_rdma_send_comm_rail_t rails[];

} nccl_net_ofi_rdma_send_comm_t;

/*
Expand Down Expand Up @@ -465,6 +481,10 @@ typedef struct nccl_net_ofi_rdma_recv_comm {
/* Free list to track control buffers, for sending RDMA control messages */
nccl_ofi_freelist_t *ctrl_buff_fl;

#if HAVE_NVTX_TRACING
nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM];
#endif

/* Number of rails */
int num_rails;

Expand Down Expand Up @@ -662,6 +682,10 @@ typedef struct nccl_net_ofi_rdma_device {

/* Memory registration key pool */
nccl_ofi_idpool_t key_pool;

#if HAVE_NVTX_TRACING
nvtxDomainHandle_t nvtx_domain[MAX_NUM_RAILS];
#endif
} nccl_net_ofi_rdma_device_t;

/*
Expand Down
72 changes: 55 additions & 17 deletions include/nccl_ofi_tracepoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define NCCL_OFI_TRACEPOINT_H_

#include "config.h"
#include "tracing_impl/nvtx.h"
#include "tracing_impl/lttng.h"

/***** SENDRECV PROTOCOL *****/
Expand All @@ -27,52 +28,89 @@
} while(0)

/***** RDMA PROTOCL *****/

#define NCCL_OFI_TRACE_SEND(dev, size, comm, msg_seq_num, request, nccl_req) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Send, dev, size, comm, msg_seq_num, request, nccl_req); \
} while(0)
NCCL_OFI_TRACE_SEND_NVTX(dev, size, comm, msg_seq_num, request, nccl_req); \
} while(0)

#define NCCL_OFI_TRACE_SEND_END(request) do { \
NCCL_OFI_TRACE_SEND_END_NVTX(request); \
} while(0)

#define NCCL_OFI_TRACE_EAGER_SEND_START(dev, rail_id, size, comm, msg_seq_num, request) do { \
/* TODO: use a better (LTTNG) trace for eager send? */ \
lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_start, dev, rail_id, size, comm, msg_seq_num, request); \
NCCL_OFI_TRACE_EAGER_SEND_START_NVTX(dev, rail_id, size, comm, msg_seq_num, request); \
} while(0)

#define NCCL_OFI_TRACE_EAGER_SEND_COMPLETE(dev, rail_id, comm, msg_seq_num, request) do { \
NCCL_OFI_TRACE_EAGER_SEND_COMPLETE_NVTX(dev, rail_id, comm, msg_seq_num, request); \
} while (0)

#define NCCL_OFI_TRACE_SEND_CTRL_RECV(dev, rail_id, comm, msg_seq_num) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Send_ctrl_recv, dev, rail_id, comm, msg_seq_num); \
} while (0)
lttng_ust_tracepoint(nccl_ofi_plugin, Send_ctrl_recv, dev, rail_id, comm, msg_seq_num); \
NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX(dev, rail_id, comm, msg_seq_num); \
} while (0)

#define NCCL_OFI_TRACE_SEND_CTRL_START(dev, rail_id, comm, req, msg_seq_num) do { \
NCCL_OFI_TRACE_SEND_CTRL_START_NVTX(dev, rail_id, comm, req, msg_seq_num); \
} while (0);

#define NCCL_OFI_TRACE_SEND_CTRL_END(dev, rail_id, comm, req, msg_seq_num) do { \
NCCL_OFI_TRACE_SEND_CTRL_END_NVTX(dev, rail_id, comm, req, msg_seq_num); \
} while (0);

#define NCCL_OFI_TRACE_SEND_WRITE_SEG_START(dev, rail_id, size, comm, msg_seq_num, request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_start, dev, rail_id, size, comm, msg_seq_num, request); \
} while(0)
lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_start, dev, rail_id, size, comm, msg_seq_num, request); \
NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX(dev, rail_id, size, comm, msg_seq_num, request); \
} while(0)

#define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE(dev, rail_id, comm, msg_seq_num, request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_complete, dev, rail_id, comm, msg_seq_num, request); \
} while(0)
NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX(dev, rail_id, comm, msg_seq_num, request); \
} while(0)

#define NCCL_OFI_TRACE_RECV(dev, tag, size, request, nccl_req) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Recv, dev, tag, size, request, nccl_req); \
} while(0)
NCCL_OFI_TRACE_RECV_NVTX(dev, tag, size, request, nccl_req); \
} while(0)

#define NCCL_OFI_TRACE_RECV_END(request) do { \
NCCL_OFI_TRACE_RECV_END_NVTX(request); \
} while(0)

#define NCCL_OFI_TRACE_RECV_CTRL_SEND_COMPLETE(request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Recv_ctrl_send_complete, request); \
} while(0)
lttng_ust_tracepoint(nccl_ofi_plugin, Recv_ctrl_send_complete, request); \
} while(0)

#define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE(dev, rail_id, size, request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Recv_segment_complete, dev, rail_id, size, request); \
} while(0)
NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX(dev, rail_id, size, request); \
} while(0)

#define NCCL_OFI_TRACE_EAGER_RECV(dev, rail_id, comm, msg_seq_num) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Eager_recv, dev, rail_id, comm, msg_seq_num); \
} while(0)
lttng_ust_tracepoint(nccl_ofi_plugin, Eager_recv, dev, rail_id, comm, msg_seq_num); \
NCCL_OFI_TRACE_EAGER_RECV_NVTX(dev, rail_id, comm, msg_seq_num); \
} while(0)

#define NCCL_OFI_TRACE_COMPLETIONS(request,ctx) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, ProcessCompletions, request,ctx); \
} while(0)
} while(0)

#define NCCL_OFI_TRACE_FLUSH(request, nccl_req) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Flush, request, nccl_req); \
} while(0)
NCCL_OFI_TRACE_FLUSH_NVTX(request, nccl_req); \
} while(0)

#define NCCL_OFI_TRACE_PENDING_INSERT(request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Pending_queue_insert, request); \
} while(0)
NCCL_OFI_TRACE_PENDING_INSERT_NVTX(request); \
} while(0)

#define NCCL_OFI_TRACE_PENDING_REMOVE(request) do { \
lttng_ust_tracepoint(nccl_ofi_plugin, Pending_queue_remove, request); \
} while(0)
NCCL_OFI_TRACE_PENDING_REMOVE_NVTX(request); \
} while(0)

#endif /* NCCL_OFI_TRACEPOINT_H_ */
#endif /* NCCL_OFI_TRACEPOINT_H_ */
Loading

0 comments on commit a3aea9e

Please sign in to comment.