Skip to content

Commit

Permalink
load model using device.index
Browse files Browse the repository at this point in the history
  • Loading branch information
anthdr committed Jan 17, 2025
1 parent 484b745 commit 0f3ea55
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion eole/bin/tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def run(cls, args):
if len(config.gpu_ranks) > 1:
logger.warning(f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used.")

vocabs, model, model_opt = BaseModel.load_test_model(config, 0)
vocabs, model, model_opt = BaseModel.load_test_model(config, device.index)
pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD)
padding_idx = vocabs["tgt"].tokens_to_ids[pad_token]
criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx, reduction="none")
Expand Down

0 comments on commit 0f3ea55

Please sign in to comment.