Skip to content

Commit

Permalink
refactor lm_head tensor parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
dc3671 committed Oct 8, 2023
1 parent e3e60e5 commit fc64ef5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 29 deletions.
31 changes: 14 additions & 17 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,10 @@ def _replace(self, child, name, conv_linear_layer):

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(data_dc, dist.get_rank(), dist.get_world_size(), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), mp_group)
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
child.bias.to(get_accelerator().current_device_name())), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group)
else:
Expand Down Expand Up @@ -441,19 +443,14 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
return r_module

def _replace_last_linear_module(self, r_module):
for name, child in r_module.named_children():
if name == "lm_head" or name == 'embed_out':
checking_key = name + '.'
if child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm] and state_dict != None:
if any(checking_key in item for item in state_dict):
load(child, state_dict, checking_key, mp_group)
else:
continue
if len(child._buffers) != 0 and state_dict != None:
load_buffer(child, state_dict, checking_key)
if child.__class__ in linear_policies:
setattr(r_module, name, linear_policies[child.__class__](child, name, conv_linear_layer))
else:
update_mp_params(child)
_replace_module(child, name)
if hasattr(r_module, "lm_head"):
name = "lm_head"
child = r_module.lm_head
elif hasattr(r_module, "embed_out"):
name = "embed_out"
child = r_module.embed_out
else:
return r_module
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer))
return r_module
22 changes: 10 additions & 12 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
_autotp.update_linear_policies()

# 4. Replace modules
if "lm_head" in str(module) or 'embed_out' in str(module):
if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears:
return _autotp._replace_last_linear_module(module)
return _autotp._replace_module(module)

Expand Down Expand Up @@ -306,6 +306,13 @@ def set_lm_head(module):
if embedding_weight is not None and hasattr(module, "lm_head") and hasattr(
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight
# enable tensor parallel for the last linear
if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and not module.lm_head.weight.is_meta:
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
elif hasattr(module, "embed_out") and hasattr(module.embed_out,
"weight") and not module.embed_out.weight.is_meta:
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
return module

if checkpoint_dict is not None and not config.replace_with_kernel_inject:
# AutoTP shard loading
Expand All @@ -320,7 +327,7 @@ def set_lm_head(module):
checkpoint=checkpoint_file)
pbar.update(1)
gc.collect()
set_lm_head(replaced_module)
replaced_module = set_lm_head(replaced_module)
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
Expand Down Expand Up @@ -553,8 +560,6 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No
policy = {}
if orig_class is not None:
policy.update({orig_class: (replace_fn, _replace_policy)})
origin_layer = torch.nn.modules.linear.Linear
policy.update({origin_layer: (replace_fn, (list(model.named_modules())[-1][0]))})
else:
for plcy in replace_policies:
# instantiate a throw-away policy in order to populate the _orig_layer_class
Expand Down Expand Up @@ -603,14 +608,7 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di
Modified ``model``.
"""
for name, child in model.named_children():
if name == "lm_head" or name == "embed_out":
if child.__class__ in policies:
replaced_module = policies[child.__class__][0](model,
policies[child.__class__][-1],
layer_id,
prefix=prefix + name,
state_dict=state_dict)
elif child.__class__ in policies:
if child.__class__ in policies:
replaced_module = policies[child.__class__][0](child,
policies[child.__class__][-1],
layer_id,
Expand Down

0 comments on commit fc64ef5

Please sign in to comment.