diff --git a/alpaca_lora/scripts/utils/process_llama_megatron_ckpt.py b/alpaca_lora/scripts/utils/process_llama_megatron_ckpt.py index ef7bcae..ba3bc37 100644 --- a/alpaca_lora/scripts/utils/process_llama_megatron_ckpt.py +++ b/alpaca_lora/scripts/utils/process_llama_megatron_ckpt.py @@ -1,5 +1,5 @@ import torch -import json +import os import argparse @@ -140,7 +140,7 @@ def build_llama_state_dict(llama_dir, llama_file, parallel_size): for parallel_idx, parallel_state in enumerate(split_parameter(llama_state, parallel_size)): state['model'] = parallel_state dump_file = "model-model_part-{}.pt".format(parallel_idx) - torch.save(state, llama_dir + 'megatron_{}/'.format(parallel_size) + dump_file) + torch.save(state, os.path.join(llama_dir + 'megatron_{}/'.format(parallel_size)) + dump_file) print("dump new model to {}{}".format(llama_dir, dump_file)) def main(): diff --git a/alpaca_lora/src/llama_model.py b/alpaca_lora/src/llama_model.py index e2bfda5..34762fd 100644 --- a/alpaca_lora/src/llama_model.py +++ b/alpaca_lora/src/llama_model.py @@ -201,7 +201,7 @@ def get_normalized_probs( else: return utils.softmax(logits, dim=-1) - def forward(self, src_tokens, src_lengths, src_pos, tgt_pos, prev_output_tokens): + def forward(self, src_tokens, src_lengths, src_pos, tgt_pos, tgt_tokens): src_x, src_padding, src_attn, src_hiddens = self.decoder( prev_output_tokens=src_tokens, @@ -215,7 +215,7 @@ def forward(self, src_tokens, src_lengths, src_pos, tgt_pos, prev_output_tokens) incremental_state[layer_idx]['key'] = layer_hidden_states tgt_x, tgt_padding, tgt_attn, tgt_hiddens = self.decoder( - prev_output_tokens=prev_output_tokens, + prev_output_tokens=tgt_tokens, incremental_state=incremental_state, src_pos=src_pos, tgt_pos=tgt_pos, @@ -410,7 +410,7 @@ def forward( ): if incremental_state is not None and trunc_flg: - prev_output_tokens = prev_output_tokens[:, -1:] + prev_output_tokens = prev_output_tokens[:, -1:] bsz, target_len = prev_output_tokens.size() x = self.embed_tokens(prev_output_tokens)