Skip to content

Commit

Permalink
Merge pull request #202 from llimllib/fix_tpu_available
Browse files Browse the repository at this point in the history
update pytorch-lightning requirement to >= 1.8.0
  • Loading branch information
minimaxir authored Mar 16, 2023
2 parents 7424601 + 20063f7 commit 7bfbddb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
21 changes: 11 additions & 10 deletions aitextgen/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import ProgressBarBase
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.accelerators import TPUAccelerator


class ATGTransformer(pl.LightningModule):
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.train_transformers_only = train_transformers_only
self.num_layers_freeze = num_layers_freeze

@property
def save_every_check(self):
return self.save_every > 0 and self.steps % self.save_every == 0
Expand All @@ -141,15 +141,15 @@ def on_train_start(self, trainer, pl_module):
def on_train_end(self, trainer, pl_module):
self.main_progress_bar.close()
self.unfreeze_layers(pl_module)

def get_metrics(self, trainer, pl_module):
# don't show the version number
items = super().get_metrics(trainer, pl_module)
items.pop("v_num", None)
return items

def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)

# clean up the GPU cache used for the benchmark
# https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/4
Expand Down Expand Up @@ -186,19 +186,19 @@ def on_batch_end(self, trainer, pl_module):
desc += f" — GPU Mem: {gpu_memory} MB"
self.main_progress_bar.update(self.progress_bar_refresh_rate)
self.main_progress_bar.set_description(desc)
if _TPU_AVAILABLE and self.save_every_check:

if TPUAccelerator.is_available() and self.save_every_check:
did_unfreeze = False
if self.enabled:
self.unfreeze_layers(pl_module)
did_unfreeze = True
self.save_pytorch_model(trainer, pl_module, tpu=True)
if did_unfreeze:
self.freeze_layers(pl_module)

if self.enabled:
did_unfreeze = False
if not _TPU_AVAILABLE and self.save_every_check:
if not TPUAccelerator.is_available() and self.save_every_check:
self.unfreeze_layers(pl_module)
self.save_pytorch_model(trainer, pl_module)
did_unfreeze = True
Expand Down Expand Up @@ -243,13 +243,14 @@ def generate_sample_text(self, trainer, pl_module):
self.main_progress_bar.write("=" * 10)

def save_pytorch_model(self, trainer, pl_module, tpu=False):

if self.enabled:
self.main_progress_bar.write(
f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m"
)
if tpu:
import torch_xla.core.xla_model as xm

pl_module.model.save_pretrained(self.output_dir, save_function=xm.save)
else:
pl_module.model.save_pretrained(self.output_dir)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers>=4.5.1
fire>=0.3.0
pytorch-lightning>=1.3.1
torch>=1.6.0
pytorch-lightning>=1.8.0
torch>=1.6.0

0 comments on commit 7bfbddb

Please sign in to comment.