-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[D] Stable baselines integration (#384)
* refactored reinforcement learning agent to accept marketplace * adapted test_exampleprinter.py to marketplace initialization * add market option to accept continuos actions * fixed action space check * initial stable baselines integration * Agent init by env (#390) * introduced self.network in actorcritic_agent * added network_architecture in QLearningAgent * changed actorcritic_agent to network_architecture * set back training_scenario * am_configuration initialize rl-agent via marketplace * added final analyse to stable baselines training * added more stable baselines algorithms * added ppo algorithm * introduced stable_baselines_folder * renamed training to callback * satisfied linter * fixed loading problem * try to make tqdm run in stable_baselines * make tqdm running * reduced pmonitoring episodes in sb training * save model only if significantly better * fixed too long test time bug * moved back to 250 episodes testing * set timeout to 15 minutes * added first batch of fixes to @NikkelM feedback * added type annotations and asserts in stable_baselines_model * added sbtraining to training_scenario * applied comments in am_configuration * solved .dat problem and fixed crashing asserts * reintroduced _end_of_training * removed deprecated if * Moved '.dat' to function call instead of appending within function * Fixed assert * fixed model file ending bug * Add short explanation docstring Co-authored-by: Johann Schulze Tast <[email protected]> * fixed wrong docstring * Fixed tests Co-authored-by: NikkelM <[email protected]> Co-authored-by: Johann Schulze Tast <[email protected]>
- Loading branch information
1 parent
cd4f491
commit f8bc162
Showing
12 changed files
with
303 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
141 changes: 141 additions & 0 deletions
141
recommerce/rl/stable_baselines/stable_baselines_callback.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import os | ||
import signal | ||
import sys | ||
import time | ||
import warnings | ||
|
||
import numpy as np | ||
from stable_baselines3.common.callbacks import BaseCallback | ||
from stable_baselines3.common.results_plotter import load_results, ts2xy | ||
from torch.utils.tensorboard import SummaryWriter | ||
from tqdm import tqdm | ||
from tqdm.auto import trange | ||
|
||
import recommerce.configuration.utils as ut | ||
from recommerce.configuration.hyperparameter_config import config | ||
from recommerce.configuration.path_manager import PathManager | ||
from recommerce.market.sim_market import SimMarket | ||
from recommerce.monitoring.agent_monitoring.am_evaluation import Evaluator | ||
from recommerce.monitoring.agent_monitoring.am_monitoring import Monitor | ||
from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent | ||
|
||
warnings.filterwarnings('ignore') | ||
|
||
|
||
class PerStepCheck(BaseCallback): | ||
""" | ||
Callback for saving a model (the check is done every `check_freq` steps) | ||
based on the training reward (in practice, we recommend using `EvalCallback`). | ||
""" | ||
def __init__(self, agent_class, marketplace_class, log_dir_prepend='', training_steps=10000, iteration_length=500): | ||
assert issubclass(agent_class, ReinforcementLearningAgent) | ||
assert issubclass(marketplace_class, SimMarket) | ||
assert isinstance(log_dir_prepend, str), \ | ||
f'log_dir_prepend should be a string, but {log_dir_prepend} is {type(log_dir_prepend)}' | ||
assert isinstance(training_steps, int) and training_steps > 0 | ||
assert isinstance(iteration_length, int) and iteration_length > 0 | ||
super(PerStepCheck, self).__init__(True) | ||
self.best_mean_interim_reward = None | ||
self.best_mean_overall_reward = None | ||
self.marketplace_class = marketplace_class | ||
self.agent_class = agent_class | ||
self.iteration_length = iteration_length | ||
self.tqdm_instance = trange(training_steps) | ||
self.saved_parameter_paths = [] | ||
signal.signal(signal.SIGINT, self._signal_handler) | ||
|
||
self.initialize_io_related(log_dir_prepend) | ||
|
||
def _signal_handler(self, signum, frame) -> None: # pragma: no cover | ||
""" | ||
Handle any interruptions to the running process, such as a `KeyboardInterrupt`-event. | ||
""" | ||
print('\nAborting training...') | ||
self._end_of_training() | ||
sys.exit(0) | ||
|
||
def initialize_io_related(self, log_dir_prepend) -> None: | ||
""" | ||
Initializes the local variables self.curr_time, self.signature, self.writer, self.save_path | ||
and self.tmp_parameters which are needed for saving the models and writing to tensorboard | ||
Args: | ||
log_dir_prepend (str): A prefix that is written before the saved data | ||
""" | ||
ut.ensure_results_folders_exist() | ||
self.curr_time = time.strftime('%b%d_%H-%M-%S') | ||
self.signature = 'Stable_Baselines_Training' | ||
self.writer = SummaryWriter(log_dir=os.path.join(PathManager.results_path, 'runs', f'{log_dir_prepend}training_{self.curr_time}')) | ||
path_name = f'{self.signature}_{self.curr_time}' | ||
self.save_path = os.path.join(PathManager.results_path, 'trainedModels', log_dir_prepend + path_name) | ||
os.makedirs(os.path.abspath(self.save_path), exist_ok=True) | ||
self.tmp_parameters = os.path.join(self.save_path, 'tmp_model.zip') | ||
|
||
def _on_step(self) -> bool: | ||
""" | ||
This method is called at every step by the stable baselines agents. | ||
""" | ||
self.tqdm_instance.update() | ||
if (self.num_timesteps - 1) % config.episode_length != 0 or self.num_timesteps <= config.episode_length: | ||
return True | ||
self.tqdm_instance.refresh() | ||
finished_episodes = self.num_timesteps // config.episode_length | ||
x, y = ts2xy(load_results(self.save_path), 'timesteps') | ||
assert len(x) > 0 and len(x) == len(y) | ||
mean_reward = np.mean(y[-100:]) | ||
|
||
# consider print info | ||
if (finished_episodes) % 10 == 0: | ||
tqdm.write(f'{self.num_timesteps}: {finished_episodes} episodes trained, mean return {mean_reward:.3f}') | ||
|
||
# consider update best model | ||
if self.best_mean_interim_reward is None or mean_reward > self.best_mean_interim_reward + 15: | ||
self.model.save(self.tmp_parameters) | ||
self.best_mean_interim_reward = mean_reward | ||
if self.best_mean_overall_reward is None or self.best_mean_interim_reward > self.best_mean_overall_reward: | ||
if self.best_mean_overall_reward is not None: | ||
tqdm.write(f'Best overall reward updated {self.best_mean_overall_reward:.3f} -> {self.best_mean_interim_reward:.3f}') | ||
self.best_mean_overall_reward = self.best_mean_interim_reward | ||
|
||
# consider save model | ||
if (finished_episodes % self.iteration_length == 0 and finished_episodes > 0) and self.best_mean_interim_reward is not None: | ||
self.save_parameters(finished_episodes) | ||
|
||
return True | ||
|
||
def _on_training_end(self) -> None: | ||
self.tqdm_instance.close() | ||
if self.best_mean_interim_reward is not None: | ||
finished_episodes = self.num_timesteps // config.episode_length | ||
self.save_parameters(finished_episodes) | ||
|
||
# analyze trained agents | ||
if len(self.saved_parameter_paths) == 0: | ||
print('No agents saved! Nothing to monitor.') | ||
return | ||
monitor = Monitor() | ||
agent_list = [(self.agent_class, [parameter_path]) for parameter_path in self.saved_parameter_paths] | ||
monitor.configurator.setup_monitoring(False, 250, 250, self.marketplace_class, agent_list, support_continuous_action_space=True) | ||
rewards = monitor.run_marketplace() | ||
episode_numbers = [int(parameter_path[-9:][:5]) for parameter_path in self.saved_parameter_paths] | ||
Evaluator(monitor.configurator).evaluate_session(rewards, episode_numbers) | ||
|
||
def save_parameters(self, finished_episodes: int): | ||
assert isinstance(finished_episodes, int) | ||
path_to_parameters = os.path.join(self.save_path, f'{self.signature}_{finished_episodes:05d}.zip') | ||
os.rename(self.tmp_parameters, path_to_parameters) | ||
self.saved_parameter_paths.append(path_to_parameters) | ||
tqdm.write(f'I write the interim model after {finished_episodes} episodes to the disk.') | ||
tqdm.write(f'You can find the parameters here: {path_to_parameters}.') | ||
tqdm.write(f'This model achieved a mean reward of {self.best_mean_interim_reward}.') | ||
self.best_mean_interim_reward = None | ||
|
||
def _end_of_training(self) -> None: | ||
""" | ||
Inform the user of the best_mean_overall_reward the agent achieved during training. | ||
""" | ||
if self.best_mean_overall_reward is None: | ||
print('The `best_mean_overall_reward` has never been set. Is this expected?') | ||
else: | ||
print(f'The best mean reward reached by the agent was {self.best_mean_overall_reward:.3f}') | ||
print('The models were saved to:') | ||
print(os.path.abspath(self.save_path)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import numpy as np | ||
import stable_baselines3.common.monitor | ||
from stable_baselines3 import A2C, DDPG, PPO, SAC, TD3 | ||
from stable_baselines3.common.noise import NormalActionNoise | ||
|
||
from recommerce.market.circular.circular_vendors import CircularAgent | ||
from recommerce.market.linear.linear_vendors import LinearAgent | ||
from recommerce.market.sim_market import SimMarket | ||
from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent | ||
from recommerce.rl.stable_baselines.stable_baselines_callback import PerStepCheck | ||
|
||
|
||
class StableBaselinesAgent(ReinforcementLearningAgent, LinearAgent, CircularAgent): | ||
def __init__(self, marketplace=None, optim=None, load_path=None, name=None): | ||
assert marketplace is not None | ||
assert isinstance(marketplace, SimMarket), \ | ||
f'if marketplace is provided, marketplace must be a SimMarket, but is {type(marketplace)}' | ||
assert optim is None | ||
assert load_path is None or isinstance(load_path, str) | ||
assert name is None or isinstance(name, str) | ||
|
||
self.marketplace = marketplace | ||
if load_path is None: | ||
self._initialize_model(marketplace) | ||
print(f'I initiate {self.name}-agent using {self.model.device} device') | ||
if load_path is not None: | ||
self._load(load_path) | ||
print(f'I load {self.name}-agent using {self.model.device} device from {load_path}') | ||
|
||
if name is not None: | ||
self.name = name | ||
|
||
def policy(self, observation: np.array) -> np.array: | ||
assert isinstance(observation, np.ndarray), f'{observation}: this is a {type(observation)}, not a np ndarray' | ||
return self.model.predict(observation)[0] | ||
|
||
def synchronize_tgt_net(self): # pragma: no cover | ||
assert False, 'This method may never be used in a StableBaselinesAgent!' | ||
|
||
def train_agent(self, training_steps=100000, iteration_length=500): | ||
callback = PerStepCheck(type(self), type(self.marketplace), training_steps=training_steps, iteration_length=iteration_length) | ||
self.model.set_env(stable_baselines3.common.monitor.Monitor(self.marketplace, callback.save_path)) | ||
self.model.learn(training_steps, callback=callback) | ||
|
||
|
||
class StableBaselinesDDPG(StableBaselinesAgent): | ||
name = 'Stable_Baselines_DDPG' | ||
|
||
def _initialize_model(self, marketplace): | ||
n_actions = marketplace.get_actions_dimension() | ||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions)) | ||
self.model = DDPG('MlpPolicy', marketplace, action_noise=action_noise, verbose=False) | ||
|
||
def _load(self, load_path): | ||
self.model = DDPG.load(load_path) | ||
|
||
|
||
class StableBaselinesTD3(StableBaselinesAgent): | ||
name = 'Stable_Baselines_TD3' | ||
|
||
def _initialize_model(self, marketplace): | ||
n_actions = marketplace.get_actions_dimension() | ||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions)) | ||
self.model = TD3('MlpPolicy', marketplace, action_noise=action_noise, verbose=False) | ||
|
||
def _load(self, load_path): | ||
self.model = TD3.load(load_path) | ||
|
||
|
||
class StableBaselinesA2C(StableBaselinesAgent): | ||
name = 'Stable_Baselines_A2C' | ||
|
||
def _initialize_model(self, marketplace): | ||
self.model = A2C('MlpPolicy', marketplace, verbose=False) | ||
|
||
def _load(self, load_path): | ||
self.model = A2C.load(load_path) | ||
|
||
|
||
class StableBaselinesPPO(StableBaselinesAgent): | ||
name = 'Stable_Baselines_PPO' | ||
|
||
def _initialize_model(self, marketplace): | ||
self.model = PPO('MlpPolicy', marketplace, verbose=False) | ||
|
||
def _load(self, load_path): | ||
self.model = PPO.load(load_path) | ||
|
||
|
||
class StableBaselinesSAC(StableBaselinesAgent): | ||
name = 'Stable_Baselines_SAC' | ||
|
||
def _initialize_model(self, marketplace): | ||
self.model = SAC('MlpPolicy', marketplace, verbose=False) | ||
|
||
def _load(self, load_path): | ||
self.model = SAC.load(load_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.