diff --git a/big_vision/configs/bit_i1k.py b/big_vision/configs/bit_i1k.py index 485678b..0590eaa 100644 --- a/big_vision/configs/bit_i1k.py +++ b/big_vision/configs/bit_i1k.py @@ -31,20 +31,23 @@ def get_config(runlocal=False): """Config for training on ImageNet-1k.""" config = mlc.ConfigDict() - config.dataset = 'imagenet2012' - config.train_split = 'train[:99%]' - config.cache_raw = not runlocal # Needs up to 120GB of RAM! - config.shuffle_buffer_size = 250_000 if not runlocal else 10_000 # Per host. + config.seed = 0 + config.total_epochs = 90 config.num_classes = 1000 config.loss = 'softmax_xent' - config.seed = 0 - config.batch_size = 4096 if not runlocal else 32 - config.total_epochs = 90 + config.input = dict() + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 4096 if not runlocal else 32 + config.input.cache_raw = not runlocal # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 if not runlocal else 10_000 # Per host. pp_common = '|onehot(1000, key="{lbl}", key_result="labels")' pp_common += '|value_range(-1, 1)|keep("image", "labels")' - config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label') + config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label') pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.log_training_steps = 50 @@ -62,30 +65,29 @@ def get_config(runlocal=False): config.grad_clip_norm = 1.0 # linear scaling rule. Don't forget to sweep if sweeping batch_size. - config.wd = (1e-4 / 256) * config.batch_size - config.lr = (0.1 / 256) * config.batch_size + config.wd = (1e-4 / 256) * config.input.batch_size + config.lr = (0.1 / 256) * config.input.batch_size config.schedule = dict(decay_type='cosine', warmup_steps=1000) # Eval section - eval_common = dict( - type='classification', - dataset='imagenet2012', - pp_fn=pp_eval.format(lbl='label'), - loss_name=config.loss, - log_steps=1000, # Very fast O(seconds) so it's fine to run it often. - ) + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=1000, # Very fast O(seconds) so it's fine to run it often. + cache_final=not runlocal, + ) config.evals = {} - config.evals.train = {**eval_common, 'split': 'train[:2%]'} - config.evals.minival = {**eval_common, 'split': 'train[99%:]'} - config.evals.val = {**eval_common, 'split': 'validation'} - config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'} - - config.evals.real = dict(**eval_common) - config.evals.real.dataset = 'imagenet2012_real' - config.evals.real.split = 'validation' + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') config.evals.real.pp_fn = pp_eval.format(lbl='real_label') - # config.fewshot = get_fewshot_lsr() - # config.fewshot.log_steps = 1000 + # config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal) + # config.evals.fewshot.log_steps = 1000 return config \ No newline at end of file diff --git a/big_vision/configs/bit_i21k.py b/big_vision/configs/bit_i21k.py index 02f005f..efce655 100644 --- a/big_vision/configs/bit_i21k.py +++ b/big_vision/configs/bit_i21k.py @@ -28,23 +28,25 @@ def get_config(): """Config for training on imagenet-21k.""" config = mlc.ConfigDict() - config.dataset = 'imagenet21k' - config.train_split = 'full[51200:]' + config.seed = 0 + config.total_epochs = 90 config.num_classes = 21843 config.init_head_bias = -10.0 config.loss = 'sigmoid_xent' - config.trial = 0 - config.batch_size = 4096 - config.total_epochs = 90 + config.input = dict() + config.input.data = dict( + name='imagenet21k', + split='full[51200:]', + ) + config.input.batch_size = 4096 + config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")' pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}') pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"') - config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k - pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common_i21k - pp_eval_i1k = 'decode|resize_small(256)|central_crop(224)' + pp_common_i1k - config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. + config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k + pp_eval = 'decode|resize_small(256)|central_crop(224)' config.log_training_steps = 50 config.ckpt_steps = 1000 @@ -58,22 +60,23 @@ def get_config(): config.grad_clip_norm = 1.0 # linear scaling rule. Don't forget to sweep if sweeping batch_size. - config.lr = (0.03 / 256) * config.batch_size - config.wd = (3e-5 / 256) * config.batch_size + config.lr = (0.03 / 256) * config.input.batch_size + config.wd = (3e-5 / 256) * config.input.batch_size config.schedule = dict(decay_type='cosine', warmup_steps=5000) - # Eval section - eval_common = dict( - type='classification', - dataset=config.dataset, - pp_fn=pp_eval, - loss_name=config.loss, - log_steps=1000, # Very fast O(seconds) so it's fine to run it often. - ) + # Evaluations on i21k itself. + def eval_i21k(split): + return dict( + type='classification', + data={**config.input.data, 'split': split}, + pp_fn=pp_eval + pp_common_i21k, + loss_name=config.loss, + log_steps=1000, # Very fast O(seconds) so it's fine to run it often. + ) config.evals = {} - config.evals.test = {**eval_common, 'split': 'full[:25_600]'} - config.evals.val = {**eval_common, 'split': 'full[25_600:51_200]'} - config.evals.train = {**eval_common, 'split': 'full[51_200:76_800]'} + config.evals.test = eval_i21k('full[:25_600]') + config.evals.val = eval_i21k('full[25_600:51_200]') + config.evals.train = eval_i21k('full[51_200:76_800]') # Few-shot evaluators config.evals.fewshot = get_fewshot_lsr() diff --git a/big_vision/configs/common_fewshot.py b/big_vision/configs/common_fewshot.py index d90a95a..176c2a2 100644 --- a/big_vision/configs/common_fewshot.py +++ b/big_vision/configs/common_fewshot.py @@ -39,7 +39,7 @@ def get_fewshot_lsr(target_resolution=224, resize_resolution=256, } config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)|keep("image", "label")' config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)|keep("image", "label")' - config.shots = [1, 5, 10, 25] + config.shots = (1, 5, 10, 25) config.l2_reg = 2.0 ** 10 config.num_seeds = 3 config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)] diff --git a/big_vision/configs/load_and_eval.py b/big_vision/configs/load_and_eval.py index fa219a5..44fba68 100644 --- a/big_vision/configs/load_and_eval.py +++ b/big_vision/configs/load_and_eval.py @@ -36,6 +36,7 @@ import big_vision.configs.common as bvcc from big_vision.configs.common_fewshot import get_fewshot_lsr +from big_vision.configs.proj.image_text import lit_eval import ml_collections as mlc diff --git a/big_vision/configs/mlp_mixer_i1k.py b/big_vision/configs/mlp_mixer_i1k.py index aa1938f..5bf8e58 100644 --- a/big_vision/configs/mlp_mixer_i1k.py +++ b/big_vision/configs/mlp_mixer_i1k.py @@ -31,14 +31,22 @@ def get_config(mode=None): """Config for training Mixer on i1k.""" config = mlc.ConfigDict() - config.dataset = 'imagenet2012' - config.train_split = 'train[:99%]' - config.cache_raw = True # Needs up to 120GB of RAM! + config.seed = 0 + config.total_epochs = 300 config.num_classes = 1000 - config.init_head_bias = -6.9 config.loss = 'sigmoid_xent' + config.init_head_bias = -6.9 + + config.input = dict() + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 4096 + config.input.cache_raw = True # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 - config.pp_train = ( + config.input.pp = ( 'decode_jpeg_and_inception_crop(224)' '|flip_lr' '|randaug(2,15)' @@ -46,7 +54,7 @@ def get_config(mode=None): '|onehot(1000, key="label", key_result="labels")' '|keep("image", "labels")' ) - ppv = ( + pp_eval = ( 'decode' '|resize_small(256)|central_crop(224)' '|value_range(-1, 1)' @@ -54,14 +62,8 @@ def get_config(mode=None): '|keep("image", "labels")' ) - config.batch_size = 4096 - config.total_epochs = 300 - - config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. - config.log_training_steps = 50 config.ckpt_steps = 1000 - config.ckpt_timeout = 1 config.prefetch_to_device = 2 @@ -86,30 +88,29 @@ def get_config(mode=None): ) # Eval section - eval_common = dict( - type='classification', - dataset='imagenet2012', - pp_fn=ppv.format(lbl='label'), - loss_name=config.loss, - log_steps=2500, # Very fast O(seconds) so it's fine to run it often. - ) + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + cache_final=mode != 'gpu8', + ) config.evals = {} - config.evals.train = {**eval_common, 'split': 'train[:2%]'} - config.evals.minival = {**eval_common, 'split': 'train[99%:]'} - config.evals.val = {**eval_common, 'split': 'validation'} - config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'} - - config.evals.real = dict(**eval_common) - config.evals.real.dataset = 'imagenet2012_real' - config.evals.real.split = 'validation' - config.evals.real.pp_fn = ppv.format(lbl='real_label') + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') config.fewshot = get_fewshot_lsr() if mode == 'gpu8': config.total_epochs = 60 - config.batch_size = 512 - config.cache_raw = False + config.input.batch_size = 512 + config.input.cache_raw = False if mode == 'regression_test': config.total_epochs = 60 diff --git a/big_vision/configs/proj/distill/bigsweep_flowers_pet.py b/big_vision/configs/proj/distill/bigsweep_flowers_pet.py index 4a96bc8..97a7456 100644 --- a/big_vision/configs/proj/distill/bigsweep_flowers_pet.py +++ b/big_vision/configs/proj/distill/bigsweep_flowers_pet.py @@ -50,18 +50,21 @@ def get_config(arg=None): arg = bvcc.parse_arg(arg, runlocal=False, data='flowers', variant='medium', crop='inception_crop(128)') config = mlc.ConfigDict() - config.dataset = dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data] - config.cache_raw = True + config.input = {} + config.input.data = dict( + name=dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data], + split=dict(flowers='train', pet='train[:90%]')[arg.data], + ) + config.input.batch_size = 512 + config.input.cache_raw = True + config.input.shuffle_buffer_size = 50_000 config.prefetch_to_device = 4 - config.train_split = dict(flowers='train', pet='train[:90%]')[arg.data] - config.num_classes = NCLS[arg.data] - config.batch_size = 512 - config.num_epochs = { + config.num_classes = NCLS[arg.data] + config.total_epochs = { 'flowers': {'fast': 10_000, 'medium': 100_000, 'long': 1_000_000}, 'pet': {'fast': 1000, 'medium': 3000, 'long': 30_000}, }[arg.data][arg.variant] - config.shuffle_buffer_size = 50_000 config.log_training_steps = 100 config.ckpt_steps = 2500 @@ -81,7 +84,7 @@ def get_config(arg=None): f'|onehot({config.num_classes}, key="label", key_result="labels")' '|keep("image", "labels")' ) - config.pp_train = f'decode|{arg.crop}|flip_lr' + pp_common + config.input.pp = f'decode|{arg.crop}|flip_lr' + pp_common ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common config.mixup = dict(p=1.0, n=2) @@ -118,18 +121,19 @@ def get_config(arg=None): val_split = 'train[90%:]' if not arg.runlocal else 'train[:16]' test_split = 'test' if not arg.runlocal else 'test[:16]' - base = dict( - type='classification', - pred='student_fwd', - dataset=config.dataset, - pp_fn=ppv, - loss_name='softmax_xent', - log_steps=500, - ) + def get_eval(split): + return dict( + type='classification', + pred='student_fwd', + data=dict(name=config.input.data.name, split=split), + pp_fn=ppv, + loss_name='softmax_xent', + log_steps=500, + ) config.evals = {} - config.evals.student_train = {**base, 'split': minitrain_split} - config.evals.student_val = {**base, 'split': val_split} - config.evals.student_test = {**base, 'split': test_split} + config.evals.student_train = get_eval(minitrain_split) + config.evals.student_val = get_eval(val_split) + config.evals.student_test = get_eval(test_split) # Teacher is fixed, so rare evals. teacher = dict(log_steps=100_000, pred='prof_m_fwd') @@ -138,22 +142,23 @@ def get_config(arg=None): config.evals.teacher_test = {**config.evals.student_test, **teacher} # Could in principle also look at agreement on other datasets! - dist = dict( - type='proj.distill.distance', - pred='student_prof_m_fwd', - dataset=config.dataset, - pp_fn=ppv + '|keep("image")', - log_steps=1000, - distances=({'kind': 'kl'}, {'kind': 'euclidean'}, - {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), - ) - config.evals.dist_train = {**dist, 'split': minitrain_split} - config.evals.dist_val = {**dist, 'split': val_split} - config.evals.dist_test = {**dist, 'split': test_split} + def get_dist(split): + return dict( + type='proj.distill.distance', + pred='student_prof_m_fwd', + data=dict(name=config.input.data.name, split=split), + pp_fn=ppv + '|keep("image")', + log_steps=1000, + distances=({'kind': 'kl'}, {'kind': 'euclidean'}, + {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), + ) + config.evals.dist_train = get_dist(minitrain_split) + config.evals.dist_val = get_dist(val_split) + config.evals.dist_test = get_dist(test_split) # Make a few things much smaller for quick local debugging testruns. if arg.runlocal: - config.shuffle_buffer_size = 10 - config.batch_size = 8 + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 return config \ No newline at end of file diff --git a/big_vision/configs/proj/distill/bigsweep_food_sun.py b/big_vision/configs/proj/distill/bigsweep_food_sun.py index 26b05ff..93fafeb 100644 --- a/big_vision/configs/proj/distill/bigsweep_food_sun.py +++ b/big_vision/configs/proj/distill/bigsweep_food_sun.py @@ -40,7 +40,7 @@ import ml_collections as mlc H, L = 160, 128 -CLS = dict(food=101, sun=397) +NCLS = dict(food=101, sun=397) def get_config(arg=None): @@ -48,15 +48,18 @@ def get_config(arg=None): arg = bvcc.parse_arg(arg, runlocal=False, data='food', variant='medium', crop='inception_crop(128)') config = mlc.ConfigDict() - config.dataset = dict(food='food101', sun='sun397')[arg.data] - config.cache_raw = True + config.input = {} + config.input.data = dict( + name=dict(food='food101', sun='sun397')[arg.data], + split=dict(food='train[:90%]', sun='train')[arg.data], + ) + config.input.batch_size = 512 + config.input.cache_raw = True + config.input.shuffle_buffer_size = 50_000 config.prefetch_to_device = 4 - config.train_split = dict(food='train[:90%]', sun='train')[arg.data] - config.num_classes = CLS[arg.data] - config.batch_size = 512 - config.num_epochs = {'fast': 100, 'medium': 1000, 'long': 3000}[arg.variant] - config.shuffle_buffer_size = 50_000 + config.num_classes = NCLS[arg.data] + config.total_epochs = {'fast': 100, 'medium': 1000, 'long': 3000}[arg.variant] config.log_training_steps = 50 config.ckpt_steps = 2500 @@ -76,7 +79,7 @@ def get_config(arg=None): f'|onehot({config.num_classes}, key="label", key_result="labels")' '|keep("image", "labels")' ) - config.pp_train = f'decode|{arg.crop}|flip_lr' + pp_common + config.input.pp = f'decode|{arg.crop}|flip_lr' + pp_common ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common config.mixup = dict(p=1.0, n=2) @@ -113,18 +116,19 @@ def get_config(arg=None): val_split = 'validation' if not arg.runlocal else 'validation[:16]' test_split = 'test' if not arg.runlocal else 'test[:16]' - base = dict( - type='classification', - pred='student_fwd', - dataset=config.dataset, - pp_fn=ppv, - loss_name='softmax_xent', - log_steps=500, - ) + def get_eval(split): + return dict( + type='classification', + pred='student_fwd', + data=dict(name=config.input.data.name, split=split), + pp_fn=ppv, + loss_name='softmax_xent', + log_steps=500, + ) config.evals = {} - config.evals.student_train = {**base, 'split': minitrain_split} - config.evals.student_val = {**base, 'split': val_split} - config.evals.student_test = {**base, 'split': test_split} + config.evals.student_train = get_eval(minitrain_split) + config.evals.student_val = get_eval(val_split) + config.evals.student_test = get_eval(test_split) # Teacher is fixed, so rare evals. teacher = dict(log_steps=100_000, pred='prof_m_fwd') @@ -133,33 +137,35 @@ def get_config(arg=None): config.evals.teacher_test = {**config.evals.student_test, **teacher} # Could in principle also look at agreement on other datasets! - dist = dict( - type='proj.distill.distance', - pred='student_prof_m_fwd', - dataset=config.dataset, - pp_fn=ppv + '|keep("image")', - log_steps=1000, - distances=({'kind': 'kl'}, {'kind': 'euclidean'}, - {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), - ) - config.evals.dist_train = {**dist, 'split': minitrain_split} - config.evals.dist_val = {**dist, 'split': val_split} - config.evals.dist_test = {**dist, 'split': test_split} + def get_dist(split): + return dict( + type='proj.distill.distance', + pred='student_prof_m_fwd', + data=dict(name=config.input.data.name, split=split), + pp_fn=ppv + '|keep("image")', + log_steps=1000, + distances=({'kind': 'kl'}, {'kind': 'euclidean'}, + {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), + ) + config.evals.dist_train = get_dist(minitrain_split) + config.evals.dist_val = get_dist(val_split) + config.evals.dist_test = get_dist(test_split) # Make a few things much smaller for quick local debugging testruns. if arg.runlocal: - config.shuffle_buffer_size = 10 - config.batch_size = 8 + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 return config def get_hyper(hyper): """Hyper sweep.""" + # TODO: update, similar to flowers_pet sweep. # By default, not running the MASSIVE sweep, just the recommended setting # across durations. However, code for sweep is left for reference/convenience. return hyper.zipit([ - hyper.sweep('config.num_epochs', [100, 1_000]), + hyper.sweep('config.total_epochs', [100, 1_000]), hyper.sweep('config.mixup.p', [0.0, 1.0]), hyper.sweep('config.weight_decay', [1e-3, 1e-5]), ]) @@ -173,7 +179,7 @@ def fix(**kw): def setting(p, l, m, crop, pp_end=None, **extra): pp_end = pp_end or ( f'|value_range(-1, 1, key="image")' - f'|onehot({CLS}, key="label", key_result="labels")' + f'|onehot({NCLS}, key="label", key_result="labels")' f'|keep("image", "labels")' ) return hyper.product([ @@ -186,7 +192,7 @@ def setting(p, l, m, crop, pp_end=None, **extra): # Mixup, Layers and Mag in randaug. plm = [(0.0, 0, 0), (0.1, 0, 0), (0.5, 0, 0), (1.0, 0, 0)] return hyper.product([ - hyper.sweep('config.num_epochs', [100, 1000, 3000]), + hyper.sweep('config.total_epochs', [100, 1000, 3000]), hyper.sweep('config.lr.base', [0.001, 0.003, 0.01]), hyper.sweep('config.distance_kw.t', [1.0, 2.0, 5.0, 10.0]), hyper.sweep('config.weight_decay', [1e-5, 3e-5, 1e-4, 3e-4, 1e-3]), @@ -198,7 +204,7 @@ def setting(p, l, m, crop, pp_end=None, **extra): pp_end=( f'|value_range(-1, 1, key="student")' f'|value_range(-1, 1, key="teacher")' - f'|onehot({CLS}, key="label", key_result="labels")' + f'|onehot({NCLS}, key="label", key_result="labels")' f'|keep("student", "teacher", "labels")')) for p, l, m in plm] + [setting(p=p, l=l, m=m, crop=f'inception_crop({L})') for diff --git a/big_vision/configs/proj/distill/bit_i1k.py b/big_vision/configs/proj/distill/bit_i1k.py index 444eb73..c435edc 100644 --- a/big_vision/configs/proj/distill/bit_i1k.py +++ b/big_vision/configs/proj/distill/bit_i1k.py @@ -25,7 +25,7 @@ big_vision.trainers.proj.distill.distill \ --config big_vision/configs/proj/distill/bit_i1k.py \ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ - --config.num_epochs 1200 + --config.total_epochs 1200 """ import big_vision.configs.common as bvcc @@ -39,13 +39,13 @@ def get_config(arg=None): arg = bvcc.parse_arg(arg, runlocal=False) config = mlc.ConfigDict() - config.dataset = 'imagenet2012' - config.train_split = 'train[:98%]' - config.num_classes = 1000 + config.input = {} + config.input.data = dict(name='imagenet2012', split='train[:98%]') + config.input.batch_size = 4096 + config.input.shuffle_buffer_size = 250_000 - config.batch_size = 4096 - config.num_epochs = 1200 # A good middle-ground - config.shuffle_buffer_size = 250_000 + config.num_classes = 1000 + config.total_epochs = 1200 # A good middle-ground config.log_training_steps = 50 config.ckpt_steps = 1000 @@ -67,11 +67,11 @@ def get_config(arg=None): '|onehot(1000, key="{lbl}", key_result="labels")' '|keep("image", "labels")' ) - config.pp_train = ( + config.input.pp = ( 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label') ) - ppv = 'decode|{crop}' + pp_common + ppv = 'decode|resize_small(256)|central_crop(224)' + pp_common config.mixup = dict(p=1.0, n=2) @@ -95,26 +95,23 @@ def get_config(arg=None): real_split = 'validation' if not arg.runlocal else 'validation[:16]' v2_split = 'test' if not arg.runlocal else 'test[:16]' - base = dict( - type='classification', - pred='student_fwd', - dataset='imagenet2012', - pp_fn=ppv.format(lbl='label', crop='resize_small(256)|central_crop(224)'), - loss_name='softmax_xent', - log_steps=1000, - ) + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + pred='student_fwd', + data=dict(name=dataset, split=split), + pp_fn=ppv.format(lbl='label'), + loss_name='softmax_xent', + log_steps=1000, + ) config.evals = {} - config.evals.student_train = {**base, 'split': minitrain_split} - config.evals.student_minival = {**base, 'split': minival_split} - config.evals.student_val = {**base, 'split': val_split} - config.evals.student_v2 = {**base, 'dataset': 'imagenet_v2', 'split': v2_split} - - config.evals.student_real = dict(**base) - config.evals.student_real.dataset = 'imagenet2012_real' - config.evals.student_real.split = real_split - config.evals.student_real.pp_fn = ppv.format( - lbl='real_label', crop='resize_small(256)|central_crop(224)') + config.evals.student_train = get_eval(minitrain_split) + config.evals.student_minival = get_eval(minival_split) + config.evals.student_val = get_eval(val_split) + config.evals.student_v2 = get_eval(v2_split, dataset='imagenet_v2') + config.evals.student_real = get_eval(real_split, dataset='imagenet2012_real') + config.evals.student_real.pp_fn = ppv.format(lbl='real_label') config.evals.student_fewshot = get_fewshot_lsr(runlocal=arg.runlocal) config.evals.student_fewshot.pred = 'student_fwd' @@ -133,23 +130,38 @@ def get_config(arg=None): config.evals.teacher_fewshot.prefix = 'z_teacher/' # Could in principle also look at agreement on other datasets! - dist = dict( - type='proj.distill.distance', - pred='student_prof_m_fwd', - dataset='imagenet2012', - pp_fn=ppv.format(lbl='label', crop='resize_small(256)|central_crop(224)') + '|keep("image")', - log_steps=1000, - distances=({'kind': 'kl'}, {'kind': 'euclidean'}, - {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), - ) - config.evals.dist_train = {**dist, 'split': minitrain_split} - config.evals.dist_minival = {**dist, 'split': minival_split} - config.evals.dist_val = {**dist, 'split': val_split} - config.evals.dist_v2 = {**dist, 'split': v2_split} + def get_dist(split, dataset='imagenet2012'): + return dict( + type='proj.distill.distance', + pred='student_prof_m_fwd', + data=dict(name=dataset, split=split), + pp_fn=ppv.format(lbl='label') + '|keep("image")', + log_steps=1000, + distances=({'kind': 'kl'}, {'kind': 'euclidean'}, + {'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}), + ) + config.evals.dist_train = get_dist(minitrain_split) + config.evals.dist_minival = get_dist(minival_split) + config.evals.dist_val = get_dist(val_split) + config.evals.dist_v2 = get_dist(v2_split, dataset='imagenet_v2') + + # NOTE: CKA evaluator does not work with batch padding, so the size of the + # split must be a multiple of the batch size. + def get_cka(split): + return dict( + type='proj.distill.cka', + pred='student_prof_m_fwd', + data=dict(name='imagenet2012', split=split), + pp_fn=ppv.format(lbl='label') + '|keep("image")', + log_steps=1000, + ) + config.evals.cka_train = get_cka('train[:24576]' if not arg.runlocal else 'train[:16]') + config.evals.cka_minival = get_cka('train[-24576:]' if not arg.runlocal else 'train[:16]') + config.evals.cka_val = get_cka('validation[:49152]' if not arg.runlocal else 'validation[:16]') # Make a few things much smaller for quick local debugging testruns. if arg.runlocal: - config.shuffle_buffer_size = 10 - config.batch_size = 8 + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 return config \ No newline at end of file diff --git a/big_vision/configs/proj/image_text/README.md b/big_vision/configs/proj/image_text/README.md index b823efb..ec1381d 100644 --- a/big_vision/configs/proj/image_text/README.md +++ b/big_vision/configs/proj/image_text/README.md @@ -29,15 +29,18 @@ Colabs: | :--- | :---: | :---: | :---: | :---: | :--- | | mixed_L16L | [link](https://storage.googleapis.com/vit_models/lit/LiT-L16L.npz) | 75.7 | 48.5 | 31.2 | `txt=bert_large,img=L/16` | | mixed_B16B | [link](https://storage.googleapis.com/vit_models/lit/LiT-B16B.npz) | 72.1 | 49.4 | 31.1 | `txt=bert_base,img=B/16,img_head` | +| mixed_B16B_2 | [link](https://storage.googleapis.com/vit_models/lit/LiT-B16B.npz) | 73.9 | 51.5 | 31.8 | `txt=bert_base,img=B/16` | | coco_B16B | [link](https://storage.googleapis.com/vit_models/lit/big_vision/coco_B16B/checkpoint.npz) | 20.7 | 47.2 | 32.1 | `txt=bert_base,img=B/16` | -The first two rows are the best available models trained on open source data, +The first three rows are the best available models trained on open source data, originally published in the [`google-research/vision_transformer`] repository. These models were re-evaluated with this codebase using the following commands: ```bash big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img=B/16,img_head,init=gs://vit_models/lit/LiT-B16B.npz +big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_base,img=B/16_2,init=gs://vit_models/lit/LiT-B16B_2.npz + big_vision.tools.eval_only --config big_vision/configs/proj/image_text/lit_coco.py:txt=bert_large,img=L/16,init=gs://vit_models/lit/LiT-L16L.npz ``` @@ -52,4 +55,11 @@ can be used to verify correctness of the codebase [`CC12M`]: https://arxiv.org/abs/2102.08981 [`YFCC100M`]: https://arxiv.org/abs/1503.01817 [`tfds`]: https://www.tensorflow.org/datasets/api_docs/python/tfds -[`coco_captions`]: https://www.tensorflow.org/datasets/catalog/coco_captions \ No newline at end of file +[`coco_captions`]: https://www.tensorflow.org/datasets/catalog/coco_captions + + +### Changelog + +- 2022-08-18: Added LiT-B16B_2 model that was trained for 60k steps + (LiT_B16B: 30k) without linear head on the image side (LiT_B16B: 768) and has + better performance. diff --git a/big_vision/configs/proj/image_text/common.py b/big_vision/configs/proj/image_text/common.py index b1b86a4..a2a55c4 100644 --- a/big_vision/configs/proj/image_text/common.py +++ b/big_vision/configs/proj/image_text/common.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Common config code for LiT models.""" +"""Snippets and constants used a lot in image-text configs.""" + +import ml_collections + # pylint: disable=line-too-long inits = { @@ -27,3 +30,32 @@ 'L/16': ('L/16', 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz'), } # pylint: enable=line-too-long + + +def get_coco( + *, + pp_img='resize(224)|value_range(-1, 1)', + pp_txt='tokenize(max_len=16, inkey="texts", eos="sticky", pad_value=1)', + prefix='z/retr/coco_', + log_steps): + """Returns config for mscoco retrieval zero-shot. + + Args: + pp_img: Pre-processing string for "image" feature. + pp_txt: Pre-processing string for texts (expected to tokenize "texts" to + "labels"). + prefix: Prefix to use for metrics. + log_steps: How often the evaluators should be run. + + Returns: + `ConfigDict` that can be used as a retrieval evaluator configuration. + """ + return ml_collections.ConfigDict({ + 'type': 'proj.image_text.retrieval', + 'log_steps': log_steps, + 'pp_txt': pp_txt, + 'pp_img': pp_img, + 'prefix': prefix, + 'dataset': 'coco_captions', + 'txt_name': ('captions', 'text'), + }) diff --git a/big_vision/configs/proj/image_text/common_retrieval.py b/big_vision/configs/proj/image_text/common_retrieval.py deleted file mode 100644 index 236ef7c..0000000 --- a/big_vision/configs/proj/image_text/common_retrieval.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2022 Big Vision Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Common retrieval configuration.""" - -import ml_collections - - -def get_coco( - *, - pp_img='resize(224)|value_range(-1, 1)', - pp_txt='tokenize(max_len=16, inkey="texts", eos="sticky", pad_value=1)', - prefix='z/retr/coco_', - log_steps): - """Returns config for mscoco retrieval zero-shot. - - Args: - pp_img: Pre-processing string for "image" feature. - pp_txt: Pre-processing string for texts (expected to tokenize "texts" to - "labels"). - prefix: Prefix to use for metrics. - log_steps: How often the evaluators should be run. - - Returns: - `ConfigDict` that can be used as a retrieval evaluator configuration. - """ - return ml_collections.ConfigDict({ - 'type': 'proj.image_text.retrieval', - 'log_steps': log_steps, - 'pp_txt': pp_txt, - 'pp_img': pp_img, - 'prefix': prefix, - 'dataset': 'coco_captions', - 'txt_name': ('captions', 'text'), - }) diff --git a/big_vision/configs/proj/image_text/lit_coco.py b/big_vision/configs/proj/image_text/lit_coco.py index 5c363d9..e7233dc 100644 --- a/big_vision/configs/proj/image_text/lit_coco.py +++ b/big_vision/configs/proj/image_text/lit_coco.py @@ -44,8 +44,7 @@ """ import big_vision.configs.common as bvcc -from big_vision.configs.proj.image_text import common as cl -from big_vision.configs.proj.image_text import common_retrieval +from big_vision.configs.proj.image_text import common from ml_collections import ConfigDict @@ -54,38 +53,36 @@ def get_config(arg=None): arg = bvcc.parse_arg( arg, res=224, runlocal=False, token_len=16, txt='bert_base', img='B/16', init='', img_head=False) - img_name, img_init = cl.inits[arg.img] - txt_name, txt_init = cl.inits[arg.txt] + img_name, img_init = common.inits[arg.img] + txt_name, txt_init = common.inits[arg.txt] config = ConfigDict() - config.batch_size = 4096*1 if not arg.runlocal else 32 - # TODO update config to use YFCC100M, CC12M from tfds - config.dataset = 'coco_captions' - config.train_split = 'train' + config.input = {} + config.input.data = dict(name='coco_captions', split='train') + config.input.batch_size = 4096 if not arg.runlocal else 32 + config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50 + config.total_steps = 5_000 if not arg.runlocal else 1 config.init_shapes = [(1, arg.res, arg.res, 3), (1, arg.token_len,)] config.init_types = ['float32', 'int32'] if arg.init: - vocab_path = '.'.join(arg.init.split('.')[:-1]) + '.txt' + vocab_path = arg.init.rsplit('.', 1)[0] + '.txt' else: vocab_path = f'{txt_init}/vocab.txt' tokenizer = lambda inkey: ( f'bert_tokenize(inkey="{inkey}", max_len={arg.token_len}, ' f'vocab_path="{vocab_path}")') - config.pp_train = pp_eval = ( + config.input.pp = pp_eval = ( f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)' f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")' ) config.pp_modules = [ 'ops_general', 'ops_image', 'ops_text', 'proj.flaxformer.bert_ops'] - config.pp_img = f'resize({arg.res})|value_range(-1,1)|keep("image")' - config.pp_txt = tokenizer('label') + '|keep("labels")' - config.shuffle_buffer_size = 250_000 if not arg.runlocal else 50 config.log_training_steps = 50 - config.checkpoint_steps = 1000 + config.ckpt_steps = 1000 # Model section config.model_name = 'proj.image_text.two_towers' @@ -117,15 +114,12 @@ def get_config(arg=None): config.optax_name = 'scale_by_adam' else: config.optax_name = 'big_vision.scale_by_adafactor' - # Gather representations across TPU cores for larger batch size for loss. - # Generally helps: ((internal link) - config.loss_use_global_batch = True config.lr = 0.001 config.wd = 0.01 warmup_steps = max(int(0.03 * config.total_steps), 100) config.schedule = [ - ('img/.*', None), + ('img/.*', None), # Freezes image tower. ('.*', dict(decay_type='cosine', warmup_steps=warmup_steps)), ] @@ -138,28 +132,25 @@ def get_config(arg=None): eval_common = dict( type='proj.image_text.contrastive', use_global_batch=config.loss_use_global_batch, - log_steps=500, + log_steps=500 if not arg.runlocal else 5, ) config.evals = {} sub = '[:4]' if arg.runlocal else '' config.evals.val = { **eval_common, - 'split': f'val{sub}', - 'dataset': config.dataset, + 'data': dict(name=config.input.data.name, split=f'val{sub}'), 'pp_fn': pp_eval, } config.evals.coco = { **eval_common, - 'dataset': 'coco_captions', - 'split': f'val{sub}', + 'data': dict(name='coco_captions', split=f'val{sub}'), 'pp_fn': ( f'decode|resize({arg.res})|value_range(-1,1)' f'|flatten|{tokenizer("captions/text")}|keep("image", "labels")'), } config.evals.imagenet = { **eval_common, - 'dataset': 'imagenet2012', - 'split': f'validation{sub}', + 'data': dict(name='imagenet2012', split=f'validation{sub}'), 'pp_fn': ( f'decode|resize({arg.res})|value_range(-1,1)' '|clip_i1k_label_names' @@ -172,7 +163,7 @@ def get_config(arg=None): config.evals.disclf.type = 'proj.image_text.discriminative_classifier' config.evals.disclf.prefix = 'z/0shot/' config.evals.disclf.log_steps = eval_common['log_steps'] - config.evals.retrieval_coco = common_retrieval.get_coco( + config.evals.retrieval_coco = common.get_coco( pp_img=f'resize({arg.res})|value_range(-1, 1)', pp_txt=tokenizer('texts'), log_steps=config.evals.disclf.log_steps, diff --git a/big_vision/configs/proj/uvim/README.md b/big_vision/configs/proj/uvim/README.md index d4f363f..8a6bfa8 100644 --- a/big_vision/configs/proj/uvim/README.md +++ b/big_vision/configs/proj/uvim/README.md @@ -19,14 +19,14 @@ different tasks: panoptic segmentation, colorization and depth prediction. | Depth | UViM Stage I model | [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) | 0.155 RMSE | [link](https://storage.googleapis.com/big_vision/uvim/depth_stageI_params.npz) | | Depth | UViM Stage II model | [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) | 0.463 RMSE | [link](https://storage.googleapis.com/big_vision/uvim/depth_stageII_params.npz) | -All of this models can be interactively explored in our [colabs](/big_vision/configs/proj/uvim). +All of this models can be interactively explored in our [colabs](configs/proj/uvim). ## Running on a single-host TPU machine Below we provide instructions on how to run UViM training (stage I and stage II) using a single TPU host with 8 TPU accelerators. These instructions can be easily adapted to a GPU host and multi-host TPU setup, see the main -`big_vision` [README file](/README.md). +`big_vision` [README file](README.md). We assume that the user has already created and `ssh`-ed to the TPU host machine. The next step is to clone `big_vision` repository: diff --git a/big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py b/big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py index fca7b55..aa68eb7 100644 --- a/big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py +++ b/big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py @@ -46,7 +46,8 @@ def get_config(arg=''): arg = bvcc.parse_arg(arg, runlocal=False, singlehost=False) config = ConfigDict() - config.pp_train = ( + config.input = {} + config.input.pp = ( f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' f'inception_box|crop_box(key="image")|crop_box(key="labels")|' @@ -68,19 +69,17 @@ def get_config(arg=''): f'keep("image","image_ctx","image/id")' # image/id used for rng seeds. ) - config.dataset = 'coco/2017_panoptic' - config.train_split = 'train[4096:]' + config.input.data = dict(name='coco/2017_panoptic', split='train[4096:]') + config.input.batch_size = 512 + config.input.shuffle_buffer_size = 50_000 - config.batch_size = 512 config.total_epochs = 200 config.log_training_steps = 50 - config.shuffle_buffer_size = 50_000 config.ckpt_steps = 1000 config.keep_ckpt_steps = 5000 - config.ckpt_timeout = 1 config.prefetch_to_device = 2 - config.trial = 0 + config.seed = 0 # Optimizer section config.optax_name = 'big_vision.scale_by_adafactor' @@ -133,8 +132,7 @@ def get_config(arg=''): config.evals.val = ConfigDict() config.evals.val.type = 'proj.uvim.compute_mean' config.evals.val.pred = 'validation' - config.evals.val.dataset = config.dataset - config.evals.val.split = 'train[:4096]' + config.evals.val.data = dict(name=config.input.data.name, split='train[:4096]') config.evals.val.pp_fn = pp_eval config.evals.val.log_steps = 1000 @@ -157,10 +155,10 @@ def get_config(arg=''): # config.evals.save_pred.outfile = 'inference.npz' if arg.singlehost: - config.batch_size = 32 + config.input.batch_size = 32 config.num_epochs = 50 elif arg.runlocal: - config.batch_size = 4 - config.shuffle_buffer_size = 10 - config.evals.val.split = 'train[:16]' + config.input.batch_size = 4 + config.input.shuffle_buffer_size = 10 + config.evals.val.data.split = 'train[:16]' return config \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/train_imagenet2012_colorization_pretrained.py b/big_vision/configs/proj/uvim/train_imagenet2012_colorization_pretrained.py index 1f7e760..0245db5 100644 --- a/big_vision/configs/proj/uvim/train_imagenet2012_colorization_pretrained.py +++ b/big_vision/configs/proj/uvim/train_imagenet2012_colorization_pretrained.py @@ -13,7 +13,7 @@ # limitations under the License. # pylint: disable=line-too-long -r"""A config for training a UViM stage II model for the panoptic task. +r"""A config for training a UViM stage II model for the colorization task. """ import big_vision.configs.common as bvcc @@ -39,7 +39,8 @@ def get_config(arg=''): arg = bvcc.parse_arg(arg, runlocal=False, singlehost=False) config = ConfigDict() - config.pp_train = ( + config.input = {} + config.input.pp = ( f'decode_jpeg_and_inception_crop({RES})' f'|flip_lr' f'|copy(inkey="image", outkey="labels")' @@ -63,19 +64,17 @@ def get_config(arg=''): f'|strong_hash(inkey="tfds_id", outkey="image/id")' f'|keep("image","image_ctx","labels","image/id")') - config.dataset = 'imagenet2012' - config.train_split = 'train[4096:]' + config.input.data = dict(name='imagenet2012', split='train[4096:]') + config.input.batch_size = 512 + config.input.shuffle_buffer_size = 50_000 - config.batch_size = 512 config.total_epochs = 50 config.log_training_steps = 50 - config.shuffle_buffer_size = 50_000 config.ckpt_steps = 1000 config.keep_ckpt_steps = 5000 - config.ckpt_timeout = 1 config.prefetch_to_device = 2 - config.trial = 0 + config.seed = 0 # Optimizer section config.optax_name = 'big_vision.scale_by_adafactor' @@ -127,19 +126,18 @@ def get_config(arg=''): config.evals.val = ConfigDict() config.evals.val.type = 'proj.uvim.compute_mean' config.evals.val.pred = 'validation' - config.evals.val.dataset = config.dataset - config.evals.val.split = 'train[:4096]' + config.evals.val.data = dict(name=config.input.data.name, split='train[:4096]') config.evals.val.pp_fn = pp_eval config.evals.val.log_steps = 1000 base = { - 'type': 'proj.uvim.colorization', + 'type': 'proj.uvim.psnr', 'pp_fn': pp_eval.replace('decode|', ''), - 'log_steps': 2500, + 'log_steps': 10_000, } - config.evals.colorization_train = dict(**base, split='train[4096:8192]') - config.evals.colorization_holdout = dict(**base, split='train[:4096]') - config.evals.colorization_val = dict(**base, split='validation') + config.evals.psnr_train = dict(**base, split='train[4096:8192]') + config.evals.psnr_holdout = dict(**base, split='train[:4096]') + config.evals.psnr_val = dict(**base, split='validation') config.evals.colorization_val_coltran_fid = { 'type': 'proj.uvim.coltran_fid', @@ -154,10 +152,10 @@ def get_config(arg=''): # config.evals.save_pred.outfile = 'inference.npz' if arg.singlehost: - config.batch_size = 32 + config.input.batch_size = 32 config.total_epochs = 20 elif arg.runlocal: - config.batch_size = 8 - config.val_split = 'validation[:256]' - config.shuffle_buffer_size = 10 + config.input.batch_size = 8 + config.input.shuffle_buffer_size = 10 + config.evals.val.data.split = 'validation[:256]' return config \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/train_nyu_depth_pretrained.py b/big_vision/configs/proj/uvim/train_nyu_depth_pretrained.py index cae1e95..d2e0581 100644 --- a/big_vision/configs/proj/uvim/train_nyu_depth_pretrained.py +++ b/big_vision/configs/proj/uvim/train_nyu_depth_pretrained.py @@ -42,10 +42,11 @@ def get_config(arg='split=final'): """Config for training.""" - arg = bvcc.parse_arg(arg, split='final', runlocal=False, singlehost=True) + arg = bvcc.parse_arg(arg, split='final', runlocal=False, singlehost=False) config = ConfigDict() - config.pp_train = ( + config.input = {} + config.input.pp = ( f'decode|nyu_depth|' f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' f'inception_box|crop_box(key="image")|crop_box(key="labels")|' @@ -76,29 +77,25 @@ def get_config(arg='split=final'): f'keep("image","image_ctx","ground_truth")' ) - config.dataset = 'nyu_depth_v2' - config.train_split = 'train' + config.input.data = dict(name='nyu_depth_v2', split='train') + config.input.batch_size = 512 + config.input.shuffle_buffer_size = 50_000 - config.batch_size = 512 config.total_epochs = 50 config.log_training_steps = 50 - config.shuffle_buffer_size = 50_000 config.ckpt_steps = 1000 config.keep_ckpt_steps = 5000 - config.ckpt_timeout = 1 config.prefetch_to_device = 2 - config.trial = 0 config.seed = 0 # Optimizer section config.optax_name = 'big_vision.scale_by_adafactor' - config.optax = ConfigDict() + config.optax = dict(beta2_cap=0.95) config.optax.clipping_threshold = None - config.optax.beta2_cap = 0.95 - config.lr = 1e-3 - config.wd = 1e-6 + config.lr = 0.001 + config.wd = 0.000001 config.lr_mults = ( ('pos_embedding_encoder.*', 0.1), ('EmbedPatches.*', 0.1), @@ -146,14 +143,14 @@ def get_config(arg='split=final'): config.evals.val = ConfigDict() config.evals.val.type = 'proj.uvim.compute_mean' config.evals.val.pred = 'validation' - config.evals.val.dataset = config.dataset - config.evals.val.split = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'validation' config.evals.val.pp_fn = pp_eval config.evals.val.log_steps = 1000 base = { 'type': 'proj.uvim.nyu_depth', - 'dataset': config.dataset, + 'dataset': config.input.data.name, 'pp_fn': pp_predict, 'log_steps': 2000, 'min_depth': MIN_DEPTH, @@ -162,13 +159,12 @@ def get_config(arg='split=final'): config.evals.nyu_depth_val = dict(**base, split='validation') if arg.singlehost: - config.batch_size = 32 + config.input.batch_size = 32 config.total_epochs = 20 elif arg.runlocal: config.oracle.model_init = '/tmp/checkpoint.npz' config.model_init = {'encoder': '/tmp/enc_checkpoint.npz'} config.evals = {} - config.batch_size = 1 - config.val_split = 'validation[:16]' - config.shuffle_buffer_size = 10 + config.input.batch_size = 1 + config.input.shuffle_buffer_size = 10 return config \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/vqvae_coco_panoptic.py b/big_vision/configs/proj/uvim/vqvae_coco_panoptic.py index 071561c..5a620ba 100644 --- a/big_vision/configs/proj/uvim/vqvae_coco_panoptic.py +++ b/big_vision/configs/proj/uvim/vqvae_coco_panoptic.py @@ -36,14 +36,15 @@ def get_config(arg='res=512,patch_size=16'): config.task = 'proj.uvim.panoptic_task' - config.dataset = 'coco/2017_panoptic' - config.train_split = 'train[4096:]' + config.input = {} + config.input.data = dict(name='coco/2017_panoptic', split='train[4096:]') + + config.input.batch_size = 1024 + config.input.shuffle_buffer_size = 25_000 - config.trial = 0 - config.batch_size = 1024 config.total_epochs = 1000 - config.pp_train = ( + config.input.pp = ( f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|' f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' f'inception_box|crop_box(key="image")|crop_box(key="labels")|' @@ -56,11 +57,9 @@ def get_config(arg='res=512,patch_size=16'): f'value_range(-1, 1)|make_canonical|keep("image","labels")' ) - config.shuffle_buffer_size = 25_000 - config.log_training_steps = 50 config.ckpt_steps = 1000 - config.keep_ckpt_steps = 20000 + config.keep_ckpt_steps = 20_000 # Model section config.model_name = 'proj.uvim.vit' @@ -104,8 +103,8 @@ def get_config(arg='res=512,patch_size=16'): config.evals.val = mlc.ConfigDict() config.evals.val.type = 'proj.uvim.compute_mean' config.evals.val.pred = 'validation' - config.evals.val.dataset = config.dataset - config.evals.val.split = 'train[:4096]' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'train[:4096]' config.evals.val.pp_fn = pp_eval config.evals.val.log_steps = 250 @@ -127,18 +126,18 @@ def get_config(arg='res=512,patch_size=16'): # config.evals.save_pred.split = 'validation[:1024]' # config.evals.save_pred.outfile = 'inference.npz' - config.trial = 0 + config.seed = 0 if arg.singlehost: - config.batch_size = 128 + config.input.batch_size = 128 config.num_epochs = 100 elif arg.runlocal: - config.batch_size = 16 - config.shuffle_buffer_size = 10 + config.input.batch_size = 16 + config.input.shuffle_buffer_size = 10 config.log_training_steps = 5 config.model.enc_depth = 1 config.model.dec_depth = 1 - config.evals.val.split = 'validation[:16]' + config.evals.val.data.split = 'validation[:16]' config.evals.val.log_steps = 20 return config \ No newline at end of file diff --git a/big_vision/configs/proj/uvim/vqvae_imagenet2012_colorization.py b/big_vision/configs/proj/uvim/vqvae_imagenet2012_colorization.py index 8a7f55b..d4ecb7a 100644 --- a/big_vision/configs/proj/uvim/vqvae_imagenet2012_colorization.py +++ b/big_vision/configs/proj/uvim/vqvae_imagenet2012_colorization.py @@ -22,21 +22,22 @@ def get_config(arg='res=512,patch_size=16'): - """Config for training label compression on COCO-panoptic.""" + """A config for training a UViM stage I model for the colorization task.""" arg = bvcc.parse_arg(arg, res=512, patch_size=16, - runlocal=False, singlehost=True) + runlocal=False, singlehost=False) config = mlc.ConfigDict() config.task = 'proj.uvim.colorization_task' - config.dataset = 'imagenet2012' - config.train_split = 'train[4096:]' + config.input = {} + config.input.data = dict(name='imagenet2012', split='train[4096:]') + + config.input.batch_size = 1024 + config.input.shuffle_buffer_size = 25_000 - config.trial = 0 - config.batch_size = 1024 config.total_epochs = 100 - config.pp_train = ( + config.input.pp = ( f'decode_jpeg_and_inception_crop({arg.res})' f'|flip_lr' f'|copy(inkey="image", outkey="labels")' @@ -54,11 +55,9 @@ def get_config(arg='res=512,patch_size=16'): f'|value_range(-1,1,key="labels")' f'|keep("image","labels")') - config.shuffle_buffer_size = 25_000 - config.log_training_steps = 50 config.ckpt_steps = 1000 - config.keep_ckpt_steps = 20000 + config.keep_ckpt_steps = 20_000 # Model section config.model_name = 'proj.uvim.vit' @@ -101,19 +100,19 @@ def get_config(arg='res=512,patch_size=16'): config.evals.val = mlc.ConfigDict() config.evals.val.type = 'proj.uvim.compute_mean' config.evals.val.pred = 'validation' - config.evals.val.dataset = config.dataset - config.evals.val.split = 'train[:4096]' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'train[:4096]' config.evals.val.pp_fn = pp_eval config.evals.val.log_steps = 250 base = { - 'type': 'proj.uvim.colorization', + 'type': 'proj.uvim.psnr', 'pp_fn': pp_eval.replace('decode|', ''), - 'log_steps': 2500, + 'log_steps': 10_000, } - config.evals.colorization_train = dict(**base, split='train[4096:8192]') - config.evals.colorization_holdout = dict(**base, split='train[:4096]') - config.evals.colorization_val = dict(**base, split='validation') + config.evals.psnr_train = dict(**base, split='train[4096:8192]') + config.evals.psnr_holdout = dict(**base, split='train[:4096]') + config.evals.psnr_val = dict(**base, split='validation') config.evals.colorization_val_coltran_fid = { 'type': 'proj.uvim.coltran_fid', @@ -127,25 +126,25 @@ def get_config(arg='res=512,patch_size=16'): # config.evals.save_pred.split = 'validation[:1024]' # config.evals.save_pred.outfile = 'inference.npz' - config.trial = 0 + config.seed = 0 if arg.singlehost: - config.batch_size = 128 + config.input.batch_size = 128 config.total_epochs = 20 elif arg.runlocal: - config.batch_size = 16 - config.shuffle_buffer_size = 10 + config.input.batch_size = 16 + config.input.shuffle_buffer_size = 10 config.log_training_steps = 5 config.model.enc_depth = 1 config.model.dec_depth = 1 - config.evals.val.split = 'validation[:16]' + config.evals.val.data.split = 'validation[:16]' config.evals.val.log_steps = 20 - config.evals.colorization_train.split = 'train[:256]' - config.evals.colorization_train.log_steps = 20 - config.evals.colorization_holdout.split = 'train[256:512]' - config.evals.colorization_holdout.log_steps = 20 - config.evals.colorization_val.split = 'train[:256]' - config.evals.colorization_val.log_steps = 20 + config.evals.psnr_train.split = 'train[:256]' + config.evals.psnr_train.log_steps = 20 + config.evals.psnr_holdout.split = 'train[256:512]' + config.evals.psnr_holdout.log_steps = 20 + config.evals.psnr_val.split = 'train[:256]' + config.evals.psnr_val.log_steps = 20 config.evals.colorization_val_coltran_fid.split = 'validation[:256]' config.evals.colorization_val_coltran_fid.log_steps = 20 diff --git a/big_vision/configs/proj/uvim/vqvae_nyu_depth.py b/big_vision/configs/proj/uvim/vqvae_nyu_depth.py index d1ba724..d5ae293 100644 --- a/big_vision/configs/proj/uvim/vqvae_nyu_depth.py +++ b/big_vision/configs/proj/uvim/vqvae_nyu_depth.py @@ -35,14 +35,15 @@ def get_config(arg='res=512,patch_size=16'): config.task = 'proj.uvim.depth_task' - config.dataset = 'nyu_depth_v2' - config.train_split = 'train' + config.input = {} + config.input.data = dict(name='nyu_depth_v2', split='train) + + config.input.batch_size = 1024 + config.input.shuffle_buffer_size = 25_000 - config.trial = 0 - config.batch_size = 1024 config.total_epochs = 200 - config.pp_train = ( + config.input.pp = ( f'decode|nyu_depth|' f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|' f'inception_box|crop_box(key="image")|crop_box(key="labels")|' @@ -64,11 +65,9 @@ def get_config(arg='res=512,patch_size=16'): f'keep("image","labels","ground_truth")' ) - config.shuffle_buffer_size = 25_000 - config.log_training_steps = 50 config.ckpt_steps = 1000 - config.keep_ckpt_steps = 20000 + config.keep_ckpt_steps = 20_000 # Model section config.min_depth = MIN_DEPTH @@ -113,14 +112,14 @@ def get_config(arg='res=512,patch_size=16'): config.evals.val = mlc.ConfigDict() config.evals.val.type = 'proj.uvim.compute_mean' config.evals.val.pred = 'validation' - config.evals.val.dataset = config.dataset - config.evals.val.split = 'validation' + config.evals.val.data = {**config.input.data} + config.evals.val.data.split = 'validation' config.evals.val.pp_fn = pp_eval config.evals.val.log_steps = 250 base = { 'type': 'proj.uvim.nyu_depth', - 'dataset': config.dataset, + 'dataset': config.input.data.name, 'pp_fn': pp_pred, 'log_steps': 2000, 'min_depth': MIN_DEPTH, @@ -128,16 +127,18 @@ def get_config(arg='res=512,patch_size=16'): } config.evals.nyu_depth_val = dict(**base, split='validation') + config.seed = 0 + if arg.singlehost: - config.batch_size = 128 + config.input.batch_size = 128 config.total_epochs = 50 elif arg.runlocal: - config.batch_size = 16 - config.shuffle_buffer_size = 10 + config.input.batch_size = 16 + config.input.shuffle_buffer_size = 10 config.log_training_steps = 5 config.model.enc_depth = 1 config.model.dec_depth = 1 - config.evals.val.split = 'validation[:16]' + config.evals.val.data.split = 'validation[:16]' config.evals.val.log_steps = 20 return config \ No newline at end of file diff --git a/big_vision/configs/transfer.py b/big_vision/configs/transfer.py index 0cfee84..273b7d2 100644 --- a/big_vision/configs/transfer.py +++ b/big_vision/configs/transfer.py @@ -51,6 +51,10 @@ def _set_model(config, model): config.model_init = 'i1k-s16-300ep' config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d', rep_size=True) + elif model == 'bit-m-r50x1': + config.model_name = 'bit_paper' + config.model_init = 'M' + config.model = dict(depth=50, width=1) else: raise ValueError(f'Unknown model: {model}, please define customized model.') @@ -82,8 +86,7 @@ def _set_task(config, dataset, train, val, test, n_cls, decay_type='cosine', ) - config.dataset = dataset - config.train_split = train + config.input.data = dict(name=dataset, split=train) pp_common = ( '|value_range(-1, 1)|' f'onehot({n_cls}, key="{lbl}", key_result="labels")|' @@ -101,21 +104,20 @@ def _set_task(config, dataset, train, val, test, n_cls, 'inception_crop, resmall_crop, resize_crop') if flip: pp_train += '|flip_lr' - config.pp_train = pp_train + pp_common + config.input.pp = pp_train + pp_common pp = f'decode|resize_small({h_res})|central_crop({l_res})' + pp_common config.num_classes = n_cls - eval_common = dict( - type='classification', - dataset=dataset, - loss_name='softmax_xent', - log_steps=100, - pp_fn=pp, - ) - config.evals = {} - config.evals.val = dict(**eval_common, split=val) - config.evals.test = dict(**eval_common, split=test) + def get_eval(split): + return dict( + type='classification', + data=dict(name=dataset, split=split), + loss_name='softmax_xent', + log_steps=100, + pp_fn=pp, + ) + config.evals = dict(val=get_eval(val), test=get_eval(test)) def _set_imagenet_variants(config, h_res=448, l_res=384): @@ -131,15 +133,13 @@ def _set_imagenet_variants(config, h_res=448, l_res=384): # NOTE: keep test == val for convenience in subsequent analysis. config.evals.real = dict(type='classification') - config.evals.real.dataset = 'imagenet2012_real' - config.evals.real.split = 'validation' + config.evals.real.data = dict(name='imagenet2012_real', split='validation') config.evals.real.pp_fn = pp.format(lbl='real_label') config.evals.real.loss_name = config.loss config.evals.real.log_steps = 100 config.evals.v2 = dict(type='classification') - config.evals.v2.dataset = 'imagenet_v2' - config.evals.v2.split = 'test' + config.evals.v2.data = dict(name='imagenet_v2', split='test') config.evals.v2.pp_fn = pp.format(lbl='label') config.evals.v2.loss_name = config.loss config.evals.v2.log_steps = 100 @@ -151,8 +151,9 @@ def get_config(arg=None): h_res=448, l_res=384, runlocal=False) config = mlc.ConfigDict() - config.batch_size = 512 if not arg.runlocal else 8 - config.shuffle_buffer_size = 50_000 if not arg.runlocal else 100 + config.input = {} + config.input.batch_size = 512 if not arg.runlocal else 8 + config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 100 config.log_training_steps = 10 config.ckpt_steps = 1000 diff --git a/big_vision/configs/vit_i1k.py b/big_vision/configs/vit_i1k.py index 24a0ab0..3f0ccb5 100644 --- a/big_vision/configs/vit_i1k.py +++ b/big_vision/configs/vit_i1k.py @@ -62,6 +62,11 @@ def get_config(arg=None): arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug='') config = mlc.ConfigDict() + config.seed = 0 + config.total_epochs = 300 + config.num_classes = 1000 + config.loss = 'softmax_xent' + # If this gives a KeyError, lookup Fig4 of the paper and add an entry. # Note, this here is a good average between 30ep and 300ep, sometimes you coud # find a slightly better setting for either of them. @@ -74,31 +79,31 @@ def get_config(arg=None): 'L/16': 'medium2', }[arg.variant] - config.dataset = 'imagenet2012' - config.train_split = 'train[:99%]' - config.cache_raw = not arg.runlocal # Needs up to 120GB of RAM! - config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. - config.num_classes = 1000 - config.loss = 'softmax_xent' - config.batch_size = 4096 - config.total_epochs = 300 + config.input = dict() + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 4096 + config.input.cache_raw = not arg.runlocal # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 pp_common = ( '|value_range(-1, 1)' '|onehot(1000, key="{lbl}", key_result="labels")' '|keep("image", "labels")' ) - config.pp_train = ( + config.input.pp = ( 'decode_jpeg_and_inception_crop(224)|flip_lr|' + RANDAUG_DEF[aug_setting] + pp_common.format(lbl='label') ) - pp = 'decode|resize_small(256)|central_crop(224)' + pp_common + pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common # Aggressive pre-fetching because our models here are small, so we not only # can afford it, but we also need it for the smallest models to not be # bottle-necked by the input pipeline. Play around with it for -L models tho. - config.prefetch_to_host = 8 + config.input.prefetch = 8 config.prefetch_to_device = 4 config.log_training_steps = 50 @@ -131,35 +136,35 @@ def get_config(arg=None): config.mixup = MIXUP_DEF[aug_setting] # Eval section - eval_common = dict( - type='classification', - dataset='imagenet2012', - pp_fn=pp.format(lbl='label'), - loss_name=config.loss, - log_steps=2500, # Very fast O(seconds) so it's fine to run it often. - ) + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + cache_final=not arg.runlocal, + ) config.evals = {} - config.evals.train = {**eval_common, 'split': 'train[:2%]'} - config.evals.minival = {**eval_common, 'split': 'train[99%:]'} - config.evals.val = {**eval_common, 'split': 'validation'} - config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'} - - config.evals.real = {**eval_common} - config.evals.real.dataset = 'imagenet2012_real' - config.evals.real.split = 'validation' - config.evals.real.pp_fn = pp.format(lbl='real_label') + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') + config.evals.real.pp_fn = pp_eval.format(lbl='real_label') config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal) config.fewshot.log_steps = 10_000 # Make a few things much smaller for quick local debugging testruns. if arg.runlocal: - config.shuffle_buffer_size = 10 - config.batch_size = 8 - config.evals.train.split = 'train[:16]' - config.evals.minival.split = 'train[:16]' - config.evals.val.split = 'validation[:16]' - config.evals.real.split = 'validation[:16]' - config.evals.v2.split = 'test[:16]' + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 + config.input.cache_raw = False + config.evals.train.data.split = 'train[:16]' + config.evals.minival.data.split = 'train[:16]' + config.evals.val.data.split = 'validation[:16]' + config.evals.v2.data.split = 'test[:16]' + config.evals.real.data.split = 'validation[:16]' return config \ No newline at end of file diff --git a/big_vision/configs/vit_i21k.py b/big_vision/configs/vit_i21k.py index fb12e4e..1b483a1 100644 --- a/big_vision/configs/vit_i21k.py +++ b/big_vision/configs/vit_i21k.py @@ -50,6 +50,12 @@ def get_config(arg=None): arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug=None) config = mlc.ConfigDict() + config.seed = 0 + config.total_epochs = 300 + config.num_classes = 21843 + config.init_head_bias = -10.0 + config.loss = 'sigmoid_xent' + # If this gives a KeyError, lookup Fig4 of the paper and add an entry. # Note, this here is a good average between 30ep and 300ep, sometimes you coud # find a slightly better setting for either of them. @@ -62,27 +68,24 @@ def get_config(arg=None): 'L/16': 'medium2', }[arg.variant] - config.dataset = 'imagenet21k' - config.train_split = 'full[51200:]' - config.num_classes = 21843 - config.init_head_bias = -10.0 - config.loss = 'sigmoid_xent' - - config.batch_size = 4096 - config.total_epochs = 300 + config.input = dict() + config.input.data = dict( + name='imagenet21k', + split='full[51200:]', + ) + config.input.batch_size = 4096 + config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")' pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}') pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"') - config.pp_train = f'decode_jpeg_and_inception_crop(224)|flip_lr|{RANDAUG_DEF[aug_setting]}' + pp_common_i21k - pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common_i21k - pp_eval_i1k = 'decode|resize_small(256)|central_crop(224)' + pp_common_i1k - config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. + config.input.pp = f'decode_jpeg_and_inception_crop(224)|flip_lr|{RANDAUG_DEF[aug_setting]}' + pp_common_i21k + pp_eval = 'decode|resize_small(256)|central_crop(224)' # Aggressive pre-fetching because our models here are small, so we not only # can afford it, but we also need it for the smallest models to not be # bottle-necked by the input pipeline. Play around with it for -L models tho. - config.prefetch_to_host = 8 + config.input.prefetch = 8 config.prefetch_to_device = 4 config.log_training_steps = 50 @@ -103,18 +106,19 @@ def get_config(arg=None): config.mixup = MIXUP_DEF[aug_setting] - # Eval section - eval_common = dict( - type='classification', - dataset=config.dataset, - pp_fn=pp_eval, - loss_name=config.loss, - log_steps=1000, # Very fast O(seconds) so it's fine to run it often. - ) + # Evaluations on i21k itself. + def eval_i21k(split): + return dict( + type='classification', + data={**config.input.data, 'split': split}, + pp_fn=pp_eval + pp_common_i21k, + loss_name=config.loss, + log_steps=1000, # Very fast O(seconds) so it's fine to run it often. + ) config.evals = {} - config.evals.test = {**eval_common, 'split': 'full[:25_600]'} - config.evals.val = {**eval_common, 'split': 'full[25_600:51_200]'} - config.evals.train = {**eval_common, 'split': 'full[51_200:76_800]'} + config.evals.test = eval_i21k('full[:25_600]') + config.evals.val = eval_i21k('full[25_600:51_200]') + config.evals.train = eval_i21k('full[51_200:76_800]') # Few-shot evaluators config.evals.fewshot = get_fewshot_lsr(runlocal=arg.runlocal) @@ -122,10 +126,14 @@ def get_config(arg=None): # Make a few things much smaller for quick local debugging testruns. if arg.runlocal: - config.shuffle_buffer_size = 10 - config.batch_size = 8 - config.evals.test.split = 'full[:16]' - config.evals.train.split = 'full[:16]' - config.evals.val.split = 'full[:16]' + config.input.shuffle_buffer_size = 10 + config.input.batch_size = 8 + config.evals.test.data.split = 'full[:16]' + config.evals.train.data.split = 'full[:16]' + config.evals.val.data.split = 'full[:16]' + config.evals.i1k_val.data.split = 'validation[:16]' + config.evals.i1k_v2.data.split = 'test[:16]' + config.evals.i1k_a.data.split = 'test[:16]' + config.evals.i1k_r.data.split = 'test[:16]' return config \ No newline at end of file diff --git a/big_vision/configs/vit_s16_i1k.py b/big_vision/configs/vit_s16_i1k.py index 665ad45..af411c2 100644 --- a/big_vision/configs/vit_s16_i1k.py +++ b/big_vision/configs/vit_s16_i1k.py @@ -34,21 +34,26 @@ def get_config(): """Config for training.""" config = mlc.ConfigDict() - config.dataset = 'imagenet2012' - config.train_split = 'train[:99%]' - config.cache_raw = True # Requires up to 120GB of RAM! - config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. + config.seed = 0 + config.total_epochs = 90 config.num_classes = 1000 config.loss = 'softmax_xent' - config.batch_size = 1024 - config.total_epochs = 90 + + config.input = {} + config.input.data = dict( + name='imagenet2012', + split='train[:99%]', + ) + config.input.batch_size = 1024 + config.input.cache_raw = True # Needs up to 120GB of RAM! + config.input.shuffle_buffer_size = 250_000 pp_common = ( '|value_range(-1, 1)' '|onehot(1000, key="{lbl}", key_result="labels")' '|keep("image", "labels")' ) - config.pp_train = ( + config.input.pp = ( 'decode_jpeg_and_inception_crop(224)|flip_lr|randaug(2,10)' + pp_common.format(lbl='label') ) @@ -78,22 +83,20 @@ def get_config(): config.mixup = dict(p=0.2, fold_in=None) # Eval section - eval_common = dict( - type='classification', - dataset='imagenet2012', - pp_fn=pp_eval.format(lbl='label'), - loss_name=config.loss, - log_steps=2500, # Very fast O(seconds) so it's fine to run it often. - ) + def get_eval(split, dataset='imagenet2012'): + return dict( + type='classification', + data=dict(name=dataset, split=split), + pp_fn=pp_eval.format(lbl='label'), + loss_name=config.loss, + log_steps=2500, # Very fast O(seconds) so it's fine to run it often. + ) config.evals = {} - config.evals.train = {**eval_common, 'split': 'train[:2%]'} - config.evals.minival = {**eval_common, 'split': 'train[99%:]'} - config.evals.val = {**eval_common, 'split': 'validation'} - config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'} - - config.evals.real = dict(**eval_common) - config.evals.real.dataset = 'imagenet2012_real' - config.evals.real.split = 'validation' + config.evals.train = get_eval('train[:2%]') + config.evals.minival = get_eval('train[99%:]') + config.evals.val = get_eval('validation') + config.evals.v2 = get_eval('test', dataset='imagenet_v2') + config.evals.real = get_eval('validation', dataset='imagenet2012_real') config.evals.real.pp_fn = pp_eval.format(lbl='real_label') return config diff --git a/big_vision/datasets/core.py b/big_vision/datasets/core.py new file mode 100644 index 0000000..f951a8e --- /dev/null +++ b/big_vision/datasets/core.py @@ -0,0 +1,77 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core data functions, dispatch calls to the requested dataset.""" +import importlib + + +# Note: intentionally not using ABC to avoid forcing implementation of every +# method, since one can imagine train-only datasets for example. +class DataSource: + """The API that any data source should implement.""" + + def get_tfdata(self, ordered): + """Creates this data object as a tf.data.Dataset. + + This will be called separately in each process, and it is up to the dataset + implementation to shard it accordingly if desired! + + Args: + ordered: if True, the dataset should use deterministic ordering, if False + it may have undefined ordering. Think of True == val, False == train. + + Returns: + A tf.data.Dataset object. + + Raises: + RuntimeError: if not implemented by the dataset, but called. + """ + raise RuntimeError("not implemented for {self.__class__.__name__}") + + @property + def total_examples(self): + """Returns number of examples in the dataset, regardless of sharding.""" + raise RuntimeError("not implemented for {self.__class__.__name__}") + + def num_examples_per_process(self, nprocess=None): + """Returns a list of the numer of examples for each process. + + This is only needed for datasets that should go through make_for_inference. + + Args: + nprocess: the number of processes, use `jax.process_count()` if None. + + Returns: + Returns a list of the numer of examples for each process. + + Ideally, this would always be `[total() / nprocess] * nprocess`, but in + reality we can almost never perfectly shard a dataset across arbitrary + number of processes. + + One alternative option that can work in some cases is to not even shard + the dataset and thus return `[num_examples()] * nprocess. + + Raises: + RuntimeError: if not implemented by the dataset, but called. + """ + raise RuntimeError("not implemented for {self.__class__.__name__}") + + +def get(name, **kw): + if name.startswith("bv:"): + mod = importlib.import_module(f"big_vision.datasets.{name[3:]}") + return mod.DataSource(**kw) + else: + mod = importlib.import_module("big_vision.datasets.tfds") + return mod.DataSource(name, **kw) diff --git a/big_vision/datasets/tfds.py b/big_vision/datasets/tfds.py new file mode 100644 index 0000000..b203f2a --- /dev/null +++ b/big_vision/datasets/tfds.py @@ -0,0 +1,60 @@ +# Copyright 2022 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow Datasets as data source for big_vision.""" +import functools + +import big_vision.datasets.core as ds_core +import jax +import tensorflow_datasets as tfds + + +class DataSource(ds_core.DataSource): + """Use TFDS as a data source.""" + + def __init__(self, name, split, data_dir=None, skip_decode=("image",)): + self.builder = _get_builder(name, data_dir) + self.split = split + # Each host is responsible for a fixed subset of data + process_splits = tfds.even_splits(split, jax.process_count()) + self.process_split = process_splits[jax.process_index()] + self.skip_decoders = { + f: tfds.decode.SkipDecoding() + for f in skip_decode + if f in self.builder.info.features + } + + def get_tfdata(self, ordered=False): + return self.builder.as_dataset( + split=self.process_split, + shuffle_files=not ordered, + read_config=tfds.ReadConfig( + skip_prefetch=True, # We prefetch after pipeline. + try_autocache=False, # We control this, esp. for few-shot. + add_tfds_id=True, + ), + decoders=self.skip_decoders) + + @property + def total_examples(self): + return self.builder.info.splits[self.split].num_examples + + def num_examples_per_process(self, nprocess=None): + splits = tfds.even_splits(self.split, nprocess or jax.process_count()) + return [self.builder.info.splits[s].num_examples for s in splits] + + +@functools.lru_cache(maxsize=None) +def _get_builder(dataset, data_dir): + return tfds.builder(dataset, data_dir=data_dir, try_gcs=True) diff --git a/big_vision/evaluators/classification.py b/big_vision/evaluators/classification.py index 7bf2370..b64a0f0 100644 --- a/big_vision/evaluators/classification.py +++ b/big_vision/evaluators/classification.py @@ -16,6 +16,7 @@ # pylint: disable=consider-using-from-import from functools import partial, lru_cache +import big_vision.datasets.core as ds_core import big_vision.input_pipeline as input_pipeline import big_vision.pp.builder as pp_builder import big_vision.utils as u @@ -54,12 +55,14 @@ def _eval_fn(params, batch, labels, mask): class Evaluator: """Classification evaluator.""" - def __init__(self, predict_fn, dataset, split, pp_fn, batch_size, loss_name, - data_dir=None, cache_final=True, cache_raw=False, prefetch=1, + def __init__(self, predict_fn, data, pp_fn, batch_size, loss_name, + cache_final=True, cache_raw=False, prefetch=1, label_key='labels'): + data = ds_core.get(**data) pp_fn = pp_builder.get_preprocess_fn(pp_fn) self.ds, self.steps = input_pipeline.make_for_inference( - dataset, split, pp_fn, batch_size, data_dir, + data.get_tfdata(ordered=True), pp_fn, batch_size, + num_ex_per_process=data.num_examples_per_process(), cache_final=cache_final, cache_raw=cache_raw) self.data_iter = input_pipeline.start_input_pipeline(self.ds, prefetch) self.eval_fn = get_eval_fn(predict_fn, loss_name) diff --git a/big_vision/evaluators/common.py b/big_vision/evaluators/common.py index 40bc028..3ecc1ee 100644 --- a/big_vision/evaluators/common.py +++ b/big_vision/evaluators/common.py @@ -33,7 +33,8 @@ def from_config(config, predict_fns, write_note=lambda s: s): prefix = cfg.pop("prefix", f"{name}/") # Use same batch_size as eval by default, to reduce fragmentation. - cfg["batch_size"] = cfg.get("batch_size") or config.get("batch_size_eval") or config.get("batch_size") # pylint: disable=line-too-long + # TODO: eventually remove all the deprecated names... + cfg["batch_size"] = cfg.get("batch_size") or config.get("batch_size_eval") or config.get("input.batch_size") or config.get("batch_size") # pylint: disable=line-too-long module = importlib.import_module(f"big_vision.evaluators.{module}") evaluator = module.Evaluator(predict_fns[fn_key], **cfg) diff --git a/big_vision/evaluators/fewshot_lsr.py b/big_vision/evaluators/fewshot_lsr.py index b8edfb3..2c63f36 100644 --- a/big_vision/evaluators/fewshot_lsr.py +++ b/big_vision/evaluators/fewshot_lsr.py @@ -13,16 +13,17 @@ # limitations under the License. """Utils for few-shot evaluation.""" +# pylint: disable=consider-using-from-import import functools +import big_vision.datasets.core as ds_core import big_vision.input_pipeline as input_pipeline import big_vision.pp.builder as pp_builder import big_vision.utils as u import jax import jax.numpy as jnp import numpy as np -import tensorflow_datasets as tfds BIAS_CONSTANT = 100.0 @@ -145,17 +146,20 @@ def _get_dataset(self, dataset, train_split, test_split): try: return self._datasets[key] except KeyError: + # NOTE: only supporting TFDS data for now for bwd compat/lazyness. + train_data = ds_core.get(name=dataset, split=train_split) train_ds, batches_tr = input_pipeline.make_for_inference( - dataset=dataset, - split=train_split, + train_data.get_tfdata(ordered=True), + num_ex_per_process=train_data.num_examples_per_process(), batch_size=self.batch_size, preprocess_fn=pp_builder.get_preprocess_fn(self.pp_tr)) + test_data = ds_core.get(name=dataset, split=test_split) test_ds, batches_te = input_pipeline.make_for_inference( - dataset=dataset, - split=test_split, + test_data.get_tfdata(ordered=True), + num_ex_per_process=test_data.num_examples_per_process(), batch_size=self.batch_size, preprocess_fn=pp_builder.get_preprocess_fn(self.pp_te)) - num_classes = tfds.builder(dataset).info.features["label"].num_classes + num_classes = train_data.builder.info.features["label"].num_classes return self._datasets.setdefault( key, (train_ds, batches_tr, test_ds, batches_te, num_classes)) diff --git a/big_vision/evaluators/proj/distill/distance.py b/big_vision/evaluators/proj/distill/distance.py index 72e3072..bbbde16 100644 --- a/big_vision/evaluators/proj/distill/distance.py +++ b/big_vision/evaluators/proj/distill/distance.py @@ -15,7 +15,8 @@ """Evaluator for the classfication task.""" from functools import partial, lru_cache -from big_vision import input_pipeline +from big_vision import input_pipeline +import big_vision.datasets.core as ds_core import big_vision.pp.builder as pp_builder import big_vision.utils as u @@ -94,12 +95,14 @@ def _eval_fn(params, batch, mask): class Evaluator: """Distillation distance evaluator.""" - def __init__(self, student_teacher_fwd, dataset, split, pp_fn, distances, - what=('logits', 'logits'), cache_final=True, **data_kw): + def __init__(self, student_teacher_fwd, data, pp_fn, distances, + what=('logits', 'logits'), **data_kw): + data = ds_core.get(**data) pp_fn = pp_builder.get_preprocess_fn(pp_fn) prefetch = data_kw.pop('prefetch', 1) self.ds, self.steps = input_pipeline.make_for_inference( - dataset, split, preprocess_fn=pp_fn, cache_final=cache_final, **data_kw) + data.get_tfdata(ordered=True), pp_fn, + num_ex_per_process=data.num_examples_per_process(), **data_kw) self.data_iter = input_pipeline.start_input_pipeline(self.ds, prefetch) dist_fns = tuple(get_dist_fn(**dist) for dist in distances) self.dist_names = [ @@ -121,7 +124,7 @@ def run(self, params): all_ds[i].append(np.array(val[0]).flatten()[batch_ms == 1]) for name, ds in zip(self.dist_names, all_ds): ds = np.concatenate(ds) + yield f'{name}/all', ds yield f'{name}/avg', np.mean(ds) - yield f'{name}/std', np.std(ds) yield f'{name}/min', np.min(ds) yield f'{name}/max', np.max(ds) diff --git a/big_vision/evaluators/proj/image_text/contrastive.py b/big_vision/evaluators/proj/image_text/contrastive.py index 09ba49e..305265e 100644 --- a/big_vision/evaluators/proj/image_text/contrastive.py +++ b/big_vision/evaluators/proj/image_text/contrastive.py @@ -16,6 +16,7 @@ import functools from big_vision import input_pipeline +import big_vision.datasets.core as ds_core import big_vision.pp.builder as pp_builder import big_vision.utils as u import jax @@ -59,13 +60,15 @@ def _eval_fn(params, images, labels, mask): class Evaluator: """Contrastive evaluator.""" - def __init__(self, predict_fn, dataset, split, pp_fn, batch_size, - use_global_batch, data_dir=None, cache_final=True, + def __init__(self, predict_fn, data, pp_fn, batch_size, + use_global_batch, cache_final=True, cache_raw=False, prefetch=1, label_key="labels"): + data = ds_core.get(**data) pp_fn = pp_builder.get_preprocess_fn(pp_fn) self.ds, self.steps = input_pipeline.make_for_inference( - dataset, split, pp_fn, batch_size, data_dir, cache_final=cache_final, - cache_raw=cache_raw) + data.get_tfdata(ordered=True), pp_fn, batch_size, + num_ex_per_process=data.num_examples_per_process(), + cache_final=cache_final, cache_raw=cache_raw) self.data_iter = input_pipeline.start_input_pipeline(self.ds, prefetch) self.eval_fn = get_eval_fn(predict_fn, use_global_batch) self.label_key = label_key diff --git a/big_vision/evaluators/proj/uvim/coco_panoptic.py b/big_vision/evaluators/proj/uvim/coco_panoptic.py index edb2659..b20d445 100644 --- a/big_vision/evaluators/proj/uvim/coco_panoptic.py +++ b/big_vision/evaluators/proj/uvim/coco_panoptic.py @@ -61,6 +61,8 @@ def __init__(self, predict_fn, pp_fn, batch_size, + dataset='coco/2017_panoptic', + dataset_dir=None, split='validation', predict_kwargs=None): # Prepare to run predict on all processes and gather predictions on all @@ -85,7 +87,7 @@ def preprocess(example): } self.data = common.get_jax_process_dataset( - 'coco/2017_panoptic', split, + dataset, split, dataset_dir=dataset_dir, global_batch_size=batch_size, pp_fn=preprocess) @@ -93,7 +95,8 @@ def preprocess(example): if jax.process_index() == 0: self.result_dir = tempfile.TemporaryDirectory() (self.gt_folder, self.gt_json, self.categories_json, - self.remap, self.size_map) = _prepare_ground_truth(split) + self.remap, self.size_map) = _prepare_ground_truth( + dataset, split, dataset_dir) def _compute_png_predictions(self, params): """Computes predictions and converts then to png to optimize memory use.""" @@ -170,8 +173,80 @@ def run(self, params): yield f'{k}_{m}', res[k][m] -def _prepare_ground_truth(split): - """Prepare ground truth.""" +def _prepare_ground_truth(dataset, split, data_dir): + """Prepare ground truth from tf.data.Dataset.""" + if dataset == 'coco/2017_panoptic' and data_dir is None: + return _prepare_ground_truth_from_zipfiles(split) + else: + return _prepare_ground_truth_from_dataset(dataset, split, data_dir) + + +@functools.lru_cache(maxsize=None) +def _prepare_ground_truth_from_dataset(dataset, split, data_dir): + """Prepare ground truth from a tf.data.Dataset.""" + dataset = tfds.builder(dataset, data_dir=data_dir).as_dataset(split=split) + + categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE) + with gfile.GFile(categories_json, 'rb') as f: + categories = json.loads(f.read()) + + # Build map from tfds class ids to COCO class ids. + remap = {0: 0} + with gfile.GFile(categories_json, 'r') as f: + remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(categories)}} + + gt_folder = tempfile.mkdtemp() + gfile.makedirs(gt_folder) + size_map = {} + annotations = [] + images = [] + for example in dataset: + image_id = int(example['image/id']) + panoptic_image = example['panoptic_image'] + ann_ids = example['panoptic_objects']['id'] + ann_labels = example['panoptic_objects']['label'] + ann_iscrowd = example['panoptic_objects']['is_crowd'] + ann_area = example['panoptic_objects']['area'] + + fname = f'{image_id:012d}.png' + with gfile.GFile(os.path.join(gt_folder, fname), 'wb') as f: + f.write(tf.io.encode_png(panoptic_image).numpy()) + + size_map[image_id] = (panoptic_image.shape[0], panoptic_image.shape[1]) + + segments_info = [] + for i in range(len(ann_ids)): + segments_info.append({ + 'id': int(ann_ids[i]), + 'category_id': remap[int(ann_labels[i] + 1)], + 'iscrowd': int(ann_iscrowd[i]), + 'area': int(ann_area[i]), + }) + + annotations.append({ + 'file_name': str(fname), + 'image_id': int(image_id), + 'segments_info': segments_info + }) + images.append({ + 'id': image_id, + 'file_name': f'{image_id:012d}.jpg', + }) + + # Write annotations.json needed for pq_compute. + gt_json = os.path.join(gt_folder, 'annotations.json') + with gfile.GFile(gt_json, 'wb') as f: + f.write(json.dumps({ + 'images': images, + 'annotations': annotations, + 'categories': categories, + })) + + return gt_folder, gt_json, categories_json, remap, size_map + + +def _prepare_ground_truth_from_zipfiles(split): + """Prepare ground truth from coco zip files.""" split_prefix = split.split('[')[0] if split_prefix not in ('train', 'validation'): raise ValueError(f'Split {split} not supported') diff --git a/big_vision/evaluators/proj/uvim/compute_mean.py b/big_vision/evaluators/proj/uvim/compute_mean.py index 9d9993a..86bceac 100644 --- a/big_vision/evaluators/proj/uvim/compute_mean.py +++ b/big_vision/evaluators/proj/uvim/compute_mean.py @@ -17,6 +17,7 @@ from typing import Mapping from big_vision import input_pipeline +from big_vision.datasets import core as ds_core from big_vision.pp import builder as pp_builder import jax @@ -49,10 +50,12 @@ class Evaluator: per-example metrics of shape [batch_size]. """ - def __init__(self, predict_fn, dataset, split, pp_fn, batch_size, - data_dir=None, cache_final=True, cache_raw=False, prefetch=1): + def __init__(self, predict_fn, data, pp_fn, batch_size, + cache_final=True, cache_raw=False, prefetch=1): + data = ds_core.get(**data) self.dataset, self.steps = input_pipeline.make_for_inference( - dataset=dataset, split=split, data_dir=data_dir, batch_size=batch_size, + data.get_tfdata(ordered=True), batch_size=batch_size, + num_ex_per_process=data.num_examples_per_process(), preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), cache_final=cache_final, cache_raw=cache_raw) self.data_iter = input_pipeline.start_input_pipeline(self.dataset, prefetch) diff --git a/big_vision/evaluators/proj/uvim/colorization.py b/big_vision/evaluators/proj/uvim/psnr.py similarity index 88% rename from big_vision/evaluators/proj/uvim/colorization.py rename to big_vision/evaluators/proj/uvim/psnr.py index a376f4c..dc9404b 100644 --- a/big_vision/evaluators/proj/uvim/colorization.py +++ b/big_vision/evaluators/proj/uvim/psnr.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Evaluation for image colorization.""" +"""Compute PSNR, currently used for colorization and superresolution.""" import functools @@ -25,17 +25,18 @@ class Evaluator: - """Colorization evaluator. + """PSNR evaluator. `predict_fn` accepts arbitrary dictionaries of parameters and data, where the data dictionary is produced by the `pp_fn` op. It is expected to output a - dict with `color` containing an RGB image with intensities in [-1,1]. + single-key dict containing an RGB image with intensities in [-1,1]. """ def __init__(self, predict_fn, pp_fn, batch_size, + dataset="imagenet2012", split="validation", predict_kwargs=None): @@ -44,7 +45,9 @@ def predict(params, batch): def _f(x): y = predict_fn(params, x, **(predict_kwargs or {})) # Assume image intensities are in [-1,1]. - return _psnr(y["color"], x["labels"], 2.) + # Evaluator expects a dict with a single item. + pred, = y.values() + return _psnr(pred, x["labels"], 2.) return jax.lax.all_gather({ "mask": batch["mask"], "psnr": _f(batch["input"]), @@ -61,7 +64,7 @@ def preprocess(example): } self.data = common.get_jax_process_dataset( - "imagenet2012", + dataset, split, global_batch_size=batch_size, add_tfds_id=True, @@ -88,7 +91,7 @@ def run(self, params): if jax.process_index(): # Host0 gets all preds and does eval. return - yield "resizedPSNR", np.mean(psnrs) + yield "PSNR", np.mean(psnrs) @functools.partial(jax.vmap, in_axes=[0, 0, None]) diff --git a/big_vision/input_pipeline.py b/big_vision/input_pipeline.py index 2eeed81..6f871f1 100644 --- a/big_vision/input_pipeline.py +++ b/big_vision/input_pipeline.py @@ -13,62 +13,20 @@ # limitations under the License. """ImageNet input pipeline.""" -import functools import math + import einops import flax.jax_utils as flax_utils import jax - import tensorflow as tf -import tensorflow_datasets as tfds - - -@functools.lru_cache(maxsize=None) -def get_builder(dataset, data_dir): - return tfds.builder(dataset, data_dir=data_dir, try_gcs=True) - - -def get_num_examples(dataset, split, data_dir=None): - builder = get_builder(dataset, data_dir) - return builder.info.splits[split].num_examples - - -def get_max_examples_per_host(dataset, split, data_dir=None): - """Returns the max number of examples accross all hosts.""" - splits = tfds.even_splits(split, jax.process_count()) - return max([get_num_examples(dataset, s, data_dir) for s in splits]) - - -def get_dataset_tfds(dataset="imagenet2012", split="train", - shuffle_files=True, data_dir=None, skip_decode=("image",)): - """Data provider.""" - builder = get_builder(dataset, data_dir) - split = tfds.even_splits(split, jax.process_count())[jax.process_index()] - skip_decoders = { - f: tfds.decode.SkipDecoding() - for f in skip_decode - if f in builder.info.features - } - # Each host is responsible for a fixed subset of data - return builder.as_dataset( - split=split, - shuffle_files=shuffle_files, - read_config=tfds.ReadConfig( - skip_prefetch=True, # We prefetch after pipeline. - try_autocache=False, # We control this, esp. for few-shot. - add_tfds_id=True, - ), - decoders=skip_decoders), builder def make_for_train( - dataset, split, preprocess_fn, batch_size, - shuffle_buffer_size, cache_raw=False, data_dir=None, filter_fn=None, + data, preprocess_fn, batch_size, + shuffle_buffer_size, cache_raw=False, filter_fn=None, num_parallel_calls=100, prefetch=2): """Makes an input pipeline for training.""" - data, _ = get_dataset_tfds(dataset=dataset, split=split, - shuffle_files=True, data_dir=data_dir) data = _add_tpu_host_options(data) # Use data filtering at your own risk: the actual split sizes won't be known @@ -92,23 +50,23 @@ def make_for_train( # Andreas Steiner and also implemented by him in the clu library: # https://github.com/google/CommonLoopUtils/blob/84b777c42dfd3fb6685537138433bfeb5241a006/clu/deterministic_data.py#L304. def make_for_inference( - dataset, split, preprocess_fn, batch_size, data_dir=None, + data, preprocess_fn, batch_size, num_ex_per_process, cache_raw=False, cache_final=False): """Makes an input pipeline for inference.""" - data, _ = get_dataset_tfds(dataset=dataset, split=split, - shuffle_files=False, data_dir=data_dir) + data = _add_tpu_host_options(data) data = data.cache() if cache_raw else data data = data.map(_add_mask(preprocess_fn), num_parallel_calls=100) data = data.concatenate(_get_pad_data(data)) + + local_batch_size = batch_size // jax.process_count() # Since we do 'infinite' padding it is safe to drop the remainder. - data = data.batch(batch_size // jax.process_count(), drop_remainder=True) + data = data.batch(local_batch_size, drop_remainder=True) # We need to make sure that all hosts process all data and exactly the same # number of batches. Below we take max per-host num examples and use it on all # hosts to derive the number of batches. - n = get_max_examples_per_host(dataset, split, data_dir) - num_batches = math.ceil(n / (batch_size // jax.process_count())) + num_batches = math.ceil(max(num_ex_per_process) / local_batch_size) data = data.take(num_batches) # Note we cache data after a finite number of batches is taken. @@ -147,5 +105,3 @@ def start_input_pipeline(data, n_prefetch, shard=True): if shard and n_prefetch: # Only works for pmap. it = flax_utils.prefetch_to_device(it, n_prefetch) return it - - diff --git a/big_vision/models/proj/flaxformer/bert_test.py b/big_vision/models/proj/flaxformer/bert_test.py index d4ac6f3..a81fbe5 100644 --- a/big_vision/models/proj/flaxformer/bert_test.py +++ b/big_vision/models/proj/flaxformer/bert_test.py @@ -15,7 +15,6 @@ """Tests for bert.""" import tempfile -from unittest import mock from big_vision import input_pipeline from big_vision.models.proj.flaxformer import bert @@ -44,19 +43,15 @@ class BertTest(tf.test.TestCase): - @mock.patch("tensorflow_datasets.builder") - def test_load_apply(self, mock_builder): + def test_load_apply(self): inkey = "text" - ds = tf.data.Dataset.from_tensor_slices( - {inkey: tf.ragged.constant([["this is a test"]])}) - mock_builder.return_value.as_dataset.return_value = ds - mock_builder.return_value.info.splits.__getitem__.return_value.num_examples = 1 vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" with open(vocab_path, "w") as f: f.write("\n".join(_BERT_VOCAB)) ds2, _ = input_pipeline.make_for_inference( - dataset="mocked", - split="test", + tf.data.Dataset.from_tensor_slices( + {inkey: tf.ragged.constant([["this is a test"]])}), + num_ex_per_process=[1], preprocess_fn=pp_builder.get_preprocess_fn( f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', " f"max_len={_TOKEN_LEN})" diff --git a/big_vision/models/vit.py b/big_vision/models/vit.py index 83859c0..e978302 100644 --- a/big_vision/models/vit.py +++ b/big_vision/models/vit.py @@ -246,10 +246,10 @@ def decode_variant(variant): return { # pylint:disable=line-too-long # Reference: Table 2 of https://arxiv.org/abs/2106.04560. - "width": {"Ti": 192, "S": 384, "M": 512, "B": 768, "L": 1024, "H": 1280, "g": 1408, "G": 1664}[v], - "depth": {"Ti": 12, "S": 12, "M": 12, "B": 12, "L": 24, "H": 32, "g": 40, "G": 48}[v], - "mlp_dim": {"Ti": 768, "S": 1536, "M": 2048, "B": 3072, "L": 4096, "H": 5120, "g": 6144, "G": 8192}[v], - "num_heads": {"Ti": 3, "S": 6, "M": 8, "B": 12, "L": 16, "H": 16, "g": 16, "G": 16}[v], + "width": {"Ti": 192, "S": 384, "M": 512, "B": 768, "L": 1024, "H": 1280, "g": 1408, "G": 1664, "e": 1792}[v], + "depth": {"Ti": 12, "S": 12, "M": 12, "B": 12, "L": 24, "H": 32, "g": 40, "G": 48, "e": 56}[v], + "mlp_dim": {"Ti": 768, "S": 1536, "M": 2048, "B": 3072, "L": 4096, "H": 5120, "g": 6144, "G": 8192, "e": 15360}[v], + "num_heads": {"Ti": 3, "S": 6, "M": 8, "B": 12, "L": 16, "H": 16, "g": 16, "G": 16, "e": 16}[v], # pylint:enable=line-too-long **patch } @@ -314,29 +314,9 @@ def fix_old_checkpoints(params): def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above. """Load init from checkpoint, both old model and this one. +Hi-res posemb.""" - del model_cfg - # Shortcut names for some canonical paper checkpoints: - init_file = { - # pylint: disable=line-too-long - # pylint: disable=line-too-long - # Recommended models from https://arxiv.org/abs/2106.10270 - # Many more models at https://github.com/google-research/vision_transformer - "howto-i21k-Ti/16": "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", - "howto-i21k-S/32": "gs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", - "howto-i21k-S/16": "gs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", - "howto-i21k-B/32": "gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", - "howto-i21k-B/16": "gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", - "howto-i21k-B/8": "gs://vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz", - "howto-i21k-L/16": "gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", - - # Better plain vit-s16 baselines from https://arxiv.org/abs/2205.01580 - "i1k-s16-90ep": "gs://big_vision/vit_s16_i1k_90ep.npz", - "i1k-s16-150ep": "gs://big_vision/vit_s16_i1k_150ep.npz", - "i1k-s16-300ep": "gs://big_vision/vit_s16_i1k_300ep.npz", - # pylint: disable=line-too-long - # pylint: enable=line-too-long - }.get(init_file, init_file) + + init_file = VANITY_NAMES.get(init_file, init_file) restored_params = utils.load_params(None, init_file) restored_params = fix_old_checkpoints(restored_params) @@ -351,3 +331,26 @@ def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=in new=init_params["pos_embedding"]) return restored_params + + +# Shortcut names for some canonical paper checkpoints: +VANITY_NAMES = { + # pylint: disable=line-too-long + # pylint: disable=line-too-long + # Recommended models from https://arxiv.org/abs/2106.10270 + # Many more models at https://github.com/google-research/vision_transformer + "howto-i21k-Ti/16": "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", + "howto-i21k-S/32": "gs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-S/16": "gs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", + "howto-i21k-B/32": "gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-B/16": "gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-B/8": "gs://vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz", + "howto-i21k-L/16": "gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", + + # Better plain vit-s16 baselines from https://arxiv.org/abs/2205.01580 + "i1k-s16-90ep": "gs://big_vision/vit_s16_i1k_90ep.npz", + "i1k-s16-150ep": "gs://big_vision/vit_s16_i1k_150ep.npz", + "i1k-s16-300ep": "gs://big_vision/vit_s16_i1k_300ep.npz", + # pylint: disable=line-too-long + # pylint: enable=line-too-long +} diff --git a/big_vision/pp/proj/flaxformer/bert_ops_test.py b/big_vision/pp/proj/flaxformer/bert_ops_test.py index 2ad2c94..620b8be 100644 --- a/big_vision/pp/proj/flaxformer/bert_ops_test.py +++ b/big_vision/pp/proj/flaxformer/bert_ops_test.py @@ -15,7 +15,6 @@ """Tests for bert_ops.""" import tempfile -from unittest import mock from big_vision import input_pipeline import big_vision.pp.builder as pp_builder @@ -36,14 +35,10 @@ ] -def _create_ds( - mock_builder, pp_str, tensor_slices, num_examples, remove_tpu_dtypes): - ds = tf.data.Dataset.from_tensor_slices(tensor_slices) - mock_builder.return_value.as_dataset.return_value = ds - mock_builder.return_value.info.splits.__getitem__.return_value.num_examples = num_examples +def _create_ds(pp_str, tensor_slices, num_examples, remove_tpu_dtypes): return input_pipeline.make_for_inference( - dataset="mocked", - split="test", + tf.data.Dataset.from_tensor_slices(tensor_slices), + num_ex_per_process=[num_examples], preprocess_fn=pp_builder.get_preprocess_fn( pp_str, remove_tpu_dtypes=remove_tpu_dtypes), batch_size=num_examples, @@ -52,8 +47,7 @@ def _create_ds( class BertOpsTest(tf.test.TestCase): - @mock.patch("tensorflow_datasets.builder") - def test_tokenize(self, mock_builder): + def test_tokenize(self): inkey = "texts" vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" with open(vocab_path, "w") as f: @@ -65,7 +59,7 @@ def test_tokenize(self, mock_builder): tensor_slices = { inkey: tf.ragged.constant([["one more"], ["more than one"], [""]]) } - ds = _create_ds(mock_builder, pp_str, tensor_slices, 3, True) + ds = _create_ds(pp_str, tensor_slices, 3, True) self.assertAllEqual( next(iter(ds))["labels"], [[5, 4, 2, 0, 0], [5, 2, 3, 4, 0], [5, 0, 0, 0, 0]], diff --git a/big_vision/tools/eval_only.py b/big_vision/tools/eval_only.py index 01e934f..71d13e8 100644 --- a/big_vision/tools/eval_only.py +++ b/big_vision/tools/eval_only.py @@ -84,7 +84,8 @@ def init(rng): dummy_inputs = [jnp.zeros(s, t) for s, t in zip(input_shapes, input_types)] return flax.core.unfreeze(model.init(rng, *dummy_inputs))["params"] - params_cpu = init(jax.random.PRNGKey(42)) + with u.log_timing(mw, "z/secs/init"): + params_cpu = init(jax.random.PRNGKey(42)) if jax.process_index() == 0: parameter_overview.log_parameter_overview(params_cpu, msg="init params") @@ -112,15 +113,17 @@ def predict_fn(params, *a, **kw): for (name, evaluator, _, prefix) in evaluators: write_note(f"{name} evaluation...") with u.profile(name): - for key, value in evaluator.run(params_repl): - mw.measure(f"{prefix}{key}", value) + with u.log_timing(mw, f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + u.sync() # sync barrier to get correct measurements mw.step_end() write_note("Done!") mw.close() # Make sure all hosts stay up until the end of main. - u.sync_all_hosts() + u.sync() if workdir and flags.FLAGS.cleanup and jax.process_index() == 0: gfile.rmtree(workdir) diff --git a/big_vision/train.py b/big_vision/train.py index d9fab58..7d85537 100644 --- a/big_vision/train.py +++ b/big_vision/train.py @@ -25,6 +25,7 @@ from absl import app from absl import flags from absl import logging +import big_vision.datasets.core as ds_core import big_vision.evaluators.common as eval_common import big_vision.input_pipeline as input_pipeline import big_vision.optax as bv_optax @@ -38,7 +39,8 @@ import numpy as np import optax import tensorflow as tf -import tensorflow.io.gfile as gfile + +from tensorflow.io import gfile # pylint: disable=logging-fstring-interpolation @@ -65,10 +67,8 @@ def main(argv): f"{jax.local_device_count()}/{jax.device_count()} devices and " f"writing to workdir {workdir}.\u001b[0m") - assert not config.get("grad_accum_steps"), "Grad-acc not supported anymore." - save_ckpt_path = None - if workdir and (config.get("ckpt_steps") or config.get("keep_ckpt_steps")): + if workdir: # Always create if requested, even if we may not write into it. gfile.makedirs(workdir) save_ckpt_path = os.path.join(workdir, "checkpoint.npz") @@ -76,7 +76,7 @@ def main(argv): pool = multiprocessing.pool.ThreadPool() # Here we register preprocessing ops from modules listed on `pp_modules`. - for m in config.get("pp_modules", ["ops_general", "ops_image"]): + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): importlib.import_module(f"big_vision.pp.{m}") # This seed makes the Jax part of things (like model init) deterministic. @@ -95,7 +95,9 @@ def write_note(note): if jax.process_index() == 0: info("%s", note) - batch_size = config.batch_size + write_note("Initializing...") + + batch_size = config.input.batch_size if batch_size % jax.device_count() != 0: raise ValueError(f"Batch size ({batch_size}) must " f"be divisible by device number ({jax.device_count()})") @@ -110,21 +112,20 @@ def write_note(note): chrono = u.Chrono() write_note("Initializing train dataset...") + train_data = ds_core.get(**config.input.data) train_ds = input_pipeline.make_for_train( - dataset=config.dataset, - split=config.train_split, - batch_size=config.batch_size, - preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), - shuffle_buffer_size=config.get("shuffle_buffer_size"), - cache_raw=config.get("cache_raw", False), - data_dir=fillin(config.get("dataset_dir"))) - + data=train_data.get_tfdata(ordered=False), + batch_size=batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(config.input.get("pp")), + shuffle_buffer_size=config.input.get("shuffle_buffer_size"), + cache_raw=config.input.get("cache_raw", False), + filter_fn=config.input.get("filter_fn"), + ) + + # Start prefetching already. n_prefetch = config.get("prefetch_to_device", 1) train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) - - ntrain_img = input_pipeline.get_num_examples( - config.dataset, config.train_split, - data_dir=fillin(config.get("dataset_dir"))) + ntrain_img = train_data.total_examples def get_steps(name, default=ValueError): # partial doesn't work well here. return u.steps(name, config, ntrain_img, batch_size, default) @@ -144,7 +145,7 @@ def get_steps(name, default=ValueError): # partial doesn't work well here. @partial(jax.jit, backend="cpu") def init(rng): shape = tuple(train_ds.element_spec["image"].shape[1:]) - bs = config.batch_size // jax.device_count() + bs = batch_size // jax.device_count() dummy_input = jnp.zeros((bs,) + shape, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input))["params"] @@ -156,7 +157,8 @@ def init(rng): return params rng, rng_init = jax.random.split(rng) - params_cpu = init(rng_init) + with u.log_timing(mw, "z/secs/init"): + params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_leaves(params_cpu)) @@ -186,7 +188,7 @@ def update_fn(params, opt, rng, images, labels): def loss_fn(params, images, labels): logits, _ = model.apply( - {"params": flax.core.freeze(params)}, images, + {"params": params}, images, train=True, rngs={"dropout": rng_model_local}) return getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels) @@ -253,9 +255,9 @@ def predict_fn(params, image): params_repl = flax.jax_utils.replicate(params_cpu) opt_repl = flax.jax_utils.replicate(opt_cpu) - evaluators = eval_common.from_config( - config, {"predict": predict_fn}, - lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}")) + # Initializing evaluators later when they are first needed, so we can see + # issues with training faster. + evaluators = None rng, rng_loop = jax.random.split(rng, 2) rngs_loop = flax.jax_utils.replicate(rng_loop) @@ -263,14 +265,16 @@ def predict_fn(params, image): write_note(f"First step compilations...\n{chrono.note}") error = None # For exiting with an error after cleanup. Avoids indentation. + # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): mw.step_start(step) with jax.profiler.StepTraceAnnotation("train_step", step_num=step): - params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( - params_repl, opt_repl, rngs_loop, batch["image"], batch["labels"]) + with u.log_timing(mw, "z/secs/update0", noop=step > first_step + 1): + params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( + params_repl, opt_repl, rngs_loop, batch["image"], batch["labels"]) # On the first host, let's always profile a handful of early steps. if jax.process_index() == 0: @@ -312,12 +316,18 @@ def predict_fn(params, image): u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) chrono.resume() + if evaluators is None: + evaluators = eval_common.from_config( + config, {"predict": predict_fn}, + lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}") + ) for (name, evaluator, log_steps, prefix) in evaluators: if u.itstime(step, log_steps, total_steps): chrono.pause(wait_for=params_repl) write_note(f"{name} evaluation...\n{chrono.note}") - for key, value in evaluator.run(params_repl): - mw.measure(f"{prefix}{key}", value) + with u.log_timing(mw, f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) chrono.resume() mw.step_end() @@ -337,7 +347,7 @@ def predict_fn(params, image): mw.close() # Make sure all hosts stay up until the end of main. - u.sync_all_hosts() + u.sync() # Before cleanup, as cleanup should only run for successful jobs. if error is not None: diff --git a/big_vision/trainers/proj/distill/distill.py b/big_vision/trainers/proj/distill/distill.py index 4f59abd..6fecb7e 100644 --- a/big_vision/trainers/proj/distill/distill.py +++ b/big_vision/trainers/proj/distill/distill.py @@ -34,9 +34,10 @@ from absl import app from absl import flags from absl import logging -from big_vision import input_pipeline +import big_vision.datasets.core as ds_core import big_vision.evaluators.common as eval_common import big_vision.evaluators.proj.distill.distance as dd +import big_vision.input_pipeline as input_pipeline import big_vision.optax as bv_optax import big_vision.pp.builder as pp_builder import big_vision.utils as u @@ -48,7 +49,7 @@ import numpy as np import optax import tensorflow as tf -import tensorflow.io.gfile as gfile +from tensorflow.io import gfile # pylint: disable=logging-fstring-interpolation @@ -82,10 +83,8 @@ def main(argv): f"{jax.local_device_count()}/{jax.device_count()} devices and " f"writing to workdir {workdir}.\u001b[0m") - assert not config.get("grad_accum_steps"), "Grad-acc not supported anymore." - save_ckpt_path = None - if workdir and config.get("ckpt_steps"): + if workdir: # Always create if requested, even if we may not write into it. gfile.makedirs(workdir) save_ckpt_path = os.path.join(workdir, "checkpoint.npz") @@ -93,7 +92,7 @@ def main(argv): pool = multiprocessing.pool.ThreadPool() # Here we register preprocessing ops from modules listed on `pp_modules`. - for m in config.get("pp_modules", ["ops_general", "ops_image"]): + for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): importlib.import_module(f"big_vision.pp.{m}") # This seed makes the Jax part of things (like model init) deterministic. @@ -112,14 +111,9 @@ def write_note(note): if jax.process_index() == 0: info("%s", note) - # Verify settings to make sure no checkpoints are accidentally missed. - if config.get("keep_ckpt_steps"): - assert config.get("ckpt_steps"), "Specify `ckpt_steps`." - assert config.keep_ckpt_steps % config.ckpt_steps == 0, ( - f"`keep_ckpt_steps` ({config.ckpt_steps}) should be" - f"divisible by `ckpt_steps ({config.ckpt_steps}).`") + write_note("Initializing...") - batch_size = config.batch_size + batch_size = config.input.batch_size if batch_size % jax.device_count() != 0: raise ValueError(f"Batch size ({batch_size}) must " f"be divisible by device number ({jax.device_count()})") @@ -134,31 +128,27 @@ def write_note(note): chrono = u.Chrono() write_note("Initializing train dataset...") + train_data = ds_core.get(**config.input.data) train_ds = input_pipeline.make_for_train( - dataset=config.dataset, - split=config.train_split, - batch_size=config.batch_size, - preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), - shuffle_buffer_size=config.get("shuffle_buffer_size"), - cache_raw=config.get("cache_raw", False), - data_dir=fillin(config.get("dataset_dir"))) - + data=train_data.get_tfdata(ordered=False), + batch_size=batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(config.input.get("pp")), + shuffle_buffer_size=config.input.get("shuffle_buffer_size"), + cache_raw=config.input.get("cache_raw", False), + filter_fn=config.input.get("filter_fn"), + ) + + # Start prefetching already. n_prefetch = config.get("prefetch_to_device", 1) train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) + ntrain_img = train_data.total_examples - ntrain_img = input_pipeline.get_num_examples( - config.dataset, config.train_split, - data_dir=fillin(config.get("dataset_dir"))) - steps_per_epoch = ntrain_img / batch_size + def get_steps(name, default=ValueError): # partial doesn't work well here. + return u.steps(name, config, ntrain_img, batch_size, default) + total_steps = get_steps("total") - if config.get("total_epochs"): - total_steps = int(config.total_epochs * steps_per_epoch) - assert not config.get("total_steps"), "Set only one of total_(epochs|steps)" - else: - total_steps = config.total_steps - - info("Running for %d steps, that means %f epochs and %f steps per epoch", - total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) # Create student and teacher models def get_model_mod(name): # Used many times. @@ -182,14 +172,14 @@ def get_init(model): @partial(jax.jit, backend="cpu") def _init(rng): shape = tuple(train_ds.element_spec["image"].shape[1:]) - bs = config.batch_size // jax.device_count() + bs = batch_size // jax.device_count() dummy_input = jnp.zeros((bs,) + shape, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input))["params"] # Set bias in the head to a low value, such that loss is small initially. if "init_head_bias" in config: - params["head"]["bias"] = jnp.full_like( - params["head"]["bias"], config["init_head_bias"]) + params["head"]["bias"] = jnp.full_like(params["head"]["bias"], + config["init_head_bias"]) return params return _init @@ -197,7 +187,6 @@ def _init(rng): params_cpu = {name: get_init(models[name])(rngi) for name, rngi in zip(models, rng_inits)} - # Log all parameters to helping debugging. if jax.process_index() == 0: for name, params in params_cpu.items(): parameter_overview.log_parameter_overview(params, msg=f"{name} params") @@ -210,9 +199,7 @@ def _init(rng): # or similar for good (we explored but ditched), need to refactor this a bit. tx, sched_fns = bv_optax.make( config, params_cpu["student"], sched_kw=dict( - global_batch_size=batch_size, - total_steps=total_steps, - steps_per_epoch=steps_per_epoch)) + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) # We jit this, such that the arrays are created on the CPU, not device[0]. opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu["student"]) @@ -288,16 +275,6 @@ def update_fn(params, opt, rng, data): return params, opt, rng, l, measurements - # TODO: implement a distillation evaluator. - # @partial(jax.pmap, axis_name="batch") - # def evaluation_distill_fn(params, nparams, data): - # mask = data["mask"] - # _, extra_losses = loss_fn(params, nparams, data, reduce=False) - # losses = extra_losses["distill_loss"] - # loss = jax.lax.psum(losses * mask, axis_name="batch") - # n = jax.lax.psum(mask, axis_name="batch") - # return loss, n - # We always load the teachers first, because they NEED to be initialized # and since we don't ever modify them, we don't store them in checkpoints. for name in config.teachers: @@ -339,13 +316,17 @@ def update_fn(params, opt, rng, data): write_note("Kicking off misc stuff...") first_step = bv_optax.get_count(opt_cpu) - chrono.inform(first_step, total_steps, batch_size, steps_per_epoch) + chrono.inform(first_step, total_steps, batch_size, ntrain_img / batch_size) prof = None # Keeps track of start/stop of profiler state. write_note(f"Replicating...\n{chrono.note}") params_repl = flax.jax_utils.replicate(params_cpu) opt_repl = flax.jax_utils.replicate(opt_cpu) + # Initializing evaluators later when they are first needed, so we can see + # issues with training faster. + evaluators = None + # Define predict functions that the evaluators can use: # 1. One per model predict_fns = {} @@ -369,32 +350,28 @@ def fwd(params, image, n=name): # pylint: disable=function-redefined return student_ret, teacher_ret predict_fns[f"student_{name}_fwd"] = fwd - evaluators = eval_common.from_config( - config, predict_fns, - lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}")) - rng, rng_loop = jax.random.split(rng, 2) rngs_loop = flax.jax_utils.replicate(rng_loop) ckpt_writer = None write_note(f"First step compilations...\n{chrono.note}") error = None # For exiting with an error after cleanup. Avoids indentation. + # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. - for step, train_batch in zip( - range(first_step + 1, total_steps + 1), train_iter): + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): mw.step_start(step) with jax.profiler.StepTraceAnnotation("train_step", step_num=step): params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( - params_repl, opt_repl, rngs_loop, train_batch) + params_repl, opt_repl, rngs_loop, batch) # On the first host, let's always profile a handful of early steps. if jax.process_index() == 0: - prof = u.startstop_prof(prof, step, first_step, config.log_training_steps) + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) # Report training progress - if (u.itstime(step, config.log_training_steps, total_steps, host=0) + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) or chrono.warmup and jax.process_index() == 0): for i, sched_fn_cpu in enumerate(sched_fns_cpu): mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) @@ -404,12 +381,13 @@ def fwd(params, image, n=name): # pylint: disable=function-redefined chrono.tick(step, mw.measure, write_note) if not np.isfinite(l): error = (f"The loss became nan or inf somewhere within steps " - f"[{step - config.log_training_steps}, {step}]") + f"[{step - get_steps('log_training')}, {step}]") break # Checkpoint saving if (save_ckpt_path and - u.itstime(step, config.get("ckpt_steps"), total_steps, host=0)): + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): chrono.pause(wait_for=(params_repl["student"], opt_repl)) u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) # We need to transfer the weights over now or else we risk keeping them @@ -420,7 +398,7 @@ def fwd(params, image, n=name): # pylint: disable=function-redefined # Check whether we want to keep a copy of the current checkpoint. copy_step = None - if u.itstime(step, config.get("keep_ckpt_steps"), total_steps): + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): copy_step = step ckpt = {"params": params_cpu["student"], @@ -430,18 +408,11 @@ def fwd(params, image, n=name): # pylint: disable=function-redefined u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) chrono.resume() - # TODO: Evaluator to compute distillation loss/distance on val. - # if u.itstime(step, config.log_eval_steps, total_steps): - # chrono.pause() - # for val_name, (val_iter, val_steps) in val_ds.items(): - # loss, nseen = 0, 0 - # for _, batch in zip(range(val_steps), val_iter): - # batch_losses, batch_n = evaluation_distill_fn( - # opt_repl.target, nopt_repl, batch) - # loss += np.sum(np.array(batch_losses[0])) - # nseen += np.sum(np.array(batch_n[0])) - # mw.measure(f"{val_name}_distill_loss", loss / nseen) - + if evaluators is None: + evaluators = eval_common.from_config( + config, predict_fns, + lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}") + ) for (name, evaluator, log_steps, prefix) in evaluators: if u.itstime(step, log_steps, total_steps): chrono.pause(wait_for=params_repl) @@ -467,7 +438,7 @@ def fwd(params, image, n=name): # pylint: disable=function-redefined mw.close() # Make sure all hosts stay up until the end of main. - u.sync_all_hosts() + u.sync() # Before cleanup, as cleanup should only run for successful jobs. if error is not None: diff --git a/big_vision/trainers/proj/image_text/contrastive.py b/big_vision/trainers/proj/image_text/contrastive.py index 234baf1..0188f99 100644 --- a/big_vision/trainers/proj/image_text/contrastive.py +++ b/big_vision/trainers/proj/image_text/contrastive.py @@ -18,7 +18,7 @@ - LiT (https://arxiv.org/abs/2111.07991) - CLIP (https://arxiv.org/abs/2103.00020) """ - +# pylint: disable=consider-using-from-import from functools import partial import importlib import multiprocessing.pool @@ -27,8 +27,9 @@ from absl import app from absl import flags from absl import logging -from big_vision import input_pipeline +import big_vision.datasets.core as ds_core import big_vision.evaluators.common as eval_common +import big_vision.input_pipeline as input_pipeline import big_vision.optax as bv_optax import big_vision.pp.builder as pp_builder import big_vision.utils as u @@ -39,9 +40,12 @@ from ml_collections import config_flags import numpy as np import optax +import tensorflow as tf from tensorflow.io import gfile +# pylint: disable=logging-fstring-interpolation + config_flags.DEFINE_config_file( "config", None, "Training configuration.", lock_config=True) @@ -62,6 +66,7 @@ def all_gather(z): def main(argv): del argv + tf.config.experimental.set_visible_devices([], "GPU") config = flags.FLAGS.config workdir = flags.FLAGS.workdir @@ -70,10 +75,8 @@ def main(argv): f"{jax.local_device_count()}/{jax.device_count()} devices and " f"writing to workdir {workdir}.\u001b[0m") - assert not config.get("grad_accum_steps"), "Grad-acc not supported anymore." - save_ckpt_path = None - if workdir and config.get("ckpt_steps"): + if workdir: # Always create if requested, even if we may not write into it. gfile.makedirs(workdir) save_ckpt_path = os.path.join(workdir, "checkpoint.npz") @@ -101,61 +104,42 @@ def write_note(note): write_note("Initializing...") - # Verify settings to make sure no checkpoints are accidentally missed. - if config.get("keep_ckpt_steps"): - assert config.get("ckpt_steps"), "Specify `ckpt_steps`." - assert config.keep_ckpt_steps % config.ckpt_steps == 0, ( - f"`keep_ckpt_steps` ({config.ckpt_steps}) should be" - f"divisible by `ckpt_steps ({config.ckpt_steps}).`") - - batch_size = config.batch_size - batch_size_eval = config.get("batch_size_eval", batch_size) - if (batch_size % jax.device_count() != 0 or - batch_size_eval % jax.device_count() != 0): - raise ValueError(f"Batch sizes ({batch_size} and {batch_size_eval}) must " + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) - local_batch_size = batch_size // jax.process_count() - info( - "Global batch size %d on %d hosts results in %d local batch size. " - "With %d dev per host (%d dev total), that's a %d per-device batch size.", - batch_size, jax.process_count(), local_batch_size, - jax.local_device_count(), jax.device_count(), - local_batch_size // jax.local_device_count()) + # First thing after above sanity checks, so we can log "start" ticks. + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) + chrono = u.Chrono() write_note("Initializing training pipeline...") + train_data = ds_core.get(**config.input.data) train_ds = input_pipeline.make_for_train( - dataset=config.dataset, - split=config.train_split, - batch_size=config.batch_size, - preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), - shuffle_buffer_size=config.get("shuffle_buffer_size"), - cache_raw=config.get("cache_raw", False), - data_dir=config.get("dataset_dir"), - filter_fn=config.get("filter_fn"), + data=train_data.get_tfdata(ordered=False), + batch_size=batch_size, + preprocess_fn=pp_builder.get_preprocess_fn(config.input.get("pp")), + shuffle_buffer_size=config.input.get("shuffle_buffer_size"), + cache_raw=config.input.get("cache_raw", False), + filter_fn=config.input.get("filter_fn"), ) # Start prefetching already. n_prefetch = config.get("prefetch_to_device", 1) train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) - # We always pad to local_batch_size_eval even when less would be enough in - # order to minimize memory fragmentation. + ntrain_img = train_data.total_examples - ntrain_img = input_pipeline.get_num_examples( - config.dataset, config.train_split, data_dir=config.get("dataset_dir")) - steps_per_epoch = ntrain_img / batch_size + def get_steps(name, default=ValueError): # `partial` doesn't work well here. + return u.steps(name, config, ntrain_img, batch_size, default) + total_steps = get_steps("total") - if config.get("total_epochs"): - total_steps = int(config.total_epochs * steps_per_epoch) - assert not config.get("total_steps"), "Set only one of total_(epochs|steps)" - else: - total_steps = config.total_steps - - info( - "Running for %d steps, that means %f epochs and %f steps per epoch", - total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) - mw = u.BigVisionMetricWriter(xid, wid, workdir, config) - chrono = u.Chrono() + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) write_note(f"Initializing {config.model_name} model...") model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") @@ -166,33 +150,33 @@ def write_note(note): # situations where we allocate them twice. @partial(jax.jit, backend="cpu") def init(rng): + bs = batch_size // jax.device_count() image_size = tuple(train_ds.element_spec["image"].shape[1:]) - no_image = jnp.zeros((local_batch_size,) + image_size, jnp.float32) + no_image = jnp.zeros((bs,) + image_size, jnp.float32) text_size = tuple(train_ds.element_spec["labels"].shape[1:]) - no_text = jnp.zeros((local_batch_size,) + text_size, jnp.int32) + no_text = jnp.zeros((bs,) + text_size, jnp.int32) params = flax.core.unfreeze(model.init(rng, no_image, no_text))["params"] return params rng, rng_init = jax.random.split(rng) - params_cpu = init(rng_init) + with u.log_timing(mw, "z/secs/init"): + params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_leaves(params_cpu)) - parameter_overview.log_parameter_overview(params_cpu) + parameter_overview.log_parameter_overview(params_cpu, msg="init params") mw.measure("num_params", num_params) - tx, sched_fns = bv_optax.make( - config, - params_cpu, - sched_kw=dict( - global_batch_size=batch_size, - total_steps=total_steps, - steps_per_epoch=steps_per_epoch)) + write_note(f"Initializing {config.optax_name} optimizer...") + tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) + + # We jit this, such that the arrays are created on the CPU, not device[0]. opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] - @partial(jax.pmap, axis_name="batch", donate_argnums=(0,)) - def update_fn(params, opt, batch, rng): + @partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1)) + def update_fn(params, opt, rng, batch): """Update step.""" assert "mixup" not in config, "We still have to figure out mixup." @@ -233,9 +217,12 @@ def loss_fn(params, images, labels): us = jax.tree_leaves(updates) measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) - return params, opt, l, rng, measurements + return params, opt, rng, l, measurements # We require hashable function reference for evaluator. + # We do not jit/pmap this function, because it is passed to evaluator that + # does it later. We output as many intermediate tensors as possible for + # maximal flexibility. Later `jit` will prune out things that are not needed. def predict_fn(params, image=None, text=None, **unused_kwargs): del unused_kwargs # `unused_kwargs` is to be compatible with few-shot zimg, ztxt, out = model.apply({"params": params}, image, text) @@ -266,23 +253,23 @@ def predict_fn(params, image=None, text=None, **unused_kwargs): chrono.load(checkpoint["chrono"]) elif config.get("model_init"): write_note(f"Initialize model from {config.model_init}...") - params_cpu = model_mod.load(params_cpu, config.model_init, - config.get("model"), - **config.get("model_load", {})) + params_cpu = model_mod.load( + params_cpu, config.model_init, config.get("model"), + **config.get("model_load", {})) if jax.process_index() == 0: - info("Restored parameter overview:") - parameter_overview.log_parameter_overview(params_cpu) + parameter_overview.log_parameter_overview( + params_cpu, msg="restored params") write_note("Kicking off misc stuff...") first_step = bv_optax.get_count(opt_cpu) - chrono.inform(first_step, total_steps, batch_size, steps_per_epoch) + chrono.inform(first_step, total_steps, batch_size, ntrain_img / batch_size) prof = None # Keeps track of start/stop of profiler state. write_note(f"Replicating...\n{chrono.note}") params_repl = flax.jax_utils.replicate(params_cpu) opt_repl = flax.jax_utils.replicate(opt_cpu) - # Initialise evaluators later when they are first needed, so we can see + # Initializing evaluators later when they are first needed, so we can see # issues with training faster. evaluators = None @@ -295,70 +282,67 @@ def predict_fn(params, image=None, text=None, **unused_kwargs): # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. - for step, train_batch in zip( - range(first_step + 1, total_steps + 1), train_iter): + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): mw.step_start(step) with jax.profiler.StepTraceAnnotation("train_step", step_num=step): - params_repl, opt_repl, loss_value, rngs_loop, extra_measurements = ( - update_fn(params_repl, opt_repl, train_batch, rng=rngs_loop)) + with u.log_timing(mw, "z/secs/update0", noop=step > first_step + 1): + params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( + params_repl, opt_repl, rngs_loop, batch) # On the first host, let's always profile a handful of early steps. if jax.process_index() == 0: - prof = u.startstop_prof(prof, step, first_step, config.log_training_steps) + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value[0]) + for name, value in measurements.items(): + mw.measure(name, value[0]) + chrono.tick(step, mw.measure, write_note) + if not np.isfinite(l): + error = (f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + break # Checkpoint saving if (save_ckpt_path and - u.itstime(step, config.get("ckpt_steps"), total_steps, host=0)): - chrono.pause() + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): + chrono.pause(wait_for=(params_repl, opt_repl)) u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see (internal link)). Also, takes device 0's params only. - opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl) + opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None - if u.itstime(step, config.get("keep_ckpt_steps"), total_steps): + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): copy_step = step - # Checkpoint should be a nested dictionary or FLAX datataclasses from - # `flax.struct`. Both can be present in a checkpoint. - checkpoint = { - "params": params_cpu, - "opt": opt_cpu, - "chrono": chrono.save(), - } + ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": chrono.save()} ckpt_writer = pool.apply_async( - u.save_checkpoint, (checkpoint, save_ckpt_path, copy_step)) + u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) chrono.resume() - # Report training progress - if u.itstime(step, config.log_training_steps, total_steps, host=0): - for i, sched_fn_cpu in enumerate(sched_fns_cpu): - mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) - l = mw.measure("training_loss", loss_value[0]) - for name, value in extra_measurements.items(): - mw.measure(name, value[0]) - chrono.tick(step, mw.measure, write_note) - if not np.isfinite(l): - error = (f"The loss became nan or inf somewhere within steps " - f"[{step - config.log_training_steps}, {step}]") - break - - chrono.pause() if evaluators is None: evaluators = eval_common.from_config( config, {"predict": predict_fn}, lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}") ) - for (name, evaluator, freq, prefix) in evaluators: - if u.itstime(step, freq, total_steps): + for (name, evaluator, log_steps, prefix) in evaluators: + if u.itstime(step, log_steps, total_steps): + chrono.pause(wait_for=params_repl) write_note(f"{name} evaluation...\n{chrono.note}") - for key, value in evaluator.run(params_repl): - mw.measure(f"{prefix}{key}", value) - chrono.resume() + with u.log_timing(mw, f"z/secs/eval/{name}"): + for key, value in evaluator.run(params_repl): + mw.measure(f"{prefix}{key}", value) + chrono.resume() mw.step_end() # Always give a chance to stop the profiler, no matter how things ended. @@ -377,13 +361,13 @@ def predict_fn(params, image=None, text=None, **unused_kwargs): mw.close() # Make sure all hosts stay up until the end of main. - u.sync_all_hosts() + u.sync() # Before cleanup, as cleanup should only run for successful jobs. if error is not None: raise RuntimeError(error) - u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, logging.info) + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) if __name__ == "__main__": diff --git a/big_vision/trainers/proj/uvim/train.py b/big_vision/trainers/proj/uvim/train.py index e02af25..bbaec20 100644 --- a/big_vision/trainers/proj/uvim/train.py +++ b/big_vision/trainers/proj/uvim/train.py @@ -14,7 +14,7 @@ """Train loop for training the stage-II model.""" # pylint: disable=consider-using-from-import -from functools import partial +import functools import importlib import multiprocessing.pool import os @@ -23,6 +23,7 @@ from absl import flags from absl import logging from big_vision import input_pipeline +import big_vision.datasets.core as ds_core import big_vision.evaluators.common as eval_common import big_vision.models.proj.uvim.decode as decode import big_vision.optax as bv_optax @@ -35,6 +36,7 @@ from ml_collections import config_flags import numpy as np import optax + import tensorflow.io.gfile as gfile @@ -51,6 +53,7 @@ FLAGS = flags.FLAGS ONE_HOT_AXIS = -2 +partial = functools.partial def get_model(config): @@ -100,11 +103,8 @@ def main(argv): "writing to workdir %s.\u001b[0m", jax.process_index(), jax.local_device_count(), jax.device_count(), workdir) - assert not config.get("grad_accum_steps"), ( - "Gradient accumulation not supported anymore.") - save_ckpt_path = None - if workdir and config.get("ckpt_steps"): + if workdir: # Always create if requested, even if we may not write into it. gfile.makedirs(workdir) save_ckpt_path = os.path.join(workdir, "checkpoint.npz") @@ -132,48 +132,46 @@ def write_note(note): if jax.process_index() == 0: info("%s", note) - # Verify settings to make sure no checkpoints are accidentally missed. - if config.get("keep_ckpt_steps"): - assert config.get("ckpt_steps"), "Specify `ckpt_steps`." - assert config.keep_ckpt_steps % config.ckpt_steps == 0, ( - f"`keep_ckpt_steps` ({config.ckpt_steps}) should be" - f"divisible by `ckpt_steps ({config.ckpt_steps}).`") + write_note("Initializing...") + + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) # First thing after above sanity checks, so we can log "start" ticks. - mw = u.BigVisionMetricWriter(xid, wid, workdir) + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) chrono = u.Chrono() write_note("Initializing train dataset...") - batch_size = config.batch_size + train_data = ds_core.get(**config.input.data) train_ds = input_pipeline.make_for_train( - dataset=config.dataset, - split=config.train_split, + data=train_data.get_tfdata(ordered=False), batch_size=batch_size, - preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), - shuffle_buffer_size=config.get("shuffle_buffer_size"), - cache_raw=config.get("cache_raw", False), - data_dir=config.get("dataset_dir"), + preprocess_fn=pp_builder.get_preprocess_fn(config.input.get("pp")), + shuffle_buffer_size=config.input.get("shuffle_buffer_size"), + cache_raw=config.input.get("cache_raw", False), + filter_fn=config.input.get("filter_fn"), ) - ntrain_examples = input_pipeline.get_num_examples( - config.dataset, config.train_split, - data_dir=config.get("dataset_dir")) # Start prefetching already. n_prefetch = config.get("prefetch_to_device", 1) train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) - steps_per_epoch = ntrain_examples / batch_size + ntrain_img = train_data.total_examples - if config.get("total_epochs"): - total_steps = int(config.total_epochs * steps_per_epoch) - assert not config.get("total_steps"), "Set only one of total_(epochs|steps)" - else: - total_steps = config.get("total_steps", 0) + def get_steps(name, default=ValueError): # partial doesn't work well here. + return u.steps(name, config, ntrain_img, batch_size, default) + total_steps = get_steps("total") - logging.info( - "Running for %d steps, that means %f epochs and %f steps per epoch", - total_steps, total_steps * batch_size / ntrain_examples, steps_per_epoch) + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) - write_note("Initializing model...") + write_note(f"Initializing {config.model_name} model...") model, model_mod = get_model(config) encode_labels, decode_labels, predict_outputs_fn, task_params = ( @@ -201,15 +199,9 @@ def init(rng): parameter_overview.log_parameter_overview(params_cpu, msg="init params") mw.measure("num_params", num_params) - write_note("Initializing optimizer...") - # Load the optimizer either from our folder or from flax. - tx, sched_fns = bv_optax.make( - config, - params_cpu, - sched_kw=dict( - global_batch_size=batch_size, - total_steps=total_steps, - steps_per_epoch=steps_per_epoch)) + write_note(f"Initializing {config.optax_name} optimizer...") + tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) # We jit this, such that the arrays are created on the CPU, not device[0]. opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) @@ -290,9 +282,13 @@ def predict_fn(params, batch, seed=0, temperature=1e-7, **extra): logits = decode_labels(task_params, seqs, batch) return predict_outputs_fn(logits, **extra) - evaluators = eval_common.from_config( - config, {"predict": predict_fn, "validation": validation_fn}, - lambda s: write_note(f"Initializing evaluator: {s}")) + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, {"predict": predict_fn, "validation": validation_fn}, + lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}") + ) # Decide how to initialize training. The order is important. # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. @@ -316,21 +312,20 @@ def predict_fn(params, batch, seed=0, temperature=1e-7, **extra): loaded = u.load_checkpoint(checkpoint_tree, resume_ckpt_path) # bfloat16 type gets lost when data is saved to disk, so we recover it. checkpoint = jax.tree_map(u.recover_dtype, loaded) - params_cpu = checkpoint["params"] - opt_cpu = checkpoint["opt"] + params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"] chrono.load(checkpoint["chrono"]) - else: - if config.get("model_init"): - write_note(f"Initialize model from {config.model_init}...") - params_cpu = model_mod.load(params_cpu, config.model_init, config.model, - **config.get("model_init_kwargs", {})) - if jax.process_index() == 0: - parameter_overview.log_parameter_overview( - params_cpu, msg="loaded params") + elif config.get("model_init"): + write_note(f"Initialize model from {config.model_init}...") + params_cpu = model_mod.load( + params_cpu, config.model_init, config.model, + **config.get("model_load", {})) + if jax.process_index() == 0: + parameter_overview.log_parameter_overview( + params_cpu, msg="restored params") write_note("Kicking off misc stuff...") first_step = bv_optax.get_count(opt_cpu) - chrono.inform(first_step, total_steps, batch_size, steps_per_epoch) + chrono.inform(first_step, total_steps, batch_size, ntrain_img / batch_size) prof = None # Keeps track of start/stop of profiler state. write_note(f"Replicating...\n{chrono.note}") @@ -339,31 +334,47 @@ def predict_fn(params, batch, seed=0, temperature=1e-7, **extra): task_params = flax.jax_utils.replicate(task_params) update_rngs = flax.jax_utils.replicate(rng) - write_note(f"First step compilations...\n{chrono.note}") ckpt_writer = None + + write_note(f"First step compilations...\n{chrono.note}") error = None # For exiting with an error after cleanup. Avoids indentation. + # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. - for step, train_batch in zip(range(first_step + 1, total_steps + 1), - train_iter): + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): mw.step_start(step) with jax.profiler.StepTraceAnnotation("train_step", step_num=step): - params_repl, opt_repl, loss_value, update_rngs, extra_measurements = ( + params_repl, opt_repl, loss_value, update_rngs, measurements = ( update_fn( params_repl, opt_repl, - train_batch, + batch, update_rng=update_rngs, task_params=task_params)) # On the first host, let's always profile a handful of early steps. if jax.process_index() == 0: - prof = u.startstop_prof(prof, step, first_step, config.log_training_steps) + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value[0]) + for name, value in measurements.items(): + mw.measure(name, value[0]) + chrono.tick(step, mw.measure, write_note) + if not np.isfinite(l): + error = (f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + break # Checkpoint saving if (save_ckpt_path and - u.itstime(step, config.get("ckpt_steps"), total_steps, host=0)): + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): chrono.pause(wait_for=(params_repl, opt_repl)) u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) # We need to transfer the weights over now or else we risk keeping them @@ -374,44 +385,23 @@ def predict_fn(params, batch, seed=0, temperature=1e-7, **extra): # Check whether we want to keep a copy of the current checkpoint. copy_step = None - if u.itstime(step, config.get("keep_ckpt_steps"), total_steps): + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): copy_step = step - # Checkpoint should be a nested dictionary or FLAX datataclasses from - # `flax.struct`. Both can be present in a checkpoint. - checkpoint = { - "params": params_cpu, - "opt": opt_cpu, - "chrono": chrono.save(), - } + ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": chrono.save()} ckpt_writer = pool.apply_async( - u.save_checkpoint, (checkpoint, save_ckpt_path, copy_step)) + u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) chrono.resume() - # Report training progress - if (u.itstime(step, config.log_training_steps, total_steps, host=0) - or chrono.warmup and jax.process_index() == 0): - for i, sched_fn_cpu in enumerate(sched_fns_cpu): - mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) - l = mw.measure("training_loss", loss_value[0]) - for name, value in extra_measurements.items(): - mw.measure(name, value[0]) - chrono.tick(step, mw.measure, write_note) - if not np.isfinite(l): - error = (f"The loss became nan or inf somewhere within steps " - f"[{step - config.log_training_steps}, {step}]") - break - - chrono.pause(wait_for=params_repl) - for (name, evaluator, log_steps, prefix) in evaluators: + for (name, evaluator, log_steps, prefix) in evaluators(): if u.itstime(step, log_steps, total_steps, first=log_steps < total_steps, last=False): + chrono.pause(wait_for=(params_repl, task_params)) write_note(f"{name} evaluation...\n{chrono.note}") for key, value in evaluator.run( {"params": params_repl, "task_params": task_params}): mw.measure(f"{prefix}{key}", value) - chrono.resume() - + chrono.resume() mw.step_end() # Always give a chance to stop the profiler, no matter how things ended. @@ -420,7 +410,7 @@ def predict_fn(params, batch, seed=0, temperature=1e-7, **extra): u.startstop_prof(prof) # Run final evalution, also used for eval only jobs (when total_steps == 0). - for (name, evaluator, _, prefix) in evaluators: + for (name, evaluator, _, prefix) in evaluators(): write_note(f"{name} evaluation...\n{chrono.note}") for key, value in evaluator.run( {"params": params_repl, "task_params": task_params}): @@ -437,13 +427,13 @@ def predict_fn(params, batch, seed=0, temperature=1e-7, **extra): mw.close() # Make sure all hosts stay up until the end of main. - u.sync_all_hosts() + u.sync() # Before cleanup, as cleanup should only run for successful jobs. if error is not None: raise RuntimeError(error) - u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, logging.info) + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) if __name__ == "__main__": diff --git a/big_vision/trainers/proj/uvim/vqvae.py b/big_vision/trainers/proj/uvim/vqvae.py index 15c4610..13467d0 100644 --- a/big_vision/trainers/proj/uvim/vqvae.py +++ b/big_vision/trainers/proj/uvim/vqvae.py @@ -14,7 +14,7 @@ """Train loop for training the stage-I model.""" # pylint: disable=consider-using-from-import -from functools import partial +import functools import importlib import multiprocessing.pool import os @@ -23,6 +23,7 @@ from absl import flags from absl import logging from big_vision import input_pipeline +import big_vision.datasets.core as ds_core import big_vision.evaluators.common as eval_common import big_vision.optax as bv_optax import big_vision.pp.builder as pp_builder @@ -34,11 +35,12 @@ from ml_collections import config_flags import numpy as np import optax + import tensorflow.io.gfile as gfile SG = jax.lax.stop_gradient - +partial = functools.partial config_flags.DEFINE_config_file( "config", None, "Training configuration.", lock_config=True) @@ -69,7 +71,7 @@ def main(argv): predict_outputs_fn = partial(task_module.predict_outputs, config=config) save_ckpt_path = None - if workdir and config.get("ckpt_steps"): + if workdir: # Always create if requested, even if we may not write into it. gfile.makedirs(workdir) save_ckpt_path = os.path.join(workdir, "checkpoint.npz") @@ -97,46 +99,44 @@ def write_note(note): if jax.process_index() == 0: info("%s", note) - # Verify settings to make sure no checkpoints are accidentally missed. - if config.get("keep_ckpt_steps"): - assert config.get("ckpt_steps"), "Specify `ckpt_steps`." - assert config.keep_ckpt_steps % config.ckpt_steps == 0, ( - f"`keep_ckpt_steps` ({config.ckpt_steps}) should be" - f"divisible by `ckpt_steps ({config.ckpt_steps}).`") + write_note("Initializing...") + + batch_size = config.input.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f"Batch size ({batch_size}) must " + f"be divisible by device number ({jax.device_count()})") + info("Global batch size %d on %d hosts results in %d local batch size. With " + "%d dev per host (%d dev total), that's a %d per-device batch size.", + batch_size, jax.process_count(), batch_size // jax.process_count(), + jax.local_device_count(), jax.device_count(), + batch_size // jax.device_count()) # First thing after above sanity checks, so we can log "start" ticks. - mw = u.BigVisionMetricWriter(xid, wid, workdir) + mw = u.BigVisionMetricWriter(xid, wid, workdir, config) chrono = u.Chrono() write_note("Initializing train dataset...") - batch_size = config.batch_size + train_data = ds_core.get(**config.input.data) train_ds = input_pipeline.make_for_train( - dataset=config.dataset, - split=config.train_split, + data=train_data.get_tfdata(ordered=False), batch_size=batch_size, - preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), - shuffle_buffer_size=config.get("shuffle_buffer_size"), - cache_raw=config.get("cache_raw", False), - data_dir=config.get("dataset_dir"), + preprocess_fn=pp_builder.get_preprocess_fn(config.input.get("pp")), + shuffle_buffer_size=config.input.get("shuffle_buffer_size"), + cache_raw=config.input.get("cache_raw", False), + filter_fn=config.input.get("filter_fn"), ) - ntrain_examples = input_pipeline.get_num_examples( - config.dataset, config.train_split, - data_dir=config.get("dataset_dir")) # Start prefetching already. n_prefetch = config.get("prefetch_to_device", 1) train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) - steps_per_epoch = ntrain_examples / batch_size + ntrain_img = train_data.total_examples - if config.get("total_epochs"): - total_steps = int(config.total_epochs * steps_per_epoch) - assert not config.get("total_steps"), "Set only one of total_(epochs|steps)" - else: - total_steps = config.get("total_steps", 0) + def get_steps(name, default=ValueError): # partial doesn't work well here. + return u.steps(name, config, ntrain_img, batch_size, default) + total_steps = get_steps("total") - logging.info( - "Running for %d steps, that means %f epochs and %f steps per epoch", - total_steps, total_steps * batch_size / ntrain_examples, steps_per_epoch) + info("Running for %d steps, that means %f epochs", + total_steps, total_steps * batch_size / ntrain_img) write_note(f"Initializing {config.model_name} model...") model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") @@ -171,14 +171,9 @@ def init(rng): parameter_overview.log_parameter_overview(params_cpu, msg="init params") mw.measure("num_params", num_params) - # Load the optimizer either from our folder or from flax. - tx, sched_fns = bv_optax.make( - config, - params_cpu, - sched_kw=dict( - global_batch_size=batch_size, - total_steps=total_steps, - steps_per_epoch=steps_per_epoch)) + write_note(f"Initializing {config.optax_name} optimizer...") + tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( + total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) # We jit this, such that the arrays are created on the CPU, not device[0]. opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) @@ -251,9 +246,13 @@ def predict_fn(params, batch): outputs = predict_outputs_fn(logits) return outputs - evaluators = eval_common.from_config( - config, {"predict": predict_fn, "validation": validation_fn}, - lambda s: write_note(f"Initializing evaluator: {s}")) + # Only initialize evaluators when they are first needed. + @functools.lru_cache(maxsize=None) + def evaluators(): + return eval_common.from_config( + config, {"predict": predict_fn, "validation": validation_fn}, + lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}") + ) # Decide how to initialize training. The order is important. # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. @@ -286,14 +285,14 @@ def predict_fn(params, batch): params_cpu, state_cpu = model_mod.load( {"params": params_cpu, "state": state_cpu}, config.model_init, config.model, - **config.get("model_init_kwargs", {})) + **config.get("model_load", {})) if jax.process_index() == 0: parameter_overview.log_parameter_overview( - params_cpu, msg="loaded params") + params_cpu, msg="restored params") write_note("Kicking off misc stuff...") first_step = bv_optax.get_count(opt_cpu) - chrono.inform(first_step, total_steps, batch_size, steps_per_epoch) + chrono.inform(first_step, total_steps, batch_size, ntrain_img / batch_size) prof = None # Keeps track of start/stop of profiler state. write_note(f"Replicating...\n{chrono.note}") @@ -307,29 +306,44 @@ def predict_fn(params, batch): write_note(f"First step compilations...\n{chrono.note}") error = None # For exiting with an error after cleanup. Avoids indentation. + # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. - for step, train_batch in zip( - range(first_step + 1, total_steps + 1), train_iter): + for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): mw.step_start(step) with jax.profiler.StepTraceAnnotation("train_step", step_num=step): - (params_repl, opt_repl, state_repl, loss_value, rngs_loop, - extra_measurements) = update_fn( - params_repl, - opt_repl, - state_repl, - train_batch, - rngs_loop, - not config.get("freeze_dict", True)) + params_repl, opt_repl, state_repl, loss_value, rngs_loop, measurements = ( + update_fn( + params_repl, + opt_repl, + state_repl, + batch, + rngs_loop, + not config.get("freeze_dict", True))) # On the first host, let's always profile a handful of early steps. if jax.process_index() == 0: - prof = u.startstop_prof(prof, step, first_step, config.log_training_steps) + prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) + + # Report training progress + if (u.itstime(step, get_steps("log_training"), total_steps, host=0) + or chrono.warmup and jax.process_index() == 0): + for i, sched_fn_cpu in enumerate(sched_fns_cpu): + mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) + l = mw.measure("training_loss", loss_value[0]) + for name, value in measurements.items(): + mw.measure(name, value[0]) + chrono.tick(step, mw.measure, write_note) + if not np.isfinite(l): + error = (f"The loss became nan or inf somewhere within steps " + f"[{step - get_steps('log_training')}, {step}]") + break # Checkpoint saving if (save_ckpt_path and - u.itstime(step, config.get("ckpt_steps"), total_steps, host=0)): + (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or + u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): chrono.pause(wait_for=(params_repl, opt_repl, state_repl)) u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) # We need to transfer the weights over now or else we risk keeping them @@ -340,43 +354,27 @@ def predict_fn(params, batch): # Check whether we want to keep a copy of the current checkpoint. copy_step = None - if u.itstime(step, config.get("keep_ckpt_steps"), total_steps): + if u.itstime(step, get_steps("keep_ckpt", None), total_steps): copy_step = step - # Checkpoint should be a nested dictionary or FLAX datataclasses from - # `flax.struct`. Both can be present in a checkpoint. - checkpoint = { + ckpt = { "params": params_cpu, "state": state_cpu, "opt": opt_cpu, "chrono": chrono.save(), } ckpt_writer = pool.apply_async( - u.save_checkpoint, (checkpoint, save_ckpt_path, copy_step)) + u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) chrono.resume() - # Report training progress - if (u.itstime(step, config.log_training_steps, total_steps, host=0) - or chrono.warmup and jax.process_index() == 0): - for i, sched_fn_cpu in enumerate(sched_fns_cpu): - mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) - l = mw.measure("training_loss", loss_value[0]) - for name, value in extra_measurements.items(): - mw.measure(name, value[0]) - chrono.tick(step, mw.measure, write_note) - if not np.isfinite(l): - error = (f"The loss became nan or inf somewhere within steps " - f"[{step - config.log_training_steps}, {step}]") - break - - chrono.pause(wait_for=params_repl) - for (name, evaluator, log_steps, prefix) in evaluators: + for (name, evaluator, log_steps, prefix) in evaluators(): if u.itstime(step, log_steps, total_steps): + chrono.pause(wait_for=(params_repl, state_repl)) write_note(f"{name} evaluation...\n{chrono.note}") for key, value in evaluator.run( {"params": params_repl, "state": state_repl}): mw.measure(f"{prefix}{key}", value) - chrono.resume() + chrono.resume() mw.step_end() # Always give a chance to stop the profiler, no matter how things ended. @@ -386,7 +384,7 @@ def predict_fn(params, batch): # Support eval only runs: run evaluation if total_steps (or num_epochs) is 0. if total_steps == 0: - for (name, evaluator, _, prefix) in evaluators: + for (name, evaluator, _, prefix) in evaluators(): write_note(f"{name} evaluation...\n{chrono.note}") for key, value in evaluator.run( {"params": params_repl, "state": state_repl}): @@ -403,13 +401,13 @@ def predict_fn(params, batch): mw.close() # Make sure all hosts stay up until the end of main. - u.sync_all_hosts() + u.sync() # Before cleanup, as cleanup should only run for successful jobs. if error is not None: raise RuntimeError(error) - u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, logging.info) + u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) if __name__ == "__main__": diff --git a/big_vision/utils.py b/big_vision/utils.py index ea80d67..e4eaca3 100644 --- a/big_vision/utils.py +++ b/big_vision/utils.py @@ -39,7 +39,7 @@ import ml_collections as mlc import numpy as np -import tensorflow.io.gfile as gfile +import tensorflow.io.gfile as gfile # pylint: disable=consider-using-from-import Registry = pp_registry.Registry @@ -379,7 +379,7 @@ class Chrono: """ def __init__(self): - self.program_start_time = time.time() + self.program_start_time = time.monotonic() self.train_start_time = None self.train_start_step = None # When we started timing (after warmup) @@ -405,7 +405,8 @@ def inform(self, first_step, total_steps, global_bs, steps_per_epoch): def tick(self, step, measure, write_note): """A chronometer tick.""" - now = time.time() + now = time.monotonic() + measure("uptime", now - self.program_start_time) # We do always count examples, regardless of the timing-related warmup that # happens a few lines below. @@ -466,10 +467,10 @@ def tick(self, step, measure, write_note): def pause(self, wait_for=()): assert self.pause_start is None, "Don't pause twice." jax.block_until_ready(wait_for) - self.pause_start = time.time() + self.pause_start = time.monotonic() def resume(self): - self.paused_time += time.time() - self.pause_start + self.paused_time += time.monotonic() - self.pause_start self.pause_start = None def save(self): @@ -780,12 +781,12 @@ def steps(prefix, config, data_size=None, batch_size=None, default=ValueError): ValueError if there is no such duration in the config and no default is set. """ # Be helpful and make sure only one of _steps, _epochs, _examples is defined. + # Note that steps=0 is also a valid value (e.g. to only run evaluators). msg = f"Only one of {prefix}_(steps,examples,epochs) should be defined." - assert (int(f"{prefix}_steps" in config) + - int(f"{prefix}_examples" in config) + - int(f"{prefix}_epochs" in config)) <= 1, msg + assert ((f"{prefix}_steps" in config) + + (f"{prefix}_examples" in config) + + (f"{prefix}_epochs" in config) <= 1), msg - # Boy do I anticipate the walrus operator... if f"{prefix}_steps" in config: return config[f"{prefix}_steps"] @@ -909,12 +910,15 @@ def mul(a, b): # B * BHWC -> B111 * BHWC return rng, map(mix, things), {k: mix(v) for k, v in more_things.items()} -def sync_all_hosts(): - """Makes sure all hosts are synced.""" - if jax.process_count() > 1: - x = jnp.ones([jax.local_device_count()]) - x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, "i"), "i")(x)) - assert x[0] == jax.device_count() +def _sync(x): + return jax.lax.psum(x, "i") + + +def sync(): + """Syncs hosts and empties async computation queue.""" + x = jnp.ones([jax.local_device_count()]) + x = jax.device_get(jax.pmap(_sync, "i")(x)) + assert x[0] == jax.device_count() def check_and_compile_patterns(patterns): @@ -970,6 +974,20 @@ def profile(name, ttl=3 * 365 * 24 * 3600): startstop_prof_at_steps(sess, name=name, ttl=ttl) +@contextlib.contextmanager +def log_timing(mw, name, *, noop=False): + t0 = time.monotonic() + yield + dt = time.monotonic() - t0 + if not noop: + mw.measure(name, dt) + + +@jax.jit +def _squareit(x): + return x**2 + + def startstop_prof(sess, step=None, first_step=0, log_steps=1, surround=20, **kw): """Runs the profiler for `surround` steps around the next `log_steps`."""