Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

checkpoint function #25

Merged
merged 11 commits into from
Apr 13, 2022
53 changes: 53 additions & 0 deletions energon/communication/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,56 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
return tensor, work
else:
return tensor


def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None):
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
"""
if dist._rank_not_in_group(group):
return

if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1):
raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.")

# set tensor device to cuda if backend is nccl
device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu")

my_rank = dist.get_rank() # use global rank
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list])
tensor_list = list(map(lambda x: x.to(device), tensor_list))
tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes))

# Src rank broadcasts the maximum tensor size. This is because all ranks are
# expected to call into scatter() with equal-sized tensors.
if my_rank == src:
max_tensor_size = max(tensor_sizes)
for tensor in tensor_list:
tensor.resize_(max_tensor_size)
else:
max_tensor_size = torch.tensor([0], dtype=torch.long).to(device)

dist.broadcast(max_tensor_size, src=src, group=group)

# Scatter actual serialized objects
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8).to(device)
dist.scatter(
output_tensor,
scatter_list=None if my_rank != src else tensor_list,
src=src,
group=group,
)

# Scatter per-object sizes to trim tensors when deserializing back to object
obj_tensor_size = torch.tensor([0], dtype=torch.long).to(device)
dist.scatter(
obj_tensor_size,
scatter_list=None if my_rank != src else tensor_sizes,
src=src,
group=group,
)

output_tensor, obj_tensor_size = output_tensor.cpu(), obj_tensor_size.cpu()
# Deserialize back to object
scatter_object_output_list[0] = dist.distributed_c10d._tensor_to_object(output_tensor, obj_tensor_size)
37 changes: 32 additions & 5 deletions energon/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self):
self._local_ranks = dict()
self._world_sizes = dict()
self._groups = dict()
self._cpu_groups = dict()
self._ranks_in_group = dict()

# load config from file
Expand Down Expand Up @@ -277,6 +278,32 @@ def rm_group(self, parallel_mode: ParallelMode):
self._check_parallel_mode(parallel_mode)
self._groups.pop(parallel_mode)

def get_cpu_group(self, parallel_mode: ParallelMode):
"""Returns the Gloo group of the current device for `parallel_mode`.

:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The group of the current device for `parallel_mode`
:rtype: torch.distributed.ProcessGroup
"""
self._check_parallel_mode(parallel_mode)
return self._cpu_groups[parallel_mode]

def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the Gloo group of the current device for `parallel_mode`.

:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param group: The group to be added
:type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._cpu_groups[parallel_mode] = group

def get_ranks_in_group(self, parallel_mode: ParallelMode):
"""Returns the rank of the current device for `parallel_mode` in the group.
:param parallel_mode: The chosen parallel mode
Expand Down Expand Up @@ -332,17 +359,17 @@ def init_global_dist(self,
world_size=world_size,
backend=backend,
init_method=init_method)

ranks = list(range(world_size))
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None,
list(range(world_size)), ParallelMode.GLOBAL)
cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None
self._register_dist(rank, world_size, None, cpu_group, ranks, ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank)

def _register_dist(self, local_rank, world_size,
process_group, ranks_in_group, mode):
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size)
self.add_group(mode, process_group)
self.add_cpu_group(mode, cpu_group)
self.add_ranks_in_group(mode, ranks_in_group)

def _deregister_dist(self, mode):
Expand Down
5 changes: 4 additions & 1 deletion energon/context/process_group_initializer/initializer_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,21 @@ def init_dist_group(self):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = ''

for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group

if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
cpu_group = group_cpu

return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
19 changes: 14 additions & 5 deletions energon/context/process_group_initializer/initializer_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
class Initializer_Data(ProcessGroupInitializer):
"""A ProcessGroupInitializer for data parallelism.

:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -22,23 +27,27 @@ def __init__(self, *args, **kwargs):
def init_dist_group(self):
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.

:return: Data parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Data parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.DATA

for i in range(self.num_data_parallel_group):
ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group

if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks

return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
20 changes: 15 additions & 5 deletions energon/context/process_group_initializer/initializer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ class Initializer_Model(ProcessGroupInitializer):
"""A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel
groups).

:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""

def __init__(self, *args, **kwargs):
Expand All @@ -26,22 +31,27 @@ def __init__(self, *args, **kwargs):
def init_dist_group(self):
"""Initialize model parallel groups, and assign local_ranks and groups to each gpu.

:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: Tuple
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Model parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.MODEL

for i in range(self.num_group):
ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group

if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode

return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
24 changes: 15 additions & 9 deletions energon/context/process_group_initializer/initializer_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
class Initializer_Pipeline(ProcessGroupInitializer):
"""A ProcessGroupInitializer for pipeline parallelism.

:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process
world_size (int): Size of whole communication world
config (Config): Running configuration
data_parallel_size (int): Size of data parallel
pipeline_parallel_size (int): Size of pipeline parallel
tensor_parallel_size (int): Size of tensor parallel
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_group_size = self.world_size // self.data_parallel_size
Expand All @@ -23,27 +29,27 @@ def __init__(self, *args, **kwargs):
def init_dist_group(self):
"""Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu.

:return: Pipeline parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
A Pipeline parallelism's information in list of tuples.
"""
dist_settings = list()
for i in range(self.data_parallel_size):
for j in range(self.pipeline_stage_size):
pipe_ranks = list(
range(i * self.data_group_size + j,
(i + 1) * self.data_group_size,
self.pipeline_stage_size))
range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size))
pipe_group_size = len(pipe_ranks)
pipe_group = dist.new_group(pipe_ranks)
group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group

if self.rank in pipe_ranks:
local_rank = pipe_ranks.index(self.rank)
group_world_size = pipe_group_size
process_group = pipe_group
cpu_group = group_cpu
ranks_in_group = pipe_ranks
dist_settings.append(
tuple((local_rank, group_world_size,
process_group, ranks_in_group,
tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group,
ParallelMode.PIPELINE)))

return dist_settings
19 changes: 14 additions & 5 deletions energon/context/process_group_initializer/initializer_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
class Initializer_Tensor(ProcessGroupInitializer):
"""A ProcessGroupInitializer for tensor parallelism.

:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -22,23 +27,27 @@ def __init__(self, *args, **kwargs):
def init_dist_group(self):
"""Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.

:return: Tensor parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Tensor parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.TENSOR

for i in range(self.num_tensor_parallel_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group

if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks

return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
Loading