Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Pytorch 2.x and Pytorch-Lightning 2.0.7 #151

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions deepethogram/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import TQDMProgressBar
try:
from ray.tune.integration.pytorch_lightning import TuneReportCallback, \
TuneReportCheckpointCallback
Expand Down Expand Up @@ -275,8 +276,7 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s
# learning rate schedule.

if cfg.compute.batch_size == 'auto' or cfg.train.lr == 'auto':
trainer = pl.Trainer(gpus=[cfg.compute.gpu_id],
precision=16 if cfg.compute.fp16 else 32,
trainer = pl.Trainer(precision=16 if cfg.compute.fp16 else 32,
limit_train_batches=1.0,
limit_val_batches=1.0,
limit_test_batches=1.0,
Expand Down Expand Up @@ -378,13 +378,13 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s
else:
tensorboard_logger = pl.loggers.tensorboard.TensorBoardLogger(os.getcwd())
refresh_rate = 1
callback_list.append(TQDMProgressBar(refresh_rate=refresh_rate))

# tuning messes with the callbacks
try:
# will be deprecated in the future; pytorch lightning updated their kwargs for this function
# don't like how they keep updating the api without proper deprecation warnings, etc.
trainer = pl.Trainer(gpus=[cfg.compute.gpu_id],
precision=16 if cfg.compute.fp16 else 32,
trainer = pl.Trainer(precision=16 if cfg.compute.fp16 else 32,
limit_train_batches=steps_per_epoch['train'],
limit_val_batches=steps_per_epoch['val'],
limit_test_batches=steps_per_epoch['test'],
Expand All @@ -393,13 +393,11 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s
num_sanity_val_steps=0,
callbacks=callback_list,
reload_dataloaders_every_epoch=True,
progress_bar_refresh_rate=refresh_rate,
profiler=profiler,
log_every_n_steps=1)

except TypeError:
trainer = pl.Trainer(gpus=[cfg.compute.gpu_id],
precision=16 if cfg.compute.fp16 else 32,
trainer = pl.Trainer(precision=16 if cfg.compute.fp16 else 32,
limit_train_batches=steps_per_epoch['train'],
limit_val_batches=steps_per_epoch['val'],
limit_test_batches=steps_per_epoch['test'],
Expand All @@ -408,7 +406,6 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s
num_sanity_val_steps=0,
callbacks=callback_list,
reload_dataloaders_every_n_epochs=1,
progress_bar_refresh_rate=refresh_rate,
profiler=profiler,
log_every_n_steps=1)
torch.cuda.empty_cache()
Expand Down
16 changes: 8 additions & 8 deletions deepethogram/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def __init__(self):
def on_init_end(self, trainer):
log.info('on init start')

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
log.debug('on train batch start')

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
log.debug('on train batch end')

def on_train_epoch_start(self, trainer, pl_module):
Expand Down Expand Up @@ -94,16 +94,16 @@ def end_batch(self, split, batch, pl_module, eps: float = 1e-7):

pl_module.metrics.buffer.append(split, {'fps': fps})

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
self.start_timer('train')

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.end_batch('train', batch, pl_module)

def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
self.start_timer('val')

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.end_batch('val', batch, pl_module)

def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
Expand Down Expand Up @@ -204,10 +204,10 @@ def on_validation_epoch_end(self, trainer, pl_module):
def on_test_epoch_end(self, trainer, pl_module):
self.reset_cnt(pl_module, 'test')

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
pl_module.viz_cnt['train'] += 1

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
pl_module.viz_cnt['val'] += 1

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
install_requires=[
'chardet<4.0', 'h5py', 'kornia>=0.5', 'matplotlib', 'numpy', 'omegaconf>=2',
'opencv-python-headless', 'opencv-transforms', 'pandas<1.4', 'PySide2', 'scikit-learn<1.1',
'scipy<1.8', 'tqdm', 'vidio', 'pytorch_lightning>=1.5.10'
'scipy<1.8', 'tqdm', 'vidio', 'pytorch_lightning>=2.0.7'
])