Skip to content

Commit

Permalink
fix(zlx): random_collect compatible to episode collector (opendilab#190)
Browse files Browse the repository at this point in the history
* fixbug(zlx): make random_collect compatible to episode collector; use random_collect_fn to replace entries' corresponding code

* fixbug(zlx): Fix comments by nyz; Add function random_collect's unittest

* polish(pu): polish serial entry for td3_vae and reward_model_ngu

* format(zlx): flake8 style

Co-authored-by: zlx-sensetime <[email protected]>
Co-authored-by: puyuan1996 <[email protected]>
  • Loading branch information
3 people authored Jan 19, 2022
1 parent cdf7cde commit 8068628
Show file tree
Hide file tree
Showing 17 changed files with 228 additions and 118 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
ignore=F401,F841,F403,E226,E126,W504,E265,E722,W503,W605,E741,E122
ignore=F401,F841,F403,E226,E126,W504,E265,E722,W503,W605,E741,E122,E731
max-line-length=120
statistics
1 change: 1 addition & 0 deletions ding/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .application_entry_trex_collect_data import trex_collecting_data, collect_episodic_demo_data_for_trex
from .serial_entry_guided_cost import serial_pipeline_guided_cost
from .serial_entry_gail import serial_pipeline_gail
from .utils import random_collect
14 changes: 3 additions & 11 deletions ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from .utils import random_collect


def serial_pipeline(
Expand Down Expand Up @@ -80,16 +81,7 @@ def serial_pipeline(

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
if cfg.policy.get('transition_with_policy_data', False):
collector.reset_policy(policy.collect_mode)
else:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
Expand Down
17 changes: 6 additions & 11 deletions ding/entry/serial_entry_dqfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import logging
from functools import partial
from tensorboardX import SummaryWriter
from copy import deepcopy

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.model import DQN
from copy import deepcopy
from .utils import random_collect, mark_not_expert
from dizoo.classic_control.cartpole.config.cartpole_dqfd_config import main_config, create_config # for testing


Expand Down Expand Up @@ -143,15 +144,9 @@ def serial_pipeline_dqfd(
learner.priority_info = {}
# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
for i in range(len(new_data)):
new_data[i]['is_expert'] = 0 # set is_expert flag(expert 1, agent 0)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
collector.reset_policy(policy.collect_mode)
random_collect(
cfg.policy, policy, collector, collector_env, commander, replay_buffer, postprocess_data_fn=mark_not_expert
)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
Expand Down
12 changes: 4 additions & 8 deletions ding/entry/serial_entry_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import logging
from functools import partial
from tensorboardX import SummaryWriter
import numpy as np

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import compile_config, read_config
from ding.policy import create_policy, PolicyFactory
from ding.policy import create_policy
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from ding.entry import collect_demo_data
from ding.utils import save_file
import numpy as np
from .utils import random_collect


def save_reward_model(path, reward_model, weights_name='best'):
Expand Down Expand Up @@ -112,12 +113,7 @@ def serial_pipeline_gail(

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
new_data = collector.collect(n_sample=cfg.policy.random_collect_size)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
best_reward = -np.inf
for _ in range(max_iterations):
# Evaluate policy performance
Expand Down
12 changes: 4 additions & 8 deletions ding/entry/serial_entry_guided_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.policy import create_policy
from ding.utils import set_pkg_seed, save_file
from ding.reward_model import create_reward_model
from .utils import random_collect


def serial_pipeline_guided_cost(
Expand Down Expand Up @@ -107,13 +109,7 @@ def serial_pipeline_guided_cost(

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
Expand Down
12 changes: 3 additions & 9 deletions ding/entry/serial_entry_mbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.policy import create_policy
from ding.utils import set_pkg_seed, read_file, save_file
from .utils import random_collect


def save_ckpt_fn(learner, env_model, envstep):
Expand Down Expand Up @@ -129,14 +130,7 @@ def serial_pipeline_mbrl(

# Accumulate plenty of data before the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(policy.collect_mode)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)

random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
# Train
batch_size = learner.policy.get_attribute('batch_size')
real_ratio = model_based_cfg['real_ratio']
Expand Down
17 changes: 5 additions & 12 deletions ding/entry/serial_entry_r2d3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from copy import deepcopy
from functools import partial
from typing import Union, Optional, List, Any, Tuple

import numpy as np
import torch
from tensorboardX import SummaryWriter

from ding.config import read_config, compile_config
from ding.envs import get_vec_env_setting, create_env_manager
from ding.policy import create_policy, PolicyFactory
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from .utils import random_collect, mark_not_expert


def serial_pipeline_r2d3(
Expand Down Expand Up @@ -156,16 +156,9 @@ def serial_pipeline_r2d3(

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
for i in range(len(new_data)):
# set is_expert flag(expert 1, agent 0)
new_data[i]['is_expert'] = 0
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
collector.reset_policy(policy.collect_mode)
random_collect(
cfg.policy, policy, collector, collector_env, commander, replay_buffer, postprocess_data_fn=mark_not_expert
)

for _ in range(max_iterations):
collect_kwargs = commander.step()
Expand Down
11 changes: 3 additions & 8 deletions ding/entry/serial_entry_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.policy import create_policy
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from .utils import random_collect


def serial_pipeline_reward_model(
Expand Down Expand Up @@ -83,13 +84,7 @@ def serial_pipeline_reward_model(

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
Expand Down
21 changes: 12 additions & 9 deletions ding/entry/serial_entry_reward_model_ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from functools import partial
from typing import Union, Optional, List, Any, Tuple

import torch
from tensorboardX import SummaryWriter

Expand All @@ -17,6 +16,7 @@
from ding.worker import BaseLearner, BaseSerialCommander, create_buffer, create_serial_collector
from ding.worker.collector.base_serial_evaluator_ngu import BaseSerialEvaluatorNGU as BaseSerialEvaluator # TODO
import copy
from .utils import random_collect


def serial_pipeline_reward_model_ngu(
Expand Down Expand Up @@ -96,14 +96,17 @@ def serial_pipeline_reward_model_ngu(

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
# collect_kwargs.update({'action_shape':cfg.policy.model.action_shape}) # todo
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
# backup
# action_space = collector_env.env_info().act_space
# random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
# collector.reset_policy(random_policy)
# collect_kwargs = commander.step()
# # collect_kwargs.update({'action_shape':cfg.policy.model.action_shape}) # todo
# new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
# replay_buffer.push(new_data, cur_collector_envstep=0)
# collector.reset_policy(policy.collect_mode)
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)

estimate_cnt = 0
for iter in range(max_iterations):
collect_kwargs = commander.step() # {'eps': 0.95}
Expand Down
11 changes: 3 additions & 8 deletions ding/entry/serial_entry_reward_model_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.policy import create_policy
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from .utils import random_collect


def serial_pipeline_reward_model_onpolicy(
Expand Down Expand Up @@ -83,13 +84,7 @@ def serial_pipeline_reward_model_onpolicy(

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
for iter in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
Expand Down
11 changes: 3 additions & 8 deletions ding/entry/serial_entry_sqil.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from .utils import random_collect


def serial_pipeline_sqil(
Expand Down Expand Up @@ -117,13 +118,7 @@ def serial_pipeline_sqil(

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
Expand Down
36 changes: 24 additions & 12 deletions ding/entry/serial_entry_td3_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ding.policy import create_policy, PolicyFactory
from ding.utils import set_pkg_seed
import copy
from .utils import random_collect, mark_not_expert, mark_warm_up


def serial_pipeline_td3_vae(
Expand Down Expand Up @@ -83,18 +84,29 @@ def serial_pipeline_td3_vae(

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
if cfg.policy.get('transition_with_policy_data', False):
collector.reset_policy(policy.collect_mode)
else:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
for item in new_data:
item['warm_up'] = True
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
# backup
# if cfg.policy.get('transition_with_policy_data', False):
# collector.reset_policy(policy.collect_mode)
# else:
# action_space = collector_env.env_info().act_space
# random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
# collector.reset_policy(random_policy)
# collect_kwargs = commander.step()
# new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
# for item in new_data:
# item['warm_up'] = True
# replay_buffer.push(new_data, cur_collector_envstep=0)
# collector.reset_policy(policy.collect_mode)
# postprocess_data_fn = lambda x: mark_warm_up(mark_not_expert(x))
random_collect(
cfg.policy,
policy,
collector,
collector_env,
commander,
replay_buffer,
postprocess_data_fn=lambda x: mark_warm_up(mark_not_expert(x)) # postprocess_data_fn
)
# warm_up
# Learn policy from collected data
for i in range(cfg.policy.learn.warm_up_update):
Expand Down
Loading

0 comments on commit 8068628

Please sign in to comment.