-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_pm_vae.py
113 lines (85 loc) · 3.17 KB
/
train_pm_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import json
import os
import pickle
import random
import jax
import jax.numpy as jnp
import optax
from absl import app, flags
from bax import Trainer
from bax.callbacks import CheckpointCallback, LearningRateLoggerCallback
from ml_collections.config_flags import config_flags
from posterior_matching.models.vae import PosteriorMatchingVAE
from posterior_matching.utils import (
configure_environment,
load_datasets,
cyclical_annealing_schedule,
make_run_dir,
TensorBoardCallback,
)
configure_environment()
config_flags.DEFINE_config_file("config", lock_config=False)
def get_beta_schedule(config):
if "schedule" not in config:
return lambda x: 1.0
if config.schedule == "monotonic":
return optax.linear_schedule(
config.low_value,
config.high_value,
config.transition_steps,
config.transition_begin,
)
if config.schedule == "cyclic":
return cyclical_annealing_schedule(
config.low_value, config.high_value, config.period, config.delay
)
def main(_):
config = flags.FLAGS.config
if "seed" not in config:
config.seed = random.randint(0, int(2e9))
config.lock()
train_dataset, val_dataset = load_datasets(config.data)
is_image_data = "image" in train_dataset.element_spec
data_key = "image" if is_image_data else "features"
def loss_fn(step, is_training, batch):
model = PosteriorMatchingVAE.from_config(config.model)
out = model(batch[data_key], batch["mask"], is_training=is_training)
beta_schedule = get_beta_schedule(config.get("beta", {}))
beta = beta_schedule(step)
out["beta"] = beta
elbo = jnp.mean(out["reconstruction_ll"] - beta * out["kl"])
matching_loss = -jnp.mean(out["matching_ll"])
loss = -elbo + config.get("matching_coef", 1.0) * matching_loss
return loss, jax.tree_map(jnp.mean, out)
schedule = optax.exponential_decay(**config.lr_schedule)
optimizer = optax.chain(
optax.scale_by_adam(**config.get("adam", {})),
optax.add_decayed_weights(
config.get("weight_decay", 0.0),
mask=lambda p: jax.tree_map(lambda x: x.ndim != 1, p),
),
optax.scale_by_schedule(schedule),
optax.scale(-1.0),
)
trainer = Trainer(loss_fn, optimizer, num_devices=1, seed=config.seed)
run_dir = make_run_dir(prefix=f"pm-vae-{config.data.dataset}")
print("Using run directory:", run_dir)
callbacks = [
CheckpointCallback(os.path.join(run_dir, "train_state.pkl")),
LearningRateLoggerCallback(schedule),
TensorBoardCallback(os.path.join(run_dir, "tb")),
]
train_state = trainer.fit(
train_dataset,
config.steps,
val_dataset=val_dataset,
validation_freq=config.validation_freq,
callbacks=callbacks,
)
if config.get("save_final_state", False):
with open(os.path.join(run_dir, "train_state.pkl"), "wb") as fp:
pickle.dump(train_state, fp)
with open(os.path.join(run_dir, "model_config.json"), "w") as fp:
json.dump(config.model.to_dict(), fp)
if __name__ == "__main__":
app.run(main)