Skip to content

Commit

Permalink
[MoE] Fix recompute & communication api (#3338)
Browse files Browse the repository at this point in the history
* update moe recompute.
  • Loading branch information
sljlp authored Sep 21, 2022
1 parent 7708822 commit 8dc205c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
24 changes: 14 additions & 10 deletions examples/language_model/moe/dygraph/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
MoeLayer = moe.MoELayer
from utils import get_timers

from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _initialize_recompute_setting, _initialize_recompute_hcg

__all__ = [
'GPTModel',
"GPTPretrainedModel",
Expand Down Expand Up @@ -410,7 +408,9 @@ def __init__(self,
top_k=2,
hcg=None,
gate=None,
recompute_interval=0):
recompute_interval=0,
recompute_partition=False,
recompute_offload=False):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand Down Expand Up @@ -454,12 +454,19 @@ def __init__(self,
"type": "gshard",
"top_k": top_k,
}

recompute_ctx = {
"mp_group": mp_group,
"offload": recompute_offload,
"partition": recompute_partition
}
self.moe_mlp = MoeLayer(d_model=d_model,
experts=experts_list,
gate=gate_config,
moe_group=moe_group,
mp_group=mp_group,
recompute_interval=self.recompute_interval)
recompute_interval=self.recompute_interval,
recompute_ctx=recompute_ctx)
else:
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
d_model,
Expand Down Expand Up @@ -769,11 +776,6 @@ def __init__(self,
self.hidden_size = hidden_size
self.vocab_size = vocab_size

if recompute_interval > 0:
_initialize_recompute_hcg(hcg)
_initialize_recompute_setting(recompute_offload,
recompute_partition)

self.embeddings = GPTEmbeddings(vocab_size, hidden_size,
hidden_dropout_prob,
max_position_embeddings,
Expand All @@ -800,7 +802,9 @@ def __init__(self,
top_k=top_k,
hcg=hcg,
gate=gate,
recompute_interval=recompute_interval))
recompute_interval=recompute_interval,
recompute_partition=recompute_partition,
recompute_offload=recompute_offload))

self.decoder = TransformerDecoder(decoder_layers,
num_hidden_layers,
Expand Down
10 changes: 5 additions & 5 deletions examples/language_model/moe/dygraph/run_moe_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ def initialize_mp_dp_parameters(model, hcg):
paddle.distributed.broadcast(param.detach(),
src=mp_src_rank,
group=mp_group,
use_calc_stream=True)
sync_op=True)

paddle.distributed.broadcast(param.detach(),
src=dp_src_rank,
group=dp_group,
use_calc_stream=True)
sync_op=True)


def unscale_method(self, optimizer):
Expand Down Expand Up @@ -206,7 +206,7 @@ def all_reduce_parameters(params, group):
with paddle.framework.no_grad():
for p in params:
grad = p.grad.scale_(div_factor)
paddle.distributed.all_reduce(grad, use_calc_stream=True)
paddle.distributed.all_reduce(grad, sync_op=True)


def parameters_classify(model, use_sharding=False):
Expand Down Expand Up @@ -492,9 +492,9 @@ def do_train(args):
dist.broadcast(p,
src=sharding_group.ranks[0],
group=sharding_group,
use_calc_stream=True)
sync_op=True)
# Multi stream operation will be supported later
dist.wait(tensor=p, group=sharding_group, use_calc_stream=True)
dist.wait(tensor=p, group=sharding_group, sync_op=True)
else:
initialize_mp_dp_parameters(model, hcg)

Expand Down

0 comments on commit 8dc205c

Please sign in to comment.