From 8dc205c918108e0de3cd5bffff48aa8a394af78d Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Wed, 21 Sep 2022 17:33:17 +0800 Subject: [PATCH] [MoE] Fix recompute & communication api (#3338) * update moe recompute. --- .../language_model/moe/dygraph/modeling.py | 24 +++++++++++-------- .../moe/dygraph/run_moe_pretrain.py | 10 ++++---- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/examples/language_model/moe/dygraph/modeling.py b/examples/language_model/moe/dygraph/modeling.py index 64c1f220ca1d..71b1172d16fd 100644 --- a/examples/language_model/moe/dygraph/modeling.py +++ b/examples/language_model/moe/dygraph/modeling.py @@ -35,8 +35,6 @@ MoeLayer = moe.MoELayer from utils import get_timers -from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _initialize_recompute_setting, _initialize_recompute_hcg - __all__ = [ 'GPTModel', "GPTPretrainedModel", @@ -410,7 +408,9 @@ def __init__(self, top_k=2, hcg=None, gate=None, - recompute_interval=0): + recompute_interval=0, + recompute_partition=False, + recompute_offload=False): self._config = locals() self._config.pop("self") self._config.pop("__class__", None) # py3 @@ -454,12 +454,19 @@ def __init__(self, "type": "gshard", "top_k": top_k, } + + recompute_ctx = { + "mp_group": mp_group, + "offload": recompute_offload, + "partition": recompute_partition + } self.moe_mlp = MoeLayer(d_model=d_model, experts=experts_list, gate=gate_config, moe_group=moe_group, mp_group=mp_group, - recompute_interval=self.recompute_interval) + recompute_interval=self.recompute_interval, + recompute_ctx=recompute_ctx) else: self.linear1 = fleet.meta_parallel.ColumnParallelLinear( d_model, @@ -769,11 +776,6 @@ def __init__(self, self.hidden_size = hidden_size self.vocab_size = vocab_size - if recompute_interval > 0: - _initialize_recompute_hcg(hcg) - _initialize_recompute_setting(recompute_offload, - recompute_partition) - self.embeddings = GPTEmbeddings(vocab_size, hidden_size, hidden_dropout_prob, max_position_embeddings, @@ -800,7 +802,9 @@ def __init__(self, top_k=top_k, hcg=hcg, gate=gate, - recompute_interval=recompute_interval)) + recompute_interval=recompute_interval, + recompute_partition=recompute_partition, + recompute_offload=recompute_offload)) self.decoder = TransformerDecoder(decoder_layers, num_hidden_layers, diff --git a/examples/language_model/moe/dygraph/run_moe_pretrain.py b/examples/language_model/moe/dygraph/run_moe_pretrain.py index cabeeb926473..183a96f39f69 100644 --- a/examples/language_model/moe/dygraph/run_moe_pretrain.py +++ b/examples/language_model/moe/dygraph/run_moe_pretrain.py @@ -143,12 +143,12 @@ def initialize_mp_dp_parameters(model, hcg): paddle.distributed.broadcast(param.detach(), src=mp_src_rank, group=mp_group, - use_calc_stream=True) + sync_op=True) paddle.distributed.broadcast(param.detach(), src=dp_src_rank, group=dp_group, - use_calc_stream=True) + sync_op=True) def unscale_method(self, optimizer): @@ -206,7 +206,7 @@ def all_reduce_parameters(params, group): with paddle.framework.no_grad(): for p in params: grad = p.grad.scale_(div_factor) - paddle.distributed.all_reduce(grad, use_calc_stream=True) + paddle.distributed.all_reduce(grad, sync_op=True) def parameters_classify(model, use_sharding=False): @@ -492,9 +492,9 @@ def do_train(args): dist.broadcast(p, src=sharding_group.ranks[0], group=sharding_group, - use_calc_stream=True) + sync_op=True) # Multi stream operation will be supported later - dist.wait(tensor=p, group=sharding_group, use_calc_stream=True) + dist.wait(tensor=p, group=sharding_group, sync_op=True) else: initialize_mp_dp_parameters(model, hcg)