Skip to content

Commit

Permalink
Trainer: delegate default generation values to generation_config (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Sep 5, 2023
1 parent aea7614 commit 9a70d6e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 21 deletions.
13 changes: 7 additions & 6 deletions examples/pytorch/question-answering/trainer_seq2seq_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def evaluate(
**gen_kwargs,
) -> Dict[str, float]:
gen_kwargs = gen_kwargs.copy()
gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
)
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)

# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
# training args
if gen_kwargs.get("max_length") is None and self.args.generation_max_length is not None:
gen_kwargs["max_length"] = self.args.generation_max_length
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
gen_kwargs["num_beams"] = self.args.generation_num_beams
self._gen_kwargs = gen_kwargs

eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
Expand Down
40 changes: 25 additions & 15 deletions src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,17 @@ def evaluate(
"""

gen_kwargs = gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:

# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
# training args
if (
gen_kwargs.get("max_length") is None
and gen_kwargs.get("max_new_tokens") is None
and self.args.generation_max_length is not None
):
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
gen_kwargs["num_beams"] = self.args.generation_num_beams
self._gen_kwargs = gen_kwargs

return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
Expand Down Expand Up @@ -206,11 +212,17 @@ def predict(
"""

gen_kwargs = gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:

# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
# training args
if (
gen_kwargs.get("max_length") is None
and gen_kwargs.get("max_new_tokens") is None
and self.args.generation_max_length is not None
):
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
gen_kwargs["num_beams"] = self.args.generation_num_beams
self._gen_kwargs = gen_kwargs

return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
Expand Down Expand Up @@ -256,16 +268,14 @@ def prediction_step(

# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# gen_kwargs > model.generation_config > default GenerationConfig()

# non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
gen_kwargs = self._gen_kwargs.copy()
if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
gen_kwargs.pop("num_beams")
if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
gen_kwargs.pop("max_length")

if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.model.config.max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
)
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
gen_kwargs["synced_gpus"] = (
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
Expand Down

0 comments on commit 9a70d6e

Please sign in to comment.