diff --git a/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py b/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py index 22cba6be0..5a388f867 100644 --- a/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py @@ -7,15 +7,15 @@ import time import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -from deepspeed.accelerator import get_accelerator from dschat.utils.utils import print_rank_0 def print_all_ranks(tag, value, rank): world_size = torch.distributed.get_world_size() - all_tensor = torch.zeros(world_size, dtype=torch.float32).to( - get_accelerator().current_device_name()) + all_tensor = torch.zeros(world_size, + dtype=torch.float32, + device=value.device) all_tensor[rank] = value torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM) print_rank_0(f'{tag} {all_tensor}', rank) @@ -55,7 +55,8 @@ def __init__(self, rlhf_engine, args): self.end_of_conversation_token_id = self.tokenizer( args.end_of_conversation_token)['input_ids'][-1] self.z3_enabled = args.actor_zero_stage == 3 - self.compute_fp32_loss = self.args.compute_fp32_loss + self.calculate_fp32_loss = (self.args.dtype + == "bf16") and self.args.bf16_to_fp32_loss # In case the generated experience is not valid (too short), we use the last valid # generated experience. Alternatively, we can skip the step (on all workers). @@ -70,6 +71,8 @@ def __init__(self, rlhf_engine, args): self.gamma = 1.0 self.lam = 0.95 self.generate_time = 0.0 + self.first_generate = True + self.actor_model_hpu_graph_wrapped_fwd_fn = None def _generate_sequence(self, prompts, mask, step): @@ -84,13 +87,32 @@ def _generate_sequence(self, prompts, mask, step): kwargs = dict() with torch.no_grad(): - seq = self.actor_model.module.generate( - prompts, - attention_mask=mask, - max_length=max_min_length, - pad_token_id=self.tokenizer.pad_token_id, - synced_gpus=self.z3_enabled, - **kwargs) + if is_hpu(): + import habana_frameworks.torch.hpu as thpu + if is_hpu() and self.args.enable_hpu_graphs: + orig_actor_model_fwd_fn = self.actor_model.module.forward + if self.first_generate: + self.actor_model.module.forward = thpu.wrap_in_hpu_graph_func( + self.actor_model.module.forward) + self.first_generate = False + else: + self.actor_model.module.forward = self.actor_model_hpu_graph_wrapped_fwd_fn + seq = self.actor_model.module.generate( + prompts, + attention_mask=mask, + max_length=max_min_length, + min_length=max_min_length, + lazy_mode=True) + self.actor_model_hpu_graph_wrapped_fwd_fn = self.actor_model.module.forward + self.actor_model.module.forward = orig_actor_model_fwd_fn + else: + seq = self.actor_model.module.generate( + prompts, + attention_mask=mask, + max_length=max_min_length, + pad_token_id=self.tokenizer.pad_token_id, + synced_gpus=self.z3_enabled, + **kwargs) # Filter out seq with no answers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning # NOTE: this will causes each GPU has different number of examples @@ -149,19 +171,25 @@ def generate_experience(self, prompts, mask, step): pad_token_id = self.tokenizer.pad_token_id attention_mask = seq.not_equal(pad_token_id).long() + + hpu_mark_step() with torch.no_grad(): output = self.actor_model(seq, attention_mask=attention_mask) + hpu_mark_step() output_ref = self.ref_model(seq, attention_mask=attention_mask) + hpu_mark_step() reward_score = self.reward_model.forward_value( seq, attention_mask, prompt_length=self.prompt_length)['chosen_end_scores'].detach( ) + hpu_mark_step() values = self.critic_model.forward_value( seq, attention_mask, return_value_only=True).detach()[:, :-1] + hpu_mark_step() logits = output.logits logits_ref = output_ref.logits - if self.compute_fp32_loss: + if self.calculate_fp32_loss: logits = logits.to(torch.float) logits_ref = logits_ref.to(torch.float) @@ -221,25 +249,34 @@ def train_rlhf(self, inputs): advantages, returns = self.get_advantages_and_returns( old_values, old_rewards, start) + hpu_mark_step() ### process the new outputs batch = {'input_ids': seq, "attention_mask": attention_mask} actor_prob = self.actor_model(**batch, use_cache=False).logits + hpu_mark_step() actor_log_prob = gather_log_probs(actor_prob[:, :-1, :], seq[:, 1:]) + hpu_mark_step() actor_loss = self.actor_loss_fn(actor_log_prob[:, start:], log_probs[:, start:], advantages, action_mask[:, start:]) + hpu_mark_step() self.actor_model.backward(actor_loss) + hpu_mark_step() if not self.args.align_overflow: self.actor_model.step() + hpu_mark_step() value = self.critic_model.forward_value(**batch, return_value_only=True, use_cache=False)[:, :-1] + hpu_mark_step() critic_loss = self.critic_loss_fn(value[:, start:], old_values[:, start:], returns, action_mask[:, start:]) + hpu_mark_step() self.critic_model.backward(critic_loss) + hpu_mark_step() if self.args.align_overflow: actor_overflow = self.actor_model.optimizer.check_overflow( @@ -263,8 +300,10 @@ def train_rlhf(self, inputs): "OVERFLOW: actor and critic overflow, skipping both actor and critic steps", rank) self.actor_model.step() + hpu_mark_step() self.critic_model.step() + hpu_mark_step() return actor_loss, critic_loss @@ -296,7 +335,7 @@ def critic_loss_fn(self, values, old_values, returns, mask): old_values - self.cliprange_value, old_values + self.cliprange_value, ) - if self.compute_fp32_loss: + if self.calculate_fp32_loss: values = values.float() values_clipped = values_clipped.float() vf_loss1 = (values - returns)**2 diff --git a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py index 0e67efcf9..52e13a446 100755 --- a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py @@ -7,6 +7,7 @@ import deepspeed from deepspeed.ops.adam import FusedAdam from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.accelerator import get_accelerator from transformers import AutoModelForCausalLM, get_scheduler from dschat.utils.ds_utils import get_train_ds_config, get_eval_ds_config @@ -104,8 +105,23 @@ def _init_actor(self, actor_model_name_or_path): actor_model = make_model_gradient_checkpointing_compatible( actor_model) + # TODO SW-146776: remove this WA once SW-141762 is resolved + if is_hpu(): + import habana_frameworks.torch.core as htcore + actor_model.to(dtype=torch.bfloat16, + device=get_accelerator().device()) + # Optimizer - AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam + if self.args.offload: + AdamOptimizer = DeepSpeedCPUAdam + elif self.args.no_fused_kernels or is_hpu(): + AdamOptimizer = torch.optim.AdamW + else: + AdamOptimizer = FusedAdam + print_rank_0( + f'Using {AdamOptimizer.__name__} optimizer for actor model', + self.args.global_rank) + optim_params = get_optimizer_grouped_parameters( actor_model, self.args.actor_weight_decay, self.args.actor_lora_learning_rate) @@ -234,8 +250,23 @@ def _init_critic(self, critic_model_name_or_path): critic_model = make_model_gradient_checkpointing_compatible( critic_model) + # TODO SW-146776: remove this WA once SW-141762 is resolved + if is_hpu(): + critic_model.to(dtype=torch.bfloat16, + device=get_accelerator().device()) + # Optimizer - AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam + # TODO SW-147425: change the file to use HPEX optimizer instead of AdamW on hpu + if self.args.offload: + AdamOptimizer = DeepSpeedCPUAdam + elif self.args.no_fused_kernels or is_hpu(): + AdamOptimizer = torch.optim.AdamW + else: + AdamOptimizer = FusedAdam + print_rank_0( + f'Using {AdamOptimizer.__name__} optimizer for critic model', + self.args.global_rank) + optim_params = get_optimizer_grouped_parameters( critic_model, self.args.critic_weight_decay, self.args.critic_lora_learning_rate) diff --git a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py index 7e3e6776b..628b04826 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py @@ -298,8 +298,10 @@ def create_prompt_dataset(local_rank, eval_fname = f"{output_path}/evaldata_{fname}.pt" cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname) - buf_create_cache = torch.ByteTensor([not cache_found]).to( - get_accelerator().current_device_name()) + device = torch.device(get_accelerator().device_name( + torch.distributed.get_rank())) + buf_create_cache = get_accelerator().ByteTensor([not cache_found], + device=device) torch.distributed.all_reduce(buf_create_cache) if local_rank <= 0 and (buf_create_cache.item() != 0 or reload): diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py index 050819a22..6d13ac26d 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py @@ -82,6 +82,70 @@ def causal_lm_forward( model.forward = causal_lm_forward +def configure_dropout(model_config, dropout): + if dropout is not None: + for key in ('dropout', 'attention_dropout', 'hidden_dropout', + 'activation_dropout'): + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + +def causal_lm_model_to_fp32_loss(model): + """ Convert CausalLM model to calculate loss in fp32 """ + + def causal_lm_forward( + input_ids=None, + past_key_values=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **deprecated_arguments, + ): + output = model.__original_forward__( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + labels=None, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + return_dict = isinstance(output, dict) + lm_logits = output.logits if return_dict else output[0] + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].float().contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), + shift_labels.view(batch_size * seq_length)) + + if not return_dict: + # re-pack output with fp32 loss + return ((loss, ) + output) if loss is not None else output + + output.loss = loss + return output + + model.__original_forward__ = model.forward + model.forward = causal_lm_forward + + def create_hf_model(model_class, model_name_or_path, tokenizer, @@ -122,7 +186,8 @@ def create_critic_model(model_name_or_path, rlhf_training=False, dropout=None, zero_stage=0, - compute_fp32_loss=False): + loss_to_fp32=False, + optimized_reward_loss_calc=False): # OPT model family always put a padding token at the beginning of the sequence, # we did not see this in other models but not sure if it is a general rule @@ -139,7 +204,8 @@ def create_critic_model(model_name_or_path, critic_model, tokenizer, num_padding_at_beginning=num_padding_at_beginning, - compute_fp32_loss=compute_fp32_loss) + loss_to_fp32=loss_to_fp32, + opt_loss_calc=optimized_reward_loss_calc) if rlhf_training: # load critic model from checkpoint diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py index 60d063b18..38a79ab2c 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py @@ -14,10 +14,12 @@ def __init__(self, base_model, tokenizer, num_padding_at_beginning=0, - compute_fp32_loss=False): + loss_to_fp32=False, + opt_loss_calc=False): super().__init__() self.config = base_model.config self.num_padding_at_beginning = num_padding_at_beginning + self.optimized_loss_calc = opt_loss_calc if hasattr(self.config, "word_embed_proj_dim"): # `OPT` models use word_embed_proj_dim as final output # https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L497 @@ -31,7 +33,8 @@ def __init__(self, self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) self.rwtransformer = base_model self.PAD_ID = tokenizer.pad_token_id - self.compute_fp32_loss = compute_fp32_loss + self.loss_to_fp32 = loss_to_fp32 + self.fallback_mask = None def gradient_checkpointing_enable(self): self.rwtransformer.gradient_checkpointing_enable() @@ -84,36 +87,111 @@ def forward(self, rejected_id = rejected_ids[i] chosen_reward = chosen_rewards[i] rejected_reward = rejected_rewards[i] - - c_inds = (chosen_id == self.PAD_ID).nonzero() - c_ind = c_inds[self.num_padding_at_beginning].item() if len( - c_inds - ) > self.num_padding_at_beginning else seq_len # OPT model pads the first token, so we need to use the second padding token as the end of the sequence - check_divergence = (chosen_id != rejected_id).nonzero() - - if len(check_divergence) == 0: - end_ind = rejected_reward.size(-1) - divergence_ind = end_ind - 1 - r_ind = c_ind + if self.fallback_mask is None: + fallback_mask = torch.zeros([seq_len]).bool() + fallback_mask[seq_len - 1] = 1 + self.fallback_mask = fallback_mask.to(rejected_reward.device) + if self.optimized_loss_calc: + # the below code is an performance optimize implementation to the else statement flow + # by using masks, we are achieving non dynamic shapes and infereable flow + def get_last_before_padding(paddings, num_begin_padding): + # this function returns a mask, where there is a single 1, in the last non padded element + # for example, flat = 11000011 (see below) with num begin paddings 2 + # 1. remove begin padding indication: 00001100 + shifted = torch.roll(paddings, -num_begin_padding, 0) + # 2. put the first 1's to start on the last non padded: 00000110 + shifted = torch.roll(shifted, num_begin_padding - 1, 0) + # not_paddings will indicate where we don't have padding: 00111100 + not_paddings = torch.logical_not(paddings) + # now we will get the last non padded value: 00000100 + mask = not_paddings * shifted + return mask + + # target is to create a mask that has 1's from the first token that + # chosen and rejected ids are different till the last non padded element (the longest among the two) + # incase of identical rejected/chosen - the mask will contain the last element in each sequence + + # a mask for each sequence that will place 1's where we have padded elements + chosen_padding_mask = (chosen_id == self.PAD_ID).bool() + rejected_padding_mask = (rejected_id == self.PAD_ID).bool() + + # united_unpadding_mask will what are the unite between the unpadded elements + # will indicate 1's where we have non padded tokens, in either of the inputs + united_unpadding_mask = torch.logical_not( + torch.logical_and(chosen_padding_mask, + rejected_padding_mask)) + + # get a mask of all the different tokens + divergence_mask = (chosen_id != rejected_id) + divergence_mask = divergence_mask.cumsum(0).bool() + + # loss mask indicates the elements which should be taken into consideration after sigmoid calc + # from the first divergence, till the last non padded token + loss_mask = torch.logical_and(divergence_mask, + united_unpadding_mask) + loss_mask = torch.where(divergence_mask.sum().bool(), + loss_mask, self.fallback_mask) + + # calc logsigmoid on all the input and mask the not interesting ones + if self.loss_to_fp32: + chosen_reward = chosen_reward.float() + rejected_reward = rejected_reward.float() + logsigmoid = torch.nn.functional.logsigmoid( + chosen_reward.float() - + rejected_reward.float()) * loss_mask + #average according to the interesting number of elements + num_elements_in_loss = loss_mask.sum().float() + loss += -(logsigmoid.sum() / num_elements_in_loss) + + # log the c_ind / r_ind in chosen_mean_scores / rejected_mean_scores + c_ind_mask = get_last_before_padding( + chosen_padding_mask, self.num_padding_at_beginning) + c_ind_mask = torch.where( + chosen_padding_mask.sum() > self.num_padding_at_beginning, + c_ind_mask, self.fallback_mask) + chosen_mean_score = (c_ind_mask.float() * + chosen_reward.float()).sum() + chosen_mean_scores.append(chosen_mean_score) + + r_ind_mask = get_last_before_padding( + rejected_padding_mask, self.num_padding_at_beginning) + r_ind_mask = torch.where( + rejected_padding_mask.sum() > + self.num_padding_at_beginning, r_ind_mask, + self.fallback_mask) + rejected_mean_score = (r_ind_mask.float() * + rejected_reward.float()).sum() + rejected_mean_scores.append(rejected_mean_score) else: - # Check if there is any padding otherwise take length of sequence - r_inds = (rejected_id == self.PAD_ID).nonzero() - r_ind = r_inds[self.num_padding_at_beginning].item( - ) if len(r_inds) > self.num_padding_at_beginning else seq_len - end_ind = max(c_ind, r_ind) - divergence_ind = check_divergence[0] - assert divergence_ind > 0 - c_truncated_reward = chosen_reward[divergence_ind:end_ind] - r_truncated_reward = rejected_reward[divergence_ind:end_ind] - chosen_mean_scores.append( - chosen_reward[c_ind - 1]) #use the end score for reference - rejected_mean_scores.append(rejected_reward[r_ind - 1]) - - if self.compute_fp32_loss: - c_truncated_reward = c_truncated_reward.float() - r_truncated_reward = r_truncated_reward.float() - loss += -torch.nn.functional.logsigmoid(c_truncated_reward - - r_truncated_reward).mean() + c_inds = (chosen_id == self.PAD_ID).nonzero() + c_ind = c_inds[self.num_padding_at_beginning].item() if len( + c_inds + ) > self.num_padding_at_beginning else seq_len # OPT model pads the first token, so we need to use the second padding token as the end of the sequence + check_divergence = (chosen_id != rejected_id).nonzero() + if len(check_divergence) == 0: + end_ind = rejected_reward.size(-1) + divergence_ind = end_ind - 1 + r_ind = c_ind + else: + # Check if there is any padding otherwise take length of sequence + r_inds = (rejected_id == self.PAD_ID).nonzero() + r_ind = r_inds[self.num_padding_at_beginning].item( + ) if len( + r_inds) > self.num_padding_at_beginning else seq_len + end_ind = max(c_ind, r_ind) + divergence_ind = check_divergence[0] + assert divergence_ind > 0 + c_truncated_reward = chosen_reward[divergence_ind:end_ind] + r_truncated_reward = rejected_reward[divergence_ind:end_ind] + if self.loss_to_fp32: + c_truncated_reward = c_truncated_reward.float() + r_truncated_reward = r_truncated_reward.float() + loss += -torch.nn.functional.logsigmoid( + c_truncated_reward - r_truncated_reward).mean() + + chosen_mean_scores.append( + chosen_reward[c_ind - 1]) #use the end score for reference + rejected_mean_scores.append(rejected_reward[r_ind - 1]) loss = loss / bs chosen_mean_scores = torch.stack(chosen_mean_scores) diff --git a/applications/DeepSpeed-Chat/dschat/utils/utils.py b/applications/DeepSpeed-Chat/dschat/utils/utils.py index e4dc7d036..7ef7e0df9 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/utils.py @@ -6,6 +6,7 @@ import torch import random import numpy as np +from datetime import datetime from transformers import set_seed, AutoTokenizer import json import deepspeed @@ -308,3 +309,29 @@ def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0): if global_rank == 0: torch.save(output_state_dict, output_model_file) del output_state_dict + + +def print_loss(epoch, step, steps_per_print, gas, loss, loss_sum, rank): + loss_ = loss.detach() + loss_sum = torch.zeros_like(loss_) if loss_sum is None else loss_sum + loss_sum += loss_ + if step > 0 and step % (steps_per_print * gas) == 0: + opt_step = step / gas + avg_loss = loss_sum / gas + print_rank_0( + f"[{datetime.now()}] epoch: {epoch} | step: {opt_step} | avg_loss: {avg_loss}", + rank) + if step > 0 and step % gas == 0: + loss_sum.zero_() + + return loss_sum + + +def is_hpu(): + return get_accelerator().device_name() == "hpu" + + +def hpu_mark_step(): + if is_hpu(): + import habana_frameworks.torch.core as htcore + htcore.mark_step() diff --git a/applications/DeepSpeed-Chat/inference/chatbot.py b/applications/DeepSpeed-Chat/inference/chatbot.py index 5a4e36895..8ffb717a6 100644 --- a/applications/DeepSpeed-Chat/inference/chatbot.py +++ b/applications/DeepSpeed-Chat/inference/chatbot.py @@ -10,7 +10,8 @@ import os import json from transformers import pipeline, set_seed -from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM +from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer, AutoModelForCausalLM +from deepspeed.accelerator import get_accelerator def parse_args(): @@ -54,7 +55,7 @@ def get_generator(path): generator = pipeline("text-generation", model=model, tokenizer=tokenizer, - device="cuda:0") + device=get_accelerator().device_name(0)) return generator diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py index aa505a25d..d183accee 100755 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py @@ -5,7 +5,8 @@ # DeepSpeed Team import argparse import math - +import sys +import time import torch from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -22,7 +23,7 @@ from deepspeed import get_accelerator from dschat.utils.data.data_utils import create_prompt_dataset -from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer +from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer, TensorAccumulator, update_optim_step_mean_loss, print_optim_step_mean_loss from dschat.utils.ds_utils import get_train_ds_config from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible from dschat.utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss @@ -174,12 +175,14 @@ def parse_args(): help= "Initial LoRA learning rate (after the potential warmup period) to use." ) - ## low precision + ## bf16 parser.add_argument( - '--compute_fp32_loss', - action='store_true', - help='Relevant for low precision dtypes (fp16, bf16, etc.). ' - 'If specified, loss is calculated in fp32.') + '--no_bf16_to_fp32_loss', + action='store_false', + dest='bf16_to_fp32_loss', + help='Relevant only with bf16 dtype. ' + 'If specified, loss is calculated in bf16. Otherwise, calculated in fp32.' + ) ## Tensorboard logging parser.add_argument('--enable_tensorboard', action='store_true', @@ -199,9 +202,15 @@ def parse_args(): help="Specify the format of the `eot_token`", ) ## Print loss - parser.add_argument('--print_loss', + parser.add_argument( + '--print_loss', + action='store_true', + help='Prints loss at deepspeed config steps_per_print interval.') + ## Debug + parser.add_argument('--no_fused_kernels', action='store_true', - help='Prints loss at each step.') + help='Do not use cuda fused kernels.') + ## DeepSpeed parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() @@ -209,14 +218,18 @@ def parse_args(): def main(): + if is_hpu(): + import habana_frameworks.torch.core as htcore + args = parse_args() if args.local_rank == -1: device = torch.device(get_accelerator().device_name()) else: - get_accelerator().set_device(args.local_rank) - device = torch.device(get_accelerator().device_name(), args.local_rank) - # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + if not is_hpu(): + get_accelerator().set_device(args.local_rank) + device = torch.device(get_accelerator().device_name(args.local_rank)) + # Initializes the distributed backend which will take care of synchronizing nodes/GPUs # torch.distributed.init_process_group(backend='nccl') deepspeed.init_distributed() @@ -251,7 +264,7 @@ def main(): ds_config, dropout=args.dropout) - if args.compute_fp32_loss: + if (args.dtype == "bf16") and args.bf16_to_fp32_loss: print_rank_0( f"Using model {model.__class__.__name__} with loss in fp32", args.global_rank) @@ -264,6 +277,9 @@ def main(): model = only_optimize_lora_parameters(model) model = make_model_gradient_checkpointing_compatible(model) + if is_hpu(): # TODO SW-146602: remove this WA when SW-141762 is resolved + model.to(dtype=torch.bfloat16, device=get_accelerator().device_name()) + # Prepare the data train_phase = 1 train_dataset, eval_dataset = create_prompt_dataset( @@ -298,8 +314,11 @@ def evaluation(model, eval_dataloader): losses = 0 for step, batch in enumerate(eval_dataloader): batch = to_device(batch, device) + hpu_mark_step() + with torch.no_grad(): outputs = model(**batch) + hpu_mark_step() loss = outputs.loss losses += loss.float() @@ -318,13 +337,20 @@ def evaluation(model, eval_dataloader): optimizer_grouped_parameters = get_optimizer_grouped_parameters( model, args.weight_decay, args.lora_learning_rate) - AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam + if args.offload: + AdamOptimizer = DeepSpeedCPUAdam + elif args.no_fused_kernels or is_hpu(): + AdamOptimizer = torch.optim.AdamW + else: + AdamOptimizer = FusedAdam + print_rank_0(f'Using {AdamOptimizer.__name__} optimizer', args.global_rank) + optimizer = AdamOptimizer(optimizer_grouped_parameters, lr=args.learning_rate, betas=(0.9, 0.95)) - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps) + gas = args.gradient_accumulation_steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gas) lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, @@ -355,23 +381,27 @@ def evaluation(model, eval_dataloader): print_rank_0( f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}", args.global_rank) + loss_sum = None model.train() - import time for step, batch in enumerate(train_dataloader): start = time.time() batch = to_device(batch, device) outputs = model(**batch, use_cache=False) loss = outputs.loss - if args.print_loss: - print( - f"Epoch: {epoch}, Step: {step}, Rank: {torch.distributed.get_rank()}, loss = {loss}" - ) model.backward(loss) + hpu_mark_step() model.step() + hpu_mark_step() end = time.time() if torch.distributed.get_rank() == 0: - print_throughput(model.model, args, end - start, - args.global_rank) + hf_model = model.model if hasattr(model, + 'model') else model.module + print_throughput(hf_model, args, end - start, args.global_rank) + if args.print_loss: + steps_per_print = ds_config['steps_per_print'] + loss_sum = print_loss(epoch, step, steps_per_print, + args.gradient_accumulation_steps, loss, + loss_sum, args.global_rank) # Evaluate perplexity on the validation set. print_rank_0( diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index 8cdf5644d..1bc72e636 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -21,7 +21,7 @@ from dschat.utils.model.model_utils import create_critic_model from dschat.utils.data.data_utils import create_prompt_dataset, DataCollatorReward -from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer +from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer, TensorAccumulator, update_optim_step_mean_loss, print_optim_step_mean_loss from dschat.utils.ds_utils import get_train_ds_config from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible @@ -173,7 +173,14 @@ def parse_args(): help= "Initial LoRA learning rate (after the potential warmup period) to use." ) - + ## bf16 + parser.add_argument( + '--no_bf16_to_fp32_loss', + action='store_false', + dest='bf16_to_fp32_loss', + help='Relevant only with bf16 dtype. ' + 'If specified, loss is calculated in bf16. Otherwise, calculated in fp32.' + ) # Evaluation parser.add_argument("--eval_interval", type=int, @@ -183,13 +190,6 @@ def parse_args(): type=int, default=100, help="Maximum evaluation iterations") - ## low precision - parser.add_argument( - '--compute_fp32_loss', - action='store_true', - help='Relevant for low precision dtypes (fp16, bf16, etc.). ' - 'If specified, loss is calculated in fp32.') - ## Tensorboard logging parser.add_argument('--enable_tensorboard', action='store_true', @@ -202,6 +202,24 @@ def parse_args(): "--add_eot_token", action='store_true', help="Add <|endoftext|> as additional special token to tokenizer") + + ## Print loss + parser.add_argument( + '--print_loss', + action='store_true', + help='Prints loss at deepspeed config steps_per_print interval.') + ## Debug + parser.add_argument('--no_fused_kernels', + action='store_true', + help='Do not use cuda fused kernels.') + ## UPH + parser.add_argument( + "--optimized_reward_loss_calc", + action='store_true', + help= + "Whether to use an optimized approach for RM loss calculation, or legacy flow" + ) + ## DeepSpeed parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() @@ -209,14 +227,18 @@ def parse_args(): def main(): + if is_hpu(): + import habana_frameworks.torch.core as htcore + args = parse_args() if args.local_rank == -1: device = torch.device(get_accelerator().device_name()) else: - get_accelerator().set_device(args.local_rank) - device = torch.device(get_accelerator().device_name(), args.local_rank) - # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + if is_hpu(): + get_accelerator().set_device(args.local_rank) + device = torch.device(get_accelerator().device_name(args.local_rank)) + # Initializes the distributed backend which will take care of synchronizing nodes/GPUs # torch.distributed.init_process_group(backend='nccl') deepspeed.init_distributed() @@ -244,13 +266,17 @@ def main(): tokenizer = load_hf_tokenizer(args.model_name_or_path, fast_tokenizer=True, add_special_tokens=additional_special_tokens) - rm_model = create_critic_model(args.model_name_or_path, - tokenizer, - ds_config, - args.num_padding_at_beginning, - dropout=args.dropout, - zero_stage=args.zero_stage, - compute_fp32_loss=args.compute_fp32_loss) + + loss_to_fp32 = (args.dtype == "bf16") and args.bf16_to_fp32_loss + rm_model = create_critic_model( + args.model_name_or_path, + tokenizer, + ds_config, + args.num_padding_at_beginning, + dropout=args.dropout, + zero_stage=args.zero_stage, + loss_to_fp32=loss_to_fp32, + optimized_reward_loss_calc=args.optimized_reward_loss_calc) # Model bigscience/bloom-560m has large variance at ln_f.weight parameter # This makes bf16 finetuning hard. @@ -270,7 +296,7 @@ def main(): torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight) torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias) force_optimize_params.extend( - ['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias']) + ['rwtranrsformer.ln_f.weight', 'rwtranrsformer.ln_f.bias']) if args.lora_dim > 0: rm_model = convert_linear_layer_to_lora(rm_model, @@ -282,6 +308,10 @@ def main(): force_optimize_params) rm_model = make_model_gradient_checkpointing_compatible(rm_model) + # TODO SW-146776: remove this WA once SW-141762 is resolved + if is_hpu(): + rm_model.to(dtype=torch.bfloat16, device=device) + train_phase = 2 train_dataset, eval_dataset = create_prompt_dataset( args.local_rank, args.data_path, args.data_split, @@ -324,6 +354,7 @@ def evaluation_reward(model, dataloader, eval_iters): rejected_scores += _outputs["rejected_mean_scores"].mean().float() if (_step + 1) == eval_iters: break + model.train() _acc = correct_predictions / total_predictions chosen_scores = chosen_scores / (_step + 1) rejected_scores = rejected_scores / (_step + 1) @@ -339,13 +370,21 @@ def evaluation_reward(model, dataloader, eval_iters): optimizer_grouped_parameters = get_optimizer_grouped_parameters( rm_model, args.weight_decay, args.lora_learning_rate) - AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam + # TODO SW-146129: change the file to use HPEX optimizer instead of AdamW on hpu + if args.offload: + AdamOptimizer = DeepSpeedCPUAdam + elif args.no_fused_kernels or is_hpu(): + AdamOptimizer = torch.optim.AdamW + else: + AdamOptimizer = FusedAdam + print_rank_0(f'Using {AdamOptimizer.__name__} optimizer', args.global_rank) + optimizer = AdamOptimizer(optimizer_grouped_parameters, lr=args.learning_rate, betas=(0.9, 0.95)) - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps) + gas = args.gradient_accumulation_steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gas) lr_scheduler = get_scheduler( name=args.lr_scheduler_type, @@ -385,12 +424,20 @@ def evaluation_reward(model, dataloader, eval_iters): args.global_rank) rm_model.train() mean_loss = 0 + loss_sum = None for step, batch in enumerate(train_dataloader): batch = to_device(batch, device) outputs = rm_model(**batch, use_cache=False) loss = outputs["loss"] rm_model.backward(loss) + hpu_mark_step() rm_model.step() + hpu_mark_step() + if args.print_loss: + steps_per_print = ds_config['steps_per_print'] + loss_sum = print_loss(epoch, step, steps_per_print, + args.gradient_accumulation_steps, loss, + loss_sum, args.global_rank) mean_loss += loss.item() total_micro_steps += 1 gas_boundary = (total_micro_steps % @@ -406,10 +453,9 @@ def evaluation_reward(model, dataloader, eval_iters): f"Iter {total_steps}: c_scores: {reward_score}, r_scores: {reject_score}, " f"diff: {reward_score - reject_score}, acc: {acc}", args.global_rank) - rm_model.train() print_rank_0( - f"Epoch {epoch+1}/{args.num_train_epochs} with loss {mean_loss/(step+1)}", + f"Epoch {epoch+1}/{args.num_train_epochs} with loss {loss_sum.get_mean()}", args.global_rank) # Evaluate reward_loss on the validation set. print_rank_0( diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py index 23f9a66af..05e08ad29 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py @@ -44,12 +44,13 @@ def load_stuff(model_name_or_path, num_padding_at_beginning, fast_tokenizer=True, add_special_tokens=additional_special_tokens) tokenizer.pad_token = tokenizer.eos_token - model = create_critic_model(model_name_or_path, - tokenizer, - None, - num_padding_at_beginning, - dropout=0.) - + model = create_critic_model( + model_name_or_path, + tokenizer, + ds_config=None, + num_padding_at_beginning=num_padding_at_beginning, + rlhf_training=False, + dropout=0.) return model, tokenizer diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index 1378dc4e6..2e9b7ee35 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -324,12 +324,13 @@ def parse_args(): '--enable_mixed_precision_lora', action='store_true', help='Enable Mixed Precision ZeRO++ for training and generation.') - ## low precision + ## bf16 parser.add_argument( - '--compute_fp32_loss', - action='store_true', - help='Relevant for low precision dtypes (fp16, bf16, etc.). ' - 'If specified, loss is calculated in fp32.' + '--no_bf16_to_fp32_loss', + action='store_false', + dest='bf16_to_fp32_loss', + help='Relevant only with bf16 dtype. ' + 'If specified, loss is calculated in bf16. Otherwise, calculated in fp32. ' 'This applies for both actor and critic models.') ## Tensorboard logging parser.add_argument('--enable_tensorboard', @@ -371,7 +372,17 @@ def parse_args(): help= "Training non-overflow step at which to terminate training during testing." ) + parser.add_argument('--no_fused_kernels', + action='store_true', + help='Do not use cuda fused kernels.') + + ## HPU + parser.add_argument("--enable_hpu_graphs", + default=False, + action="store_true", + help="Enable HPU graphs.") + ## DeepSpeed parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() @@ -393,6 +404,9 @@ def parse_args(): "The combination of [actor_zero_stage==2, critic_zero_stage==2, enable_hybrid_engine=True, offload=True, lora=False] is currently unsupported due to training instability!" ) + if is_hpu(): + assert not args.enable_mixed_precision_lora, "HPU does not support --enable_mixed_precision_lora" + return args @@ -449,11 +463,17 @@ def main(): if args.local_rank == -1: device = torch.device(get_accelerator().device_name()) else: - get_accelerator().set_device(args.local_rank) - device = torch.device(get_accelerator().device_name(), args.local_rank) - # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + if not is_hpu(): + get_accelerator().set_device(args.local_rank) + device = torch.device(get_accelerator().device_name(args.local_rank)) + # Initializes the distributed backend which will take care of synchronizing nodes/GPUs deepspeed.init_distributed() + if is_hpu(): + from transformers.generation import GenerationMixin + from optimum.habana.transformers.generation import GaudiGenerationMixin + GenerationMixin.generate = GaudiGenerationMixin.generate + args.global_rank = torch.distributed.get_rank() unsupervised_training_enabled = args.unsupervised_dataset_name and args.unsupervised_dataset_config_name