Skip to content

Commit

Permalink
[FIX] replace device_count by addressable_device_count (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZYHowell authored Feb 5, 2023
1 parent fb3c056 commit c079cc2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/gpu/alpa_nccl_group_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ CUstream default_stream = NULL;

CommGroup::CommGroup(PjRtStreamExecutorClient *client) : client_(client) {
if (client != nullptr) {
for (int device_id = 0; device_id < client->device_count(); ++device_id) {
for (int device_id = 0; device_id < client->addressable_device_count();
++device_id) {
auto executor = client->device_state(device_id).executor();
auto i_stream = std::make_unique<se::Stream>(executor);
auto o_stream = std::make_unique<se::Stream>(executor);
Expand Down
13 changes: 8 additions & 5 deletions tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ void AddCallBackReleasingBuffer(se::Stream *stream, PyBuffer::object &buf_obj) {

se::Stream *GetXlaStream(PjRtStreamExecutorClient *client, bool is_compute,
int device_id) {
return is_compute ? client->device_state(device_id).compute_stream()
: client->device_state(0).GetLastDeviceToDeviceStream();
return is_compute
? client->device_state(device_id).compute_stream()
: client->device_state(
client->EnqueueD2DTransfersOnSrcStream() ? 0 : device_id)
.GetLastDeviceToDeviceStream();
}
}; // namespace

Expand Down Expand Up @@ -191,10 +194,10 @@ Status ComputationWaitEvents(const AlpaUuids &uuids,
std::vector<se::Stream *> streams;
PjRtStreamExecutorClient *pjrt_client =
tensorflow::down_cast<PjRtStreamExecutorClient *>(client->pjrt_client());
int num_devices = pjrt_client->device_count();
for (int device_ordinal = 0; device_ordinal < num_devices; ++device_ordinal) {
int num_local_devices = pjrt_client->addressable_device_count();
for (int device_id = 0; device_id < num_local_devices; ++device_id) {
streams.push_back(
pjrt_client->device_state(device_ordinal).compute_stream());
pjrt_client->device_state(device_id).compute_stream());
}
for (int uuid : uuids) {
TF_RETURN_IF_ERROR(WaitEventOnStreams(uuid, streams));
Expand Down

0 comments on commit c079cc2

Please sign in to comment.