From fa52f1f693ddea0090f6562773e383f4653b5e48 Mon Sep 17 00:00:00 2001 From: Bill Mill Date: Tue, 20 Dec 2022 15:54:37 -0500 Subject: [PATCH 1/3] fix missing _TPU_AVAILABLE variable pytorch_lightning removed this variable in commit 7ef87464ddd740f8af8388bd95130066c65874da: https://github.com/Lightning-AI/lightning/commit/7ef8746 replace _TPU_AVAILABLE with TPUAccelerator.is_available --- aitextgen/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aitextgen/train.py b/aitextgen/train.py index ce54a4a..803b708 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -11,7 +11,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.progress import ProgressBarBase -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.accelerators import TPUAccelerator class ATGTransformer(pl.LightningModule): @@ -187,7 +187,7 @@ def on_batch_end(self, trainer, pl_module): self.main_progress_bar.update(self.progress_bar_refresh_rate) self.main_progress_bar.set_description(desc) - if _TPU_AVAILABLE and self.save_every_check: + if TPUAccelerator.is_available() and self.save_every_check: did_unfreeze = False if self.enabled: self.unfreeze_layers(pl_module) @@ -198,7 +198,7 @@ def on_batch_end(self, trainer, pl_module): if self.enabled: did_unfreeze = False - if not _TPU_AVAILABLE and self.save_every_check: + if not TPUAccelerator.is_available() and self.save_every_check: self.unfreeze_layers(pl_module) self.save_pytorch_model(trainer, pl_module) did_unfreeze = True From 007cca37a98abbcda2cabbd60f86e65bc5fe9e53 Mon Sep 17 00:00:00 2001 From: Bill Mill Date: Wed, 21 Dec 2022 16:04:22 -0500 Subject: [PATCH 2/3] fix on_batch_end for updated lightning --- aitextgen/train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/aitextgen/train.py b/aitextgen/train.py index 803b708..a73fc6c 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -115,7 +115,7 @@ def __init__( self.progress_bar_refresh_rate = progress_bar_refresh_rate self.train_transformers_only = train_transformers_only self.num_layers_freeze = num_layers_freeze - + @property def save_every_check(self): return self.save_every > 0 and self.steps % self.save_every == 0 @@ -141,15 +141,15 @@ def on_train_start(self, trainer, pl_module): def on_train_end(self, trainer, pl_module): self.main_progress_bar.close() self.unfreeze_layers(pl_module) - + def get_metrics(self, trainer, pl_module): # don't show the version number items = super().get_metrics(trainer, pl_module) items.pop("v_num", None) return items - def on_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) # clean up the GPU cache used for the benchmark # https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/4 @@ -186,7 +186,7 @@ def on_batch_end(self, trainer, pl_module): desc += f" — GPU Mem: {gpu_memory} MB" self.main_progress_bar.update(self.progress_bar_refresh_rate) self.main_progress_bar.set_description(desc) - + if TPUAccelerator.is_available() and self.save_every_check: did_unfreeze = False if self.enabled: @@ -195,7 +195,7 @@ def on_batch_end(self, trainer, pl_module): self.save_pytorch_model(trainer, pl_module, tpu=True) if did_unfreeze: self.freeze_layers(pl_module) - + if self.enabled: did_unfreeze = False if not TPUAccelerator.is_available() and self.save_every_check: @@ -243,13 +243,14 @@ def generate_sample_text(self, trainer, pl_module): self.main_progress_bar.write("=" * 10) def save_pytorch_model(self, trainer, pl_module, tpu=False): - + if self.enabled: self.main_progress_bar.write( f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m" ) if tpu: import torch_xla.core.xla_model as xm + pl_module.model.save_pretrained(self.output_dir, save_function=xm.save) else: pl_module.model.save_pretrained(self.output_dir) From 20063f766c3840b69ecf503fcd0fc8140178531d Mon Sep 17 00:00:00 2001 From: Bill Mill Date: Wed, 21 Dec 2022 16:04:56 -0500 Subject: [PATCH 3/3] update pytorch-lightning requirement --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index f3efead..a9e23ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ transformers>=4.5.1 fire>=0.3.0 -pytorch-lightning>=1.3.1 -torch>=1.6.0 \ No newline at end of file +pytorch-lightning>=1.8.0 +torch>=1.6.0