Skip to content

Commit

Permalink
Updates input pipelines & adds LiT-B16B_2 config. (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
andsteing authored Aug 18, 2022
1 parent 1c6f5aa commit 8921d51
Show file tree
Hide file tree
Showing 44 changed files with 1,196 additions and 1,017 deletions.
56 changes: 29 additions & 27 deletions big_vision/configs/bit_i1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,23 @@ def get_config(runlocal=False):
"""Config for training on ImageNet-1k."""
config = mlc.ConfigDict()

config.dataset = 'imagenet2012'
config.train_split = 'train[:99%]'
config.cache_raw = not runlocal # Needs up to 120GB of RAM!
config.shuffle_buffer_size = 250_000 if not runlocal else 10_000 # Per host.
config.seed = 0
config.total_epochs = 90
config.num_classes = 1000
config.loss = 'softmax_xent'

config.seed = 0
config.batch_size = 4096 if not runlocal else 32
config.total_epochs = 90
config.input = dict()
config.input.data = dict(
name='imagenet2012',
split='train[:99%]',
)
config.input.batch_size = 4096 if not runlocal else 32
config.input.cache_raw = not runlocal # Needs up to 120GB of RAM!
config.input.shuffle_buffer_size = 250_000 if not runlocal else 10_000 # Per host.

pp_common = '|onehot(1000, key="{lbl}", key_result="labels")'
pp_common += '|value_range(-1, 1)|keep("image", "labels")'
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label')
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label')
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common

config.log_training_steps = 50
Expand All @@ -62,30 +65,29 @@ def get_config(runlocal=False):
config.grad_clip_norm = 1.0

# linear scaling rule. Don't forget to sweep if sweeping batch_size.
config.wd = (1e-4 / 256) * config.batch_size
config.lr = (0.1 / 256) * config.batch_size
config.wd = (1e-4 / 256) * config.input.batch_size
config.lr = (0.1 / 256) * config.input.batch_size
config.schedule = dict(decay_type='cosine', warmup_steps=1000)

# Eval section
eval_common = dict(
type='classification',
dataset='imagenet2012',
pp_fn=pp_eval.format(lbl='label'),
loss_name=config.loss,
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
)
def get_eval(split, dataset='imagenet2012'):
return dict(
type='classification',
data=dict(name=dataset, split=split),
pp_fn=pp_eval.format(lbl='label'),
loss_name=config.loss,
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
cache_final=not runlocal,
)
config.evals = {}
config.evals.train = {**eval_common, 'split': 'train[:2%]'}
config.evals.minival = {**eval_common, 'split': 'train[99%:]'}
config.evals.val = {**eval_common, 'split': 'validation'}
config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'}

config.evals.real = dict(**eval_common)
config.evals.real.dataset = 'imagenet2012_real'
config.evals.real.split = 'validation'
config.evals.train = get_eval('train[:2%]')
config.evals.minival = get_eval('train[99%:]')
config.evals.val = get_eval('validation')
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')

# config.fewshot = get_fewshot_lsr()
# config.fewshot.log_steps = 1000
# config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal)
# config.evals.fewshot.log_steps = 1000

return config
47 changes: 25 additions & 22 deletions big_vision/configs/bit_i21k.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,25 @@ def get_config():
"""Config for training on imagenet-21k."""
config = mlc.ConfigDict()

config.dataset = 'imagenet21k'
config.train_split = 'full[51200:]'
config.seed = 0
config.total_epochs = 90
config.num_classes = 21843
config.init_head_bias = -10.0
config.loss = 'sigmoid_xent'

config.trial = 0
config.batch_size = 4096
config.total_epochs = 90
config.input = dict()
config.input.data = dict(
name='imagenet21k',
split='full[51200:]',
)
config.input.batch_size = 4096
config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.

pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common_i21k
pp_eval_i1k = 'decode|resize_small(256)|central_crop(224)' + pp_common_i1k
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
pp_eval = 'decode|resize_small(256)|central_crop(224)'

config.log_training_steps = 50
config.ckpt_steps = 1000
Expand All @@ -58,22 +60,23 @@ def get_config():
config.grad_clip_norm = 1.0

# linear scaling rule. Don't forget to sweep if sweeping batch_size.
config.lr = (0.03 / 256) * config.batch_size
config.wd = (3e-5 / 256) * config.batch_size
config.lr = (0.03 / 256) * config.input.batch_size
config.wd = (3e-5 / 256) * config.input.batch_size
config.schedule = dict(decay_type='cosine', warmup_steps=5000)

# Eval section
eval_common = dict(
type='classification',
dataset=config.dataset,
pp_fn=pp_eval,
loss_name=config.loss,
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
)
# Evaluations on i21k itself.
def eval_i21k(split):
return dict(
type='classification',
data={**config.input.data, 'split': split},
pp_fn=pp_eval + pp_common_i21k,
loss_name=config.loss,
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
)
config.evals = {}
config.evals.test = {**eval_common, 'split': 'full[:25_600]'}
config.evals.val = {**eval_common, 'split': 'full[25_600:51_200]'}
config.evals.train = {**eval_common, 'split': 'full[51_200:76_800]'}
config.evals.test = eval_i21k('full[:25_600]')
config.evals.val = eval_i21k('full[25_600:51_200]')
config.evals.train = eval_i21k('full[51_200:76_800]')

# Few-shot evaluators
config.evals.fewshot = get_fewshot_lsr()
Expand Down
2 changes: 1 addition & 1 deletion big_vision/configs/common_fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
}
config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)|keep("image", "label")'
config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)|keep("image", "label")'
config.shots = [1, 5, 10, 25]
config.shots = (1, 5, 10, 25)
config.l2_reg = 2.0 ** 10
config.num_seeds = 3
config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)]
Expand Down
1 change: 1 addition & 0 deletions big_vision/configs/load_and_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import big_vision.configs.common as bvcc
from big_vision.configs.common_fewshot import get_fewshot_lsr
from big_vision.configs.proj.image_text import lit_eval
import ml_collections as mlc


Expand Down
61 changes: 31 additions & 30 deletions big_vision/configs/mlp_mixer_i1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,39 @@ def get_config(mode=None):
"""Config for training Mixer on i1k."""
config = mlc.ConfigDict()

config.dataset = 'imagenet2012'
config.train_split = 'train[:99%]'
config.cache_raw = True # Needs up to 120GB of RAM!
config.seed = 0
config.total_epochs = 300
config.num_classes = 1000
config.init_head_bias = -6.9
config.loss = 'sigmoid_xent'
config.init_head_bias = -6.9

config.input = dict()
config.input.data = dict(
name='imagenet2012',
split='train[:99%]',
)
config.input.batch_size = 4096
config.input.cache_raw = True # Needs up to 120GB of RAM!
config.input.shuffle_buffer_size = 250_000

config.pp_train = (
config.input.pp = (
'decode_jpeg_and_inception_crop(224)'
'|flip_lr'
'|randaug(2,15)'
'|value_range(-1, 1)'
'|onehot(1000, key="label", key_result="labels")'
'|keep("image", "labels")'
)
ppv = (
pp_eval = (
'decode'
'|resize_small(256)|central_crop(224)'
'|value_range(-1, 1)'
'|onehot(1000, key="{lbl}", key_result="labels")'
'|keep("image", "labels")'
)

config.batch_size = 4096
config.total_epochs = 300

config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.

config.log_training_steps = 50
config.ckpt_steps = 1000
config.ckpt_timeout = 1

config.prefetch_to_device = 2

Expand All @@ -86,30 +88,29 @@ def get_config(mode=None):
)

# Eval section
eval_common = dict(
type='classification',
dataset='imagenet2012',
pp_fn=ppv.format(lbl='label'),
loss_name=config.loss,
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
)
def get_eval(split, dataset='imagenet2012'):
return dict(
type='classification',
data=dict(name=dataset, split=split),
pp_fn=pp_eval.format(lbl='label'),
loss_name=config.loss,
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
cache_final=mode != 'gpu8',
)
config.evals = {}
config.evals.train = {**eval_common, 'split': 'train[:2%]'}
config.evals.minival = {**eval_common, 'split': 'train[99%:]'}
config.evals.val = {**eval_common, 'split': 'validation'}
config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'}

config.evals.real = dict(**eval_common)
config.evals.real.dataset = 'imagenet2012_real'
config.evals.real.split = 'validation'
config.evals.real.pp_fn = ppv.format(lbl='real_label')
config.evals.train = get_eval('train[:2%]')
config.evals.minival = get_eval('train[99%:]')
config.evals.val = get_eval('validation')
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')

config.fewshot = get_fewshot_lsr()

if mode == 'gpu8':
config.total_epochs = 60
config.batch_size = 512
config.cache_raw = False
config.input.batch_size = 512
config.input.cache_raw = False
if mode == 'regression_test':
config.total_epochs = 60

Expand Down
71 changes: 38 additions & 33 deletions big_vision/configs/proj/distill/bigsweep_flowers_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,21 @@ def get_config(arg=None):
arg = bvcc.parse_arg(arg, runlocal=False, data='flowers', variant='medium', crop='inception_crop(128)')
config = mlc.ConfigDict()

config.dataset = dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data]
config.cache_raw = True
config.input = {}
config.input.data = dict(
name=dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data],
split=dict(flowers='train', pet='train[:90%]')[arg.data],
)
config.input.batch_size = 512
config.input.cache_raw = True
config.input.shuffle_buffer_size = 50_000
config.prefetch_to_device = 4
config.train_split = dict(flowers='train', pet='train[:90%]')[arg.data]
config.num_classes = NCLS[arg.data]

config.batch_size = 512
config.num_epochs = {
config.num_classes = NCLS[arg.data]
config.total_epochs = {
'flowers': {'fast': 10_000, 'medium': 100_000, 'long': 1_000_000},
'pet': {'fast': 1000, 'medium': 3000, 'long': 30_000},
}[arg.data][arg.variant]
config.shuffle_buffer_size = 50_000

config.log_training_steps = 100
config.ckpt_steps = 2500
Expand All @@ -81,7 +84,7 @@ def get_config(arg=None):
f'|onehot({config.num_classes}, key="label", key_result="labels")'
'|keep("image", "labels")'
)
config.pp_train = f'decode|{arg.crop}|flip_lr' + pp_common
config.input.pp = f'decode|{arg.crop}|flip_lr' + pp_common
ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common

config.mixup = dict(p=1.0, n=2)
Expand Down Expand Up @@ -118,18 +121,19 @@ def get_config(arg=None):
val_split = 'train[90%:]' if not arg.runlocal else 'train[:16]'
test_split = 'test' if not arg.runlocal else 'test[:16]'

base = dict(
type='classification',
pred='student_fwd',
dataset=config.dataset,
pp_fn=ppv,
loss_name='softmax_xent',
log_steps=500,
)
def get_eval(split):
return dict(
type='classification',
pred='student_fwd',
data=dict(name=config.input.data.name, split=split),
pp_fn=ppv,
loss_name='softmax_xent',
log_steps=500,
)
config.evals = {}
config.evals.student_train = {**base, 'split': minitrain_split}
config.evals.student_val = {**base, 'split': val_split}
config.evals.student_test = {**base, 'split': test_split}
config.evals.student_train = get_eval(minitrain_split)
config.evals.student_val = get_eval(val_split)
config.evals.student_test = get_eval(test_split)

# Teacher is fixed, so rare evals.
teacher = dict(log_steps=100_000, pred='prof_m_fwd')
Expand All @@ -138,22 +142,23 @@ def get_config(arg=None):
config.evals.teacher_test = {**config.evals.student_test, **teacher}

# Could in principle also look at agreement on other datasets!
dist = dict(
type='proj.distill.distance',
pred='student_prof_m_fwd',
dataset=config.dataset,
pp_fn=ppv + '|keep("image")',
log_steps=1000,
distances=({'kind': 'kl'}, {'kind': 'euclidean'},
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
)
config.evals.dist_train = {**dist, 'split': minitrain_split}
config.evals.dist_val = {**dist, 'split': val_split}
config.evals.dist_test = {**dist, 'split': test_split}
def get_dist(split):
return dict(
type='proj.distill.distance',
pred='student_prof_m_fwd',
data=dict(name=config.input.data.name, split=split),
pp_fn=ppv + '|keep("image")',
log_steps=1000,
distances=({'kind': 'kl'}, {'kind': 'euclidean'},
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
)
config.evals.dist_train = get_dist(minitrain_split)
config.evals.dist_val = get_dist(val_split)
config.evals.dist_test = get_dist(test_split)

# Make a few things much smaller for quick local debugging testruns.
if arg.runlocal:
config.shuffle_buffer_size = 10
config.batch_size = 8
config.input.shuffle_buffer_size = 10
config.input.batch_size = 8

return config
Loading

0 comments on commit 8921d51

Please sign in to comment.