Skip to content

Commit

Permalink
[D] Stable baselines integration (#384)
Browse files Browse the repository at this point in the history
* refactored reinforcement learning agent to accept marketplace

* adapted test_exampleprinter.py to marketplace initialization

* add market option to accept continuos actions

* fixed action space check

* initial stable baselines integration

* Agent init by env (#390)

* introduced self.network in actorcritic_agent

* added network_architecture in QLearningAgent

* changed actorcritic_agent to network_architecture

* set back training_scenario

* am_configuration initialize rl-agent via marketplace

* added final analyse to stable baselines training

* added more stable baselines algorithms

* added ppo algorithm

* introduced stable_baselines_folder

* renamed training to callback

* satisfied linter

* fixed loading problem

* try to make tqdm run in stable_baselines

* make tqdm running

* reduced pmonitoring episodes in sb training

* save model only if significantly better

* fixed too long test time bug

* moved back to 250 episodes testing

* set timeout to 15 minutes

* added first batch of fixes to @NikkelM feedback

* added type annotations and asserts in stable_baselines_model

* added sbtraining to training_scenario

* applied comments in am_configuration

* solved .dat problem and fixed crashing asserts

* reintroduced _end_of_training

* removed deprecated if

* Moved '.dat' to function call instead of appending within function

* Fixed assert

* fixed model file ending bug

* Add short explanation docstring

Co-authored-by: Johann Schulze Tast <[email protected]>

* fixed wrong docstring

* Fixed tests

Co-authored-by: NikkelM <[email protected]>
Co-authored-by: Johann Schulze Tast <[email protected]>
  • Loading branch information
3 people authored Apr 8, 2022
1 parent cd4f491 commit f8bc162
Show file tree
Hide file tree
Showing 12 changed files with 303 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
build-linux:
runs-on: [self-hosted, ubuntu-20.04]
name: CI on Ubuntu
timeout-minutes: 10
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions recommerce/configuration/environment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def _parse_and_set_agents(self, agent_list: list, needs_modelfile: bool) -> None
if needs_modelfile and issubclass(agent['agent_class'], (QLearningAgent, ActorCriticAgent)):
assert isinstance(agent['argument'], str), \
f'The "argument" field of this agent ({agent["name"]}) must be a string but was ({type(agent["argument"])})'
assert agent['argument'].endswith('.dat'), \
f'The "argument" field must contain a modelfile and therefore end in ".dat": {agent["argument"]}'
assert agent['argument'].endswith('.dat') or agent['argument'].endswith('.zip'), \
f'The "argument" field must contain a modelfile and therefore end in ".dat" or ".zip": {agent["argument"]}'
# Check that the modelfile exists. Taken from am_configuration::_get_modelfile_path()
full_path = os.path.abspath(os.path.join(PathManager.data_path, agent['argument']))
assert os.path.exists(full_path), f'the specified modelfile does not exist: {full_path}'
Expand Down
27 changes: 15 additions & 12 deletions recommerce/monitoring/agent_monitoring/am_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import recommerce.market.circular.circular_sim_market as circular_market
import recommerce.market.linear.linear_sim_market as linear_market
import recommerce.market.sim_market as sim_market
import recommerce.rl.actorcritic.actorcritic_agent as actorcritic_agent
from recommerce.configuration.path_manager import PathManager
from recommerce.market.circular.circular_vendors import CircularAgent, FixedPriceCEAgent
from recommerce.market.linear.linear_vendors import LinearAgent
from recommerce.market.vendors import Agent, HumanPlayer, RuleBasedAgent
from recommerce.rl.q_learning.q_learning_agent import QLearningAgent
from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent


class Configurator():
Expand Down Expand Up @@ -52,7 +52,7 @@ def _get_modelfile_path(self, model_name: str) -> str:
Returns:
str: The full path to the modelfile.
"""
model_name += '.dat'
assert model_name.endswith('.dat') or model_name.endswith('.zip'), f'Modelfiles must end in .dat or .zip: {model_name}'
full_path = os.path.join(PathManager.data_path, model_name)
assert os.path.exists(full_path), f'the specified modelfile does not exist: {full_path}'
return full_path
Expand Down Expand Up @@ -91,28 +91,30 @@ def _update_agents(self, agents) -> None:
# The custom_init takes two parameters: The class of the agent to be initialized and a list of arguments,
# e.g. for the fixed prices or names
self.agents.append(Agent.custom_init(current_agent[0], current_agent[1]))
elif issubclass(current_agent[0], (QLearningAgent, actorcritic_agent.ActorCriticAgent)):
elif issubclass(current_agent[0], ReinforcementLearningAgent):
try:
assert (0 <= len(current_agent[1]) <= 2), 'the argument list for a RL-agent must have length between 0 and 2'
assert all(isinstance(argument, str) for argument in current_agent[1]), 'the arguments for a RL-agent must be of type str'

agent_modelfile = f'{type(self.marketplace).__name__}_{current_agent[0].__name__}'
# Stablebaselines ends in .zip - we don't
agent_modelfile = f'{type(self.marketplace).__name__}_{current_agent[0].__name__}.dat'
agent_name = 'q_learning' if issubclass(current_agent[0], QLearningAgent) else 'actor_critic'
# no arguments
if len(current_agent[1]) == 0:
pass
# only modelfile argument
elif len(current_agent[1]) == 1 and \
(current_agent[1][0].endswith('.dat') or current_agent[1][0].endswith('.zip')):
agent_modelfile = current_agent[1][0]
# only name argument
elif len(current_agent[1]) == 1 and not str.endswith(current_agent[1][0], '.dat'):
elif len(current_agent[1]) == 1:
# get implicit modelfile name
agent_name = current_agent[1][0]
# only modelfile argument
elif len(current_agent[1]) == 1 and str.endswith(current_agent[1][0], '.dat'):
agent_modelfile = current_agent[1][0][:-4]
# both arguments, first must be the modelfile, second the name
elif len(current_agent[1]) == 2:
assert str.endswith(current_agent[1][0], '.dat'), \
assert current_agent[1][0].endswith('.dat'), \
f'if two arguments are provided, the first one must be the modelfile. Arg1: {current_agent[1][0]}, Arg2: {current_agent[1][1]}'
agent_modelfile = current_agent[1][0][:-4]
agent_modelfile = current_agent[1][0]
agent_name = current_agent[1][1]
# this should never happen due to the asserts before, but you never know
else: # pragma: no cover
Expand All @@ -137,7 +139,8 @@ def setup_monitoring(
plot_interval: int = None,
marketplace: sim_market.SimMarket = None,
agents: list = None,
subfolder_name: str = None) -> None:
subfolder_name: str = None,
support_continuous_action_space: bool = False) -> None:
"""
Configure the current monitoring session.
Expand Down Expand Up @@ -171,7 +174,7 @@ def setup_monitoring(

if(marketplace is not None):
assert issubclass(marketplace, sim_market.SimMarket), 'the marketplace must be a subclass of SimMarket'
self.marketplace = marketplace()
self.marketplace = marketplace(support_continuous_action_space)
# If the agents have not been changed, we reuse the old agents
if(agents is None):
print('Warning: Your agents are being overwritten by new instances of themselves!')
Expand Down
4 changes: 2 additions & 2 deletions recommerce/rl/actorcritic/actorcritic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def save(self, model_path, model_name) -> None:
model_path (str): The path to the folder within 'trainedModels' where the model should be saved.
model_name (str): The name of the .dat file of this specific model.
"""
model_name += '.dat'

assert model_name.endswith('.dat'), f'the modelname must end in ".dat": {model_name}'
assert os.path.exists(model_path), f'the specified path does not exist: {model_path}'
actor_path = os.path.join(model_path, f'actor_parameters{model_name}')
torch.save(self.best_interim_actor_net.state_dict(), actor_path)
torch.save(self.best_interim_critic_net.state_dict(), os.path.join(model_path, 'critic_parameters' + model_name))
Expand Down
6 changes: 3 additions & 3 deletions recommerce/rl/q_learning/q_learning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,16 @@ def calc_loss(self, batch, device='cpu'):
def sync_to_best_interim(self):
self.best_interim_net.load_state_dict(self.net.state_dict())

def save(self, model_path, model_name) -> None:
def save(self, model_path: str, model_name: str) -> None:
"""
Save a trained model to the specified folder within 'trainedModels'.
Args:
model_path (str): The path to the folder within 'trainedModels' where the model should be saved.
model_name (str): The name of the .dat file of this specific model.
"""
model_name += '.dat'

assert model_name.endswith('.dat'), f'the modelname must end in ".dat": {model_name}'
assert os.path.exists(model_path), f'the specified path does not exist: {model_path}'
parameters_path = os.path.join(model_path, model_name)
torch.save(self.best_interim_net.state_dict(), parameters_path)
return parameters_path
Expand Down
141 changes: 141 additions & 0 deletions recommerce/rl/stable_baselines/stable_baselines_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os
import signal
import sys
import time
import warnings

import numpy as np
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.results_plotter import load_results, ts2xy
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from tqdm.auto import trange

import recommerce.configuration.utils as ut
from recommerce.configuration.hyperparameter_config import config
from recommerce.configuration.path_manager import PathManager
from recommerce.market.sim_market import SimMarket
from recommerce.monitoring.agent_monitoring.am_evaluation import Evaluator
from recommerce.monitoring.agent_monitoring.am_monitoring import Monitor
from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent

warnings.filterwarnings('ignore')


class PerStepCheck(BaseCallback):
"""
Callback for saving a model (the check is done every `check_freq` steps)
based on the training reward (in practice, we recommend using `EvalCallback`).
"""
def __init__(self, agent_class, marketplace_class, log_dir_prepend='', training_steps=10000, iteration_length=500):
assert issubclass(agent_class, ReinforcementLearningAgent)
assert issubclass(marketplace_class, SimMarket)
assert isinstance(log_dir_prepend, str), \
f'log_dir_prepend should be a string, but {log_dir_prepend} is {type(log_dir_prepend)}'
assert isinstance(training_steps, int) and training_steps > 0
assert isinstance(iteration_length, int) and iteration_length > 0
super(PerStepCheck, self).__init__(True)
self.best_mean_interim_reward = None
self.best_mean_overall_reward = None
self.marketplace_class = marketplace_class
self.agent_class = agent_class
self.iteration_length = iteration_length
self.tqdm_instance = trange(training_steps)
self.saved_parameter_paths = []
signal.signal(signal.SIGINT, self._signal_handler)

self.initialize_io_related(log_dir_prepend)

def _signal_handler(self, signum, frame) -> None: # pragma: no cover
"""
Handle any interruptions to the running process, such as a `KeyboardInterrupt`-event.
"""
print('\nAborting training...')
self._end_of_training()
sys.exit(0)

def initialize_io_related(self, log_dir_prepend) -> None:
"""
Initializes the local variables self.curr_time, self.signature, self.writer, self.save_path
and self.tmp_parameters which are needed for saving the models and writing to tensorboard
Args:
log_dir_prepend (str): A prefix that is written before the saved data
"""
ut.ensure_results_folders_exist()
self.curr_time = time.strftime('%b%d_%H-%M-%S')
self.signature = 'Stable_Baselines_Training'
self.writer = SummaryWriter(log_dir=os.path.join(PathManager.results_path, 'runs', f'{log_dir_prepend}training_{self.curr_time}'))
path_name = f'{self.signature}_{self.curr_time}'
self.save_path = os.path.join(PathManager.results_path, 'trainedModels', log_dir_prepend + path_name)
os.makedirs(os.path.abspath(self.save_path), exist_ok=True)
self.tmp_parameters = os.path.join(self.save_path, 'tmp_model.zip')

def _on_step(self) -> bool:
"""
This method is called at every step by the stable baselines agents.
"""
self.tqdm_instance.update()
if (self.num_timesteps - 1) % config.episode_length != 0 or self.num_timesteps <= config.episode_length:
return True
self.tqdm_instance.refresh()
finished_episodes = self.num_timesteps // config.episode_length
x, y = ts2xy(load_results(self.save_path), 'timesteps')
assert len(x) > 0 and len(x) == len(y)
mean_reward = np.mean(y[-100:])

# consider print info
if (finished_episodes) % 10 == 0:
tqdm.write(f'{self.num_timesteps}: {finished_episodes} episodes trained, mean return {mean_reward:.3f}')

# consider update best model
if self.best_mean_interim_reward is None or mean_reward > self.best_mean_interim_reward + 15:
self.model.save(self.tmp_parameters)
self.best_mean_interim_reward = mean_reward
if self.best_mean_overall_reward is None or self.best_mean_interim_reward > self.best_mean_overall_reward:
if self.best_mean_overall_reward is not None:
tqdm.write(f'Best overall reward updated {self.best_mean_overall_reward:.3f} -> {self.best_mean_interim_reward:.3f}')
self.best_mean_overall_reward = self.best_mean_interim_reward

# consider save model
if (finished_episodes % self.iteration_length == 0 and finished_episodes > 0) and self.best_mean_interim_reward is not None:
self.save_parameters(finished_episodes)

return True

def _on_training_end(self) -> None:
self.tqdm_instance.close()
if self.best_mean_interim_reward is not None:
finished_episodes = self.num_timesteps // config.episode_length
self.save_parameters(finished_episodes)

# analyze trained agents
if len(self.saved_parameter_paths) == 0:
print('No agents saved! Nothing to monitor.')
return
monitor = Monitor()
agent_list = [(self.agent_class, [parameter_path]) for parameter_path in self.saved_parameter_paths]
monitor.configurator.setup_monitoring(False, 250, 250, self.marketplace_class, agent_list, support_continuous_action_space=True)
rewards = monitor.run_marketplace()
episode_numbers = [int(parameter_path[-9:][:5]) for parameter_path in self.saved_parameter_paths]
Evaluator(monitor.configurator).evaluate_session(rewards, episode_numbers)

def save_parameters(self, finished_episodes: int):
assert isinstance(finished_episodes, int)
path_to_parameters = os.path.join(self.save_path, f'{self.signature}_{finished_episodes:05d}.zip')
os.rename(self.tmp_parameters, path_to_parameters)
self.saved_parameter_paths.append(path_to_parameters)
tqdm.write(f'I write the interim model after {finished_episodes} episodes to the disk.')
tqdm.write(f'You can find the parameters here: {path_to_parameters}.')
tqdm.write(f'This model achieved a mean reward of {self.best_mean_interim_reward}.')
self.best_mean_interim_reward = None

def _end_of_training(self) -> None:
"""
Inform the user of the best_mean_overall_reward the agent achieved during training.
"""
if self.best_mean_overall_reward is None:
print('The `best_mean_overall_reward` has never been set. Is this expected?')
else:
print(f'The best mean reward reached by the agent was {self.best_mean_overall_reward:.3f}')
print('The models were saved to:')
print(os.path.abspath(self.save_path))
97 changes: 97 additions & 0 deletions recommerce/rl/stable_baselines/stable_baselines_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np
import stable_baselines3.common.monitor
from stable_baselines3 import A2C, DDPG, PPO, SAC, TD3
from stable_baselines3.common.noise import NormalActionNoise

from recommerce.market.circular.circular_vendors import CircularAgent
from recommerce.market.linear.linear_vendors import LinearAgent
from recommerce.market.sim_market import SimMarket
from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent
from recommerce.rl.stable_baselines.stable_baselines_callback import PerStepCheck


class StableBaselinesAgent(ReinforcementLearningAgent, LinearAgent, CircularAgent):
def __init__(self, marketplace=None, optim=None, load_path=None, name=None):
assert marketplace is not None
assert isinstance(marketplace, SimMarket), \
f'if marketplace is provided, marketplace must be a SimMarket, but is {type(marketplace)}'
assert optim is None
assert load_path is None or isinstance(load_path, str)
assert name is None or isinstance(name, str)

self.marketplace = marketplace
if load_path is None:
self._initialize_model(marketplace)
print(f'I initiate {self.name}-agent using {self.model.device} device')
if load_path is not None:
self._load(load_path)
print(f'I load {self.name}-agent using {self.model.device} device from {load_path}')

if name is not None:
self.name = name

def policy(self, observation: np.array) -> np.array:
assert isinstance(observation, np.ndarray), f'{observation}: this is a {type(observation)}, not a np ndarray'
return self.model.predict(observation)[0]

def synchronize_tgt_net(self): # pragma: no cover
assert False, 'This method may never be used in a StableBaselinesAgent!'

def train_agent(self, training_steps=100000, iteration_length=500):
callback = PerStepCheck(type(self), type(self.marketplace), training_steps=training_steps, iteration_length=iteration_length)
self.model.set_env(stable_baselines3.common.monitor.Monitor(self.marketplace, callback.save_path))
self.model.learn(training_steps, callback=callback)


class StableBaselinesDDPG(StableBaselinesAgent):
name = 'Stable_Baselines_DDPG'

def _initialize_model(self, marketplace):
n_actions = marketplace.get_actions_dimension()
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions))
self.model = DDPG('MlpPolicy', marketplace, action_noise=action_noise, verbose=False)

def _load(self, load_path):
self.model = DDPG.load(load_path)


class StableBaselinesTD3(StableBaselinesAgent):
name = 'Stable_Baselines_TD3'

def _initialize_model(self, marketplace):
n_actions = marketplace.get_actions_dimension()
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions))
self.model = TD3('MlpPolicy', marketplace, action_noise=action_noise, verbose=False)

def _load(self, load_path):
self.model = TD3.load(load_path)


class StableBaselinesA2C(StableBaselinesAgent):
name = 'Stable_Baselines_A2C'

def _initialize_model(self, marketplace):
self.model = A2C('MlpPolicy', marketplace, verbose=False)

def _load(self, load_path):
self.model = A2C.load(load_path)


class StableBaselinesPPO(StableBaselinesAgent):
name = 'Stable_Baselines_PPO'

def _initialize_model(self, marketplace):
self.model = PPO('MlpPolicy', marketplace, verbose=False)

def _load(self, load_path):
self.model = PPO.load(load_path)


class StableBaselinesSAC(StableBaselinesAgent):
name = 'Stable_Baselines_SAC'

def _initialize_model(self, marketplace):
self.model = SAC('MlpPolicy', marketplace, verbose=False)

def _load(self, load_path):
self.model = SAC.load(load_path)
2 changes: 1 addition & 1 deletion recommerce/rl/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def consider_print_info(self, frame_idx: int, episode_number: int, averaged_info

def consider_save_model(self, episodes_idx: int, force: bool = False) -> None:
if ((episodes_idx % 500 == 0 and episodes_idx > 0) or force) and self.best_mean_interim_reward is not None:
path_to_parameters = self.RL_agent.save(model_path=self.model_path, model_name=f'{self.signature}_{episodes_idx:05d}')
path_to_parameters = self.RL_agent.save(model_path=self.model_path, model_name=f'{self.signature}_{episodes_idx:05d}.dat')
tqdm.write(f'I write the interim model after {episodes_idx} episodes to the disk.')
tqdm.write(f'You can find the parameters here: {path_to_parameters}.')
tqdm.write(f'This model achieved a mean reward of {self.best_mean_interim_reward}.')
Expand Down
5 changes: 5 additions & 0 deletions recommerce/rl/training_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import recommerce.market.sim_market as sim_market
import recommerce.rl.actorcritic.actorcritic_agent as actorcritic_agent
import recommerce.rl.q_learning.q_learning_agent as q_learning_agent
import recommerce.rl.stable_baselines.stable_baselines_model as sbmodel
from recommerce.configuration.environment_config import EnvironmentConfigLoader, TrainingEnvironmentConfig
from recommerce.market.circular.circular_vendors import CircularAgent
from recommerce.rl.actorcritic.actorcritic_training import ActorCriticTrainer
Expand Down Expand Up @@ -52,6 +53,10 @@ def train_continuos_a2c_circular_economy_rebuy():
run_training_session(circular_market.CircularEconomyRebuyPriceOneCompetitor, actorcritic_agent.ContinuosActorCriticAgentFixedOneStd)


def train_stable_baselines_ppo():
sbmodel.StableBaselinesPPO(circular_market.CircularEconomyRebuyPriceOneCompetitor(True)).train_agent()


def train_from_config():
"""
Use the `environment_config_training.json` file to decide on the training parameters.
Expand Down
Loading

0 comments on commit f8bc162

Please sign in to comment.