diff --git a/aitextgen/train.py b/aitextgen/train.py index f7f6fc4..23baa21 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -1,14 +1,17 @@ -import pytorch_lightning as pl -from pytorch_lightning.callbacks.progress import ProgressBarBase -from tqdm.auto import tqdm +import os +import shutil +import subprocess import sys + import torch from torch.optim import AdamW from torch.utils.data import DataLoader +from tqdm.auto import tqdm from transformers import get_linear_schedule_with_warmup -import os -import shutil -import subprocess + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.progress import ProgressBarBase +from pytorch_lightning.utilities import _TPU_AVAILABLE class ATGTransformer(pl.LightningModule): @@ -18,12 +21,12 @@ class ATGTransformer(pl.LightningModule): def __init__(self, model, dataset, hparams, tokenizer): super(ATGTransformer, self).__init__() - self.model, self.dataset, self.hparams, self.tokenizer = ( + self.model, self.dataset, self.tokenizer = ( model, dataset, - hparams, tokenizer, ) + self.save_hyperparameters(hparams) def forward(self, inputs): return self.model(**inputs, return_dict=False) @@ -112,6 +115,10 @@ def __init__( self.progress_bar_refresh_rate = progress_bar_refresh_rate self.train_transformers_only = train_transformers_only self.num_layers_freeze = num_layers_freeze + + @property + def save_every_check(self): + return self.save_every > 0 and self.steps % self.save_every == 0 def enabled(self): self.enabled = True @@ -172,10 +179,19 @@ def on_batch_end(self, trainer, pl_module): desc += f" — GPU Mem: {gpu_memory} MB" self.main_progress_bar.update(self.progress_bar_refresh_rate) self.main_progress_bar.set_description(desc) - + + if _TPU_AVAILABLE and self.save_every_check: + did_unfreeze = False + if self.enabled: + self.unfreeze_layers(pl_module) + did_unfreeze = True + self.save_pytorch_model(trainer, pl_module, tpu=True) + if did_unfreeze: + self.freeze_layers(pl_module) + if self.enabled: did_unfreeze = False - if self.save_every > 0 and self.steps % self.save_every == 0: + if not _TPU_AVAILABLE and self.save_every_check: self.unfreeze_layers(pl_module) self.save_pytorch_model(trainer, pl_module) did_unfreeze = True @@ -228,13 +244,19 @@ def generate_sample_text(self, trainer, pl_module): self.main_progress_bar.write("=" * 10) - def save_pytorch_model(self, trainer, pl_module): - self.main_progress_bar.write( - f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m" - ) - pl_module.model.save_pretrained(self.output_dir) + def save_pytorch_model(self, trainer, pl_module, tpu=False): + + if self.enabled: + self.main_progress_bar.write( + f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m" + ) + if tpu: + import torch_xla.core.xla_model as xm + pl_module.model.save_pretrained(self.output_dir, save_function=xm.save) + else: + pl_module.model.save_pretrained(self.output_dir) - if self.save_gdrive: + if self.enabled and self.save_gdrive: for pt_file in ["pytorch_model.bin", "config.json"]: shutil.copyfile( os.path.join(self.output_dir, pt_file),