Skip to content

Commit

Permalink
Training files
Browse files Browse the repository at this point in the history
  • Loading branch information
FeiiYin committed Aug 24, 2022
1 parent b4cf181 commit bad7f12
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 1 deletion.
2 changes: 1 addition & 1 deletion configs/video_warper_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ trainer:


model:
type: models.styleheat.mirror_warper::MirrorWarper
type: models.styleheat.warper::VideoWarper
mode: train_video_warper
optimized_param: all
from_scratch_param: all
Expand Down
63 changes: 63 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import importlib

import torch.utils.data
from utils.distributed import master_only_print as print


def find_dataset_using_name(dataset_name):
dataset_filename = dataset_name
module, target = dataset_name.split('::')
datasetlib = importlib.import_module(module)
dataset = None
for name, cls in datasetlib.__dict__.items():
if name == target:
dataset = cls

if dataset is None:
raise ValueError("In %s.py, there should be a class "
"with class name that matches %s in lowercase." %
(dataset_filename, target))
return dataset


def get_option_setter(dataset_name):
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options


def create_dataloader(opt, is_inference):
dataset = find_dataset_using_name(opt.type)
instance = dataset(opt, is_inference)
phase = 'val' if is_inference else 'training'
batch_size = opt.val.batch_size if is_inference else opt.train.batch_size
print("%s dataset [%s] of size %d was created" %
(phase, opt.type, len(instance)))
dataloader = torch.utils.data.DataLoader(
instance,
batch_size=batch_size,
sampler=data_sampler(instance, shuffle=not is_inference, distributed=opt.train.distributed),
drop_last=not is_inference,
num_workers=getattr(opt, 'num_workers', 0),
)

return dataloader


def data_sampler(dataset, shuffle, distributed):
if distributed:
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return torch.utils.data.RandomSampler(dataset)
else:
return torch.utils.data.SequentialSampler(dataset)


def get_dataloader(opt, is_inference=False):
dataset = create_dataloader(opt, is_inference=is_inference)
return dataset


def get_train_val_dataloader(opt):
val_dataset = create_dataloader(opt, is_inference=True)
train_dataset = create_dataloader(opt, is_inference=False)
return val_dataset, train_dataset
Binary file removed third_part/Deep3DFaceRecon_pytorch/temp/2.png
Binary file not shown.
89 changes: 89 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import argparse
import data as Dataset

from configs.config import Config
from utils.logging import init_logging, make_logging_dir
from utils.trainer import get_model_optimizer_and_scheduler_with_pretrain, set_random_seed, get_trainer, get_model_optimizer_and_scheduler
from utils.distributed import init_dist
from utils.distributed import master_only_print as print


def parse_args():
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--config', required=True)
parser.add_argument('--name', required=True)
parser.add_argument('--checkpoints_dir', default='result', help='Dir for saving logs and models.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--which_iter', type=int, default=None)
parser.add_argument('--no_resume', action='store_true')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--single_gpu', action='store_true')
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
return args


def main():
# get training options
args = parse_args()
set_random_seed(args.seed)

opt = Config(args.config, args, is_train=True)

if not args.single_gpu:
opt.local_rank = args.local_rank
init_dist(opt.local_rank)
opt.device = opt.local_rank
print('Distributed DataParallel Training.')
else:
print('Single GPU Training.')
opt.device = 'cuda'
opt.local_rank = 0
opt.distributed = False
opt.data.train.distributed = False
opt.data.val.distributed = False

# create a visualizer
date_uid, logdir = init_logging(opt)
opt.logdir = logdir
make_logging_dir(logdir, date_uid)
os.system(f'cp {args.config} {opt.logdir}')
# create a dataset
val_dataset, train_dataset = Dataset.get_train_val_dataloader(opt.data)

# create a model
net_G, net_G_ema, opt_G, sch_G = get_model_optimizer_and_scheduler_with_pretrain(opt)

trainer = get_trainer(opt, net_G, net_G_ema, opt_G, sch_G, train_dataset)
current_epoch, current_iteration = trainer.load_checkpoint(opt, args.which_iter)

# training flag
if args.debug:
trainer.test_everything(train_dataset, val_dataset, current_epoch, current_iteration)
exit()

# Start training.
for epoch in range(current_epoch, opt.max_epoch):
print('Epoch {} ...'.format(epoch))
if not args.single_gpu:
train_dataset.sampler.set_epoch(current_epoch)
trainer.start_of_epoch(current_epoch)
for it, data in enumerate(train_dataset):
data = trainer.start_of_iteration(data, current_iteration)
trainer.optimize_parameters(data)
current_iteration += 1
trainer.end_of_iteration(data, current_epoch, current_iteration)

if current_iteration >= opt.max_iter:
print('Done with training!!!')
break
current_epoch += 1
trainer.end_of_epoch(data, val_dataset, current_epoch, current_iteration)
trainer.test(val_dataset, output_dir=os.path.join(logdir, 'evaluation'), test_limit=10)


if __name__ == '__main__':
main()


0 comments on commit bad7f12

Please sign in to comment.