-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
65 lines (54 loc) · 2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import hydra
import omegaconf
import pickle
from pprint import pprint
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf
import logging
import os
from random import randint
import config
@hydra.main(config_path="config",
config_name="default_kp_config",
version_base=None)
def main(cfg: omegaconf.DictConfig) -> None:
# Format the config file
kp = hydra.utils.instantiate(cfg.pykoop_pipeline, _convert_='all')
hydra_cfg = HydraConfig.get()
# Get parameters and create folders
if cfg.robot == 'nl_msd':
path = "build/pykoop_objects/{}/variance_{}/kp_{}_{}.bin".format(
cfg.robot, cfg.variance,
hydra_cfg.runtime.choices['regressors@pykoop_pipeline'],
hydra_cfg.runtime.choices['lifting_functions@pykoop_pipeline'])
elif cfg.robot == 'soft_robot':
path = "build/pykoop_objects/{}/variance_{}/kp_{}_{}.bin".format(
cfg.robot, cfg.variance,
hydra_cfg.runtime.choices['regressors@pykoop_pipeline'],
hydra_cfg.runtime.choices['lifting_functions@pykoop_pipeline'])
os.makedirs(os.path.dirname(path), exist_ok=True)
hydra_cfg = HydraConfig.get()
# Get preprocessed data
with open(
"build/preprocessed_data/{}/variance_{}.bin".format(
cfg.robot, cfg.variance), "rb") as f:
data = pickle.load(f)
# Train model
kp.fit(data.pykoop_dict['X_train'],
n_inputs=data.pykoop_dict['n_inputs'],
episode_feature=True)
with open(path, "wb") as f:
data_dump = pickle.dump(kp, f)
# Predict. Note that pedictions are only good at low noise.
if cfg.variance < 0.1:
kp.x_pred = kp.predict_trajectory(
data.pykoop_dict['x0_valid'],
data.pykoop_dict['u_valid'],
relift_state=True,
return_lifted=False,
)
with open(path, "wb") as f:
data_dump = pickle.dump(kp, f)
if __name__ == '__main__':
main()