diff --git a/aitextgen/train.py b/aitextgen/train.py index ce54a4a..a73fc6c 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -11,7 +11,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.progress import ProgressBarBase -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.accelerators import TPUAccelerator class ATGTransformer(pl.LightningModule): @@ -115,7 +115,7 @@ def __init__( self.progress_bar_refresh_rate = progress_bar_refresh_rate self.train_transformers_only = train_transformers_only self.num_layers_freeze = num_layers_freeze - + @property def save_every_check(self): return self.save_every > 0 and self.steps % self.save_every == 0 @@ -141,15 +141,15 @@ def on_train_start(self, trainer, pl_module): def on_train_end(self, trainer, pl_module): self.main_progress_bar.close() self.unfreeze_layers(pl_module) - + def get_metrics(self, trainer, pl_module): # don't show the version number items = super().get_metrics(trainer, pl_module) items.pop("v_num", None) return items - def on_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) # clean up the GPU cache used for the benchmark # https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/4 @@ -186,8 +186,8 @@ def on_batch_end(self, trainer, pl_module): desc += f" — GPU Mem: {gpu_memory} MB" self.main_progress_bar.update(self.progress_bar_refresh_rate) self.main_progress_bar.set_description(desc) - - if _TPU_AVAILABLE and self.save_every_check: + + if TPUAccelerator.is_available() and self.save_every_check: did_unfreeze = False if self.enabled: self.unfreeze_layers(pl_module) @@ -195,10 +195,10 @@ def on_batch_end(self, trainer, pl_module): self.save_pytorch_model(trainer, pl_module, tpu=True) if did_unfreeze: self.freeze_layers(pl_module) - + if self.enabled: did_unfreeze = False - if not _TPU_AVAILABLE and self.save_every_check: + if not TPUAccelerator.is_available() and self.save_every_check: self.unfreeze_layers(pl_module) self.save_pytorch_model(trainer, pl_module) did_unfreeze = True @@ -243,13 +243,14 @@ def generate_sample_text(self, trainer, pl_module): self.main_progress_bar.write("=" * 10) def save_pytorch_model(self, trainer, pl_module, tpu=False): - + if self.enabled: self.main_progress_bar.write( f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m" ) if tpu: import torch_xla.core.xla_model as xm + pl_module.model.save_pretrained(self.output_dir, save_function=xm.save) else: pl_module.model.save_pretrained(self.output_dir) diff --git a/requirements.txt b/requirements.txt index f3efead..a9e23ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ transformers>=4.5.1 fire>=0.3.0 -pytorch-lightning>=1.3.1 -torch>=1.6.0 \ No newline at end of file +pytorch-lightning>=1.8.0 +torch>=1.6.0