Skip to content

Commit

Permalink
refactor examples to accommodate Lightning-AI/pytorch-lightning#18105
Browse files Browse the repository at this point in the history
  • Loading branch information
speediedan committed May 6, 2024
1 parent 41ba761 commit 29d355e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
16 changes: 12 additions & 4 deletions src/fts_examples/stable/fts_superglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,23 @@ def __init__(
super().__init__()
task_name = task_name if task_name in TASK_NUM_LABELS.keys() else DEFAULT_TASK
self.text_fields = self.TASK_TEXT_FIELD_MAP[task_name]
self.init_hparams = {
"model_name_or_path": model_name_or_path,
"task_name": task_name,
"max_seq_length": max_seq_length,
"train_batch_size": train_batch_size,
"eval_batch_size": eval_batch_size,
"dataloader_kwargs": dataloader_kwargs,
"tokenizers_parallelism": tokenizers_parallelism,
}
self.save_hyperparameters(self.init_hparams)
self.dataloader_kwargs = {
"num_workers": dataloader_kwargs.get("num_workers", 0),
"pin_memory": dataloader_kwargs.get("pin_memory", False),
}
self.save_hyperparameters()
os.environ["TOKENIZERS_PARALLELISM"] = "true" if self.hparams.tokenizers_parallelism else "false"
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.model_name_or_path, use_fast=True, local_files_only=False
)
self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.model_name_or_path, use_fast=True,
local_files_only=False)

def prepare_data(self):
"""Load the SuperGLUE dataset."""
Expand Down
11 changes: 10 additions & 1 deletion src/fts_examples/stable/ipynb_src/fts_superglue_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,20 @@ def __init__(
super().__init__()
task_name = task_name if task_name in TASK_NUM_LABELS.keys() else DEFAULT_TASK
self.text_fields = self.TASK_TEXT_FIELD_MAP[task_name]
self.init_hparams = {
"model_name_or_path": model_name_or_path,
"task_name": task_name,
"max_seq_length": max_seq_length,
"train_batch_size": train_batch_size,
"eval_batch_size": eval_batch_size,
"dataloader_kwargs": dataloader_kwargs,
"tokenizers_parallelism": tokenizers_parallelism,
}
self.save_hyperparameters(self.init_hparams)
self.dataloader_kwargs = {
"num_workers": dataloader_kwargs.get("num_workers", 0),
"pin_memory": dataloader_kwargs.get("pin_memory", False),
}
self.save_hyperparameters()
os.environ["TOKENIZERS_PARALLELISM"] = "true" if self.hparams.tokenizers_parallelism else "false"
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.model_name_or_path, use_fast=True, local_files_only=False
Expand Down
1 change: 1 addition & 0 deletions src/fts_examples/stable/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"does not have many workers",
"is smaller than the logging interval",
"sentencepiece tokenizer that you are converting",
"`resume_download` is deprecated", # required because of upstream usage as of 2.2.2
"distutils Version classes are deprecated", # still required as of PyTorch/Lightning 2.2
"Please use torch.utils._pytree.register_pytree_node", # temp allow deprecated behavior of transformers
"We are importing from `pydantic", # temp pydantic import migration warning
Expand Down

0 comments on commit 29d355e

Please sign in to comment.