From f8bc1622e763f3502d103b77174517383c8fa8a2 Mon Sep 17 00:00:00 2001 From: jannikgro <52510222+jannikgro@users.noreply.github.com> Date: Fri, 8 Apr 2022 13:20:42 +0200 Subject: [PATCH] [D] Stable baselines integration (#384) * refactored reinforcement learning agent to accept marketplace * adapted test_exampleprinter.py to marketplace initialization * add market option to accept continuos actions * fixed action space check * initial stable baselines integration * Agent init by env (#390) * introduced self.network in actorcritic_agent * added network_architecture in QLearningAgent * changed actorcritic_agent to network_architecture * set back training_scenario * am_configuration initialize rl-agent via marketplace * added final analyse to stable baselines training * added more stable baselines algorithms * added ppo algorithm * introduced stable_baselines_folder * renamed training to callback * satisfied linter * fixed loading problem * try to make tqdm run in stable_baselines * make tqdm running * reduced pmonitoring episodes in sb training * save model only if significantly better * fixed too long test time bug * moved back to 250 episodes testing * set timeout to 15 minutes * added first batch of fixes to @NikkelM feedback * added type annotations and asserts in stable_baselines_model * added sbtraining to training_scenario * applied comments in am_configuration * solved .dat problem and fixed crashing asserts * reintroduced _end_of_training * removed deprecated if * Moved '.dat' to function call instead of appending within function * Fixed assert * fixed model file ending bug * Add short explanation docstring Co-authored-by: Johann Schulze Tast <35633229+blackjack2693@users.noreply.github.com> * fixed wrong docstring * Fixed tests Co-authored-by: NikkelM Co-authored-by: Johann Schulze Tast <35633229+blackjack2693@users.noreply.github.com> --- .github/workflows/CI.yml | 2 +- .../configuration/environment_config.py | 4 +- .../agent_monitoring/am_configuration.py | 27 ++-- .../rl/actorcritic/actorcritic_agent.py | 4 +- recommerce/rl/q_learning/q_learning_agent.py | 6 +- .../stable_baselines_callback.py | 141 ++++++++++++++++++ .../stable_baselines_model.py | 97 ++++++++++++ recommerce/rl/training.py | 2 +- recommerce/rl/training_scenario.py | 5 + setup.cfg | 1 + .../test_am_configuration.py | 2 +- tests/test_stable_baselines_training.py | 34 +++++ 12 files changed, 303 insertions(+), 22 deletions(-) create mode 100644 recommerce/rl/stable_baselines/stable_baselines_callback.py create mode 100644 recommerce/rl/stable_baselines/stable_baselines_model.py create mode 100644 tests/test_stable_baselines_training.py diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 14f9d75a..ef0df005 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -6,7 +6,7 @@ jobs: build-linux: runs-on: [self-hosted, ubuntu-20.04] name: CI on Ubuntu - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: Checkout uses: actions/checkout@v2 diff --git a/recommerce/configuration/environment_config.py b/recommerce/configuration/environment_config.py index 5b9f8c6d..9f033641 100644 --- a/recommerce/configuration/environment_config.py +++ b/recommerce/configuration/environment_config.py @@ -181,8 +181,8 @@ def _parse_and_set_agents(self, agent_list: list, needs_modelfile: bool) -> None if needs_modelfile and issubclass(agent['agent_class'], (QLearningAgent, ActorCriticAgent)): assert isinstance(agent['argument'], str), \ f'The "argument" field of this agent ({agent["name"]}) must be a string but was ({type(agent["argument"])})' - assert agent['argument'].endswith('.dat'), \ - f'The "argument" field must contain a modelfile and therefore end in ".dat": {agent["argument"]}' + assert agent['argument'].endswith('.dat') or agent['argument'].endswith('.zip'), \ + f'The "argument" field must contain a modelfile and therefore end in ".dat" or ".zip": {agent["argument"]}' # Check that the modelfile exists. Taken from am_configuration::_get_modelfile_path() full_path = os.path.abspath(os.path.join(PathManager.data_path, agent['argument'])) assert os.path.exists(full_path), f'the specified modelfile does not exist: {full_path}' diff --git a/recommerce/monitoring/agent_monitoring/am_configuration.py b/recommerce/monitoring/agent_monitoring/am_configuration.py index 0ccf5a71..00854741 100644 --- a/recommerce/monitoring/agent_monitoring/am_configuration.py +++ b/recommerce/monitoring/agent_monitoring/am_configuration.py @@ -7,12 +7,12 @@ import recommerce.market.circular.circular_sim_market as circular_market import recommerce.market.linear.linear_sim_market as linear_market import recommerce.market.sim_market as sim_market -import recommerce.rl.actorcritic.actorcritic_agent as actorcritic_agent from recommerce.configuration.path_manager import PathManager from recommerce.market.circular.circular_vendors import CircularAgent, FixedPriceCEAgent from recommerce.market.linear.linear_vendors import LinearAgent from recommerce.market.vendors import Agent, HumanPlayer, RuleBasedAgent from recommerce.rl.q_learning.q_learning_agent import QLearningAgent +from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent class Configurator(): @@ -52,7 +52,7 @@ def _get_modelfile_path(self, model_name: str) -> str: Returns: str: The full path to the modelfile. """ - model_name += '.dat' + assert model_name.endswith('.dat') or model_name.endswith('.zip'), f'Modelfiles must end in .dat or .zip: {model_name}' full_path = os.path.join(PathManager.data_path, model_name) assert os.path.exists(full_path), f'the specified modelfile does not exist: {full_path}' return full_path @@ -91,28 +91,30 @@ def _update_agents(self, agents) -> None: # The custom_init takes two parameters: The class of the agent to be initialized and a list of arguments, # e.g. for the fixed prices or names self.agents.append(Agent.custom_init(current_agent[0], current_agent[1])) - elif issubclass(current_agent[0], (QLearningAgent, actorcritic_agent.ActorCriticAgent)): + elif issubclass(current_agent[0], ReinforcementLearningAgent): try: assert (0 <= len(current_agent[1]) <= 2), 'the argument list for a RL-agent must have length between 0 and 2' assert all(isinstance(argument, str) for argument in current_agent[1]), 'the arguments for a RL-agent must be of type str' - agent_modelfile = f'{type(self.marketplace).__name__}_{current_agent[0].__name__}' + # Stablebaselines ends in .zip - we don't + agent_modelfile = f'{type(self.marketplace).__name__}_{current_agent[0].__name__}.dat' agent_name = 'q_learning' if issubclass(current_agent[0], QLearningAgent) else 'actor_critic' # no arguments if len(current_agent[1]) == 0: pass + # only modelfile argument + elif len(current_agent[1]) == 1 and \ + (current_agent[1][0].endswith('.dat') or current_agent[1][0].endswith('.zip')): + agent_modelfile = current_agent[1][0] # only name argument - elif len(current_agent[1]) == 1 and not str.endswith(current_agent[1][0], '.dat'): + elif len(current_agent[1]) == 1: # get implicit modelfile name agent_name = current_agent[1][0] - # only modelfile argument - elif len(current_agent[1]) == 1 and str.endswith(current_agent[1][0], '.dat'): - agent_modelfile = current_agent[1][0][:-4] # both arguments, first must be the modelfile, second the name elif len(current_agent[1]) == 2: - assert str.endswith(current_agent[1][0], '.dat'), \ + assert current_agent[1][0].endswith('.dat'), \ f'if two arguments are provided, the first one must be the modelfile. Arg1: {current_agent[1][0]}, Arg2: {current_agent[1][1]}' - agent_modelfile = current_agent[1][0][:-4] + agent_modelfile = current_agent[1][0] agent_name = current_agent[1][1] # this should never happen due to the asserts before, but you never know else: # pragma: no cover @@ -137,7 +139,8 @@ def setup_monitoring( plot_interval: int = None, marketplace: sim_market.SimMarket = None, agents: list = None, - subfolder_name: str = None) -> None: + subfolder_name: str = None, + support_continuous_action_space: bool = False) -> None: """ Configure the current monitoring session. @@ -171,7 +174,7 @@ def setup_monitoring( if(marketplace is not None): assert issubclass(marketplace, sim_market.SimMarket), 'the marketplace must be a subclass of SimMarket' - self.marketplace = marketplace() + self.marketplace = marketplace(support_continuous_action_space) # If the agents have not been changed, we reuse the old agents if(agents is None): print('Warning: Your agents are being overwritten by new instances of themselves!') diff --git a/recommerce/rl/actorcritic/actorcritic_agent.py b/recommerce/rl/actorcritic/actorcritic_agent.py index 33b121c0..f3dfbf54 100644 --- a/recommerce/rl/actorcritic/actorcritic_agent.py +++ b/recommerce/rl/actorcritic/actorcritic_agent.py @@ -85,8 +85,8 @@ def save(self, model_path, model_name) -> None: model_path (str): The path to the folder within 'trainedModels' where the model should be saved. model_name (str): The name of the .dat file of this specific model. """ - model_name += '.dat' - + assert model_name.endswith('.dat'), f'the modelname must end in ".dat": {model_name}' + assert os.path.exists(model_path), f'the specified path does not exist: {model_path}' actor_path = os.path.join(model_path, f'actor_parameters{model_name}') torch.save(self.best_interim_actor_net.state_dict(), actor_path) torch.save(self.best_interim_critic_net.state_dict(), os.path.join(model_path, 'critic_parameters' + model_name)) diff --git a/recommerce/rl/q_learning/q_learning_agent.py b/recommerce/rl/q_learning/q_learning_agent.py index 2c83991b..5b764b4e 100644 --- a/recommerce/rl/q_learning/q_learning_agent.py +++ b/recommerce/rl/q_learning/q_learning_agent.py @@ -123,7 +123,7 @@ def calc_loss(self, batch, device='cpu'): def sync_to_best_interim(self): self.best_interim_net.load_state_dict(self.net.state_dict()) - def save(self, model_path, model_name) -> None: + def save(self, model_path: str, model_name: str) -> None: """ Save a trained model to the specified folder within 'trainedModels'. @@ -131,8 +131,8 @@ def save(self, model_path, model_name) -> None: model_path (str): The path to the folder within 'trainedModels' where the model should be saved. model_name (str): The name of the .dat file of this specific model. """ - model_name += '.dat' - + assert model_name.endswith('.dat'), f'the modelname must end in ".dat": {model_name}' + assert os.path.exists(model_path), f'the specified path does not exist: {model_path}' parameters_path = os.path.join(model_path, model_name) torch.save(self.best_interim_net.state_dict(), parameters_path) return parameters_path diff --git a/recommerce/rl/stable_baselines/stable_baselines_callback.py b/recommerce/rl/stable_baselines/stable_baselines_callback.py new file mode 100644 index 00000000..6ad3e0d7 --- /dev/null +++ b/recommerce/rl/stable_baselines/stable_baselines_callback.py @@ -0,0 +1,141 @@ +import os +import signal +import sys +import time +import warnings + +import numpy as np +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.results_plotter import load_results, ts2xy +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from tqdm.auto import trange + +import recommerce.configuration.utils as ut +from recommerce.configuration.hyperparameter_config import config +from recommerce.configuration.path_manager import PathManager +from recommerce.market.sim_market import SimMarket +from recommerce.monitoring.agent_monitoring.am_evaluation import Evaluator +from recommerce.monitoring.agent_monitoring.am_monitoring import Monitor +from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent + +warnings.filterwarnings('ignore') + + +class PerStepCheck(BaseCallback): + """ + Callback for saving a model (the check is done every `check_freq` steps) + based on the training reward (in practice, we recommend using `EvalCallback`). + """ + def __init__(self, agent_class, marketplace_class, log_dir_prepend='', training_steps=10000, iteration_length=500): + assert issubclass(agent_class, ReinforcementLearningAgent) + assert issubclass(marketplace_class, SimMarket) + assert isinstance(log_dir_prepend, str), \ + f'log_dir_prepend should be a string, but {log_dir_prepend} is {type(log_dir_prepend)}' + assert isinstance(training_steps, int) and training_steps > 0 + assert isinstance(iteration_length, int) and iteration_length > 0 + super(PerStepCheck, self).__init__(True) + self.best_mean_interim_reward = None + self.best_mean_overall_reward = None + self.marketplace_class = marketplace_class + self.agent_class = agent_class + self.iteration_length = iteration_length + self.tqdm_instance = trange(training_steps) + self.saved_parameter_paths = [] + signal.signal(signal.SIGINT, self._signal_handler) + + self.initialize_io_related(log_dir_prepend) + + def _signal_handler(self, signum, frame) -> None: # pragma: no cover + """ + Handle any interruptions to the running process, such as a `KeyboardInterrupt`-event. + """ + print('\nAborting training...') + self._end_of_training() + sys.exit(0) + + def initialize_io_related(self, log_dir_prepend) -> None: + """ + Initializes the local variables self.curr_time, self.signature, self.writer, self.save_path + and self.tmp_parameters which are needed for saving the models and writing to tensorboard + Args: + log_dir_prepend (str): A prefix that is written before the saved data + """ + ut.ensure_results_folders_exist() + self.curr_time = time.strftime('%b%d_%H-%M-%S') + self.signature = 'Stable_Baselines_Training' + self.writer = SummaryWriter(log_dir=os.path.join(PathManager.results_path, 'runs', f'{log_dir_prepend}training_{self.curr_time}')) + path_name = f'{self.signature}_{self.curr_time}' + self.save_path = os.path.join(PathManager.results_path, 'trainedModels', log_dir_prepend + path_name) + os.makedirs(os.path.abspath(self.save_path), exist_ok=True) + self.tmp_parameters = os.path.join(self.save_path, 'tmp_model.zip') + + def _on_step(self) -> bool: + """ + This method is called at every step by the stable baselines agents. + """ + self.tqdm_instance.update() + if (self.num_timesteps - 1) % config.episode_length != 0 or self.num_timesteps <= config.episode_length: + return True + self.tqdm_instance.refresh() + finished_episodes = self.num_timesteps // config.episode_length + x, y = ts2xy(load_results(self.save_path), 'timesteps') + assert len(x) > 0 and len(x) == len(y) + mean_reward = np.mean(y[-100:]) + + # consider print info + if (finished_episodes) % 10 == 0: + tqdm.write(f'{self.num_timesteps}: {finished_episodes} episodes trained, mean return {mean_reward:.3f}') + + # consider update best model + if self.best_mean_interim_reward is None or mean_reward > self.best_mean_interim_reward + 15: + self.model.save(self.tmp_parameters) + self.best_mean_interim_reward = mean_reward + if self.best_mean_overall_reward is None or self.best_mean_interim_reward > self.best_mean_overall_reward: + if self.best_mean_overall_reward is not None: + tqdm.write(f'Best overall reward updated {self.best_mean_overall_reward:.3f} -> {self.best_mean_interim_reward:.3f}') + self.best_mean_overall_reward = self.best_mean_interim_reward + + # consider save model + if (finished_episodes % self.iteration_length == 0 and finished_episodes > 0) and self.best_mean_interim_reward is not None: + self.save_parameters(finished_episodes) + + return True + + def _on_training_end(self) -> None: + self.tqdm_instance.close() + if self.best_mean_interim_reward is not None: + finished_episodes = self.num_timesteps // config.episode_length + self.save_parameters(finished_episodes) + + # analyze trained agents + if len(self.saved_parameter_paths) == 0: + print('No agents saved! Nothing to monitor.') + return + monitor = Monitor() + agent_list = [(self.agent_class, [parameter_path]) for parameter_path in self.saved_parameter_paths] + monitor.configurator.setup_monitoring(False, 250, 250, self.marketplace_class, agent_list, support_continuous_action_space=True) + rewards = monitor.run_marketplace() + episode_numbers = [int(parameter_path[-9:][:5]) for parameter_path in self.saved_parameter_paths] + Evaluator(monitor.configurator).evaluate_session(rewards, episode_numbers) + + def save_parameters(self, finished_episodes: int): + assert isinstance(finished_episodes, int) + path_to_parameters = os.path.join(self.save_path, f'{self.signature}_{finished_episodes:05d}.zip') + os.rename(self.tmp_parameters, path_to_parameters) + self.saved_parameter_paths.append(path_to_parameters) + tqdm.write(f'I write the interim model after {finished_episodes} episodes to the disk.') + tqdm.write(f'You can find the parameters here: {path_to_parameters}.') + tqdm.write(f'This model achieved a mean reward of {self.best_mean_interim_reward}.') + self.best_mean_interim_reward = None + + def _end_of_training(self) -> None: + """ + Inform the user of the best_mean_overall_reward the agent achieved during training. + """ + if self.best_mean_overall_reward is None: + print('The `best_mean_overall_reward` has never been set. Is this expected?') + else: + print(f'The best mean reward reached by the agent was {self.best_mean_overall_reward:.3f}') + print('The models were saved to:') + print(os.path.abspath(self.save_path)) diff --git a/recommerce/rl/stable_baselines/stable_baselines_model.py b/recommerce/rl/stable_baselines/stable_baselines_model.py new file mode 100644 index 00000000..fa7b9f8a --- /dev/null +++ b/recommerce/rl/stable_baselines/stable_baselines_model.py @@ -0,0 +1,97 @@ +import numpy as np +import stable_baselines3.common.monitor +from stable_baselines3 import A2C, DDPG, PPO, SAC, TD3 +from stable_baselines3.common.noise import NormalActionNoise + +from recommerce.market.circular.circular_vendors import CircularAgent +from recommerce.market.linear.linear_vendors import LinearAgent +from recommerce.market.sim_market import SimMarket +from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent +from recommerce.rl.stable_baselines.stable_baselines_callback import PerStepCheck + + +class StableBaselinesAgent(ReinforcementLearningAgent, LinearAgent, CircularAgent): + def __init__(self, marketplace=None, optim=None, load_path=None, name=None): + assert marketplace is not None + assert isinstance(marketplace, SimMarket), \ + f'if marketplace is provided, marketplace must be a SimMarket, but is {type(marketplace)}' + assert optim is None + assert load_path is None or isinstance(load_path, str) + assert name is None or isinstance(name, str) + + self.marketplace = marketplace + if load_path is None: + self._initialize_model(marketplace) + print(f'I initiate {self.name}-agent using {self.model.device} device') + if load_path is not None: + self._load(load_path) + print(f'I load {self.name}-agent using {self.model.device} device from {load_path}') + + if name is not None: + self.name = name + + def policy(self, observation: np.array) -> np.array: + assert isinstance(observation, np.ndarray), f'{observation}: this is a {type(observation)}, not a np ndarray' + return self.model.predict(observation)[0] + + def synchronize_tgt_net(self): # pragma: no cover + assert False, 'This method may never be used in a StableBaselinesAgent!' + + def train_agent(self, training_steps=100000, iteration_length=500): + callback = PerStepCheck(type(self), type(self.marketplace), training_steps=training_steps, iteration_length=iteration_length) + self.model.set_env(stable_baselines3.common.monitor.Monitor(self.marketplace, callback.save_path)) + self.model.learn(training_steps, callback=callback) + + +class StableBaselinesDDPG(StableBaselinesAgent): + name = 'Stable_Baselines_DDPG' + + def _initialize_model(self, marketplace): + n_actions = marketplace.get_actions_dimension() + action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions)) + self.model = DDPG('MlpPolicy', marketplace, action_noise=action_noise, verbose=False) + + def _load(self, load_path): + self.model = DDPG.load(load_path) + + +class StableBaselinesTD3(StableBaselinesAgent): + name = 'Stable_Baselines_TD3' + + def _initialize_model(self, marketplace): + n_actions = marketplace.get_actions_dimension() + action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=1 * np.ones(n_actions)) + self.model = TD3('MlpPolicy', marketplace, action_noise=action_noise, verbose=False) + + def _load(self, load_path): + self.model = TD3.load(load_path) + + +class StableBaselinesA2C(StableBaselinesAgent): + name = 'Stable_Baselines_A2C' + + def _initialize_model(self, marketplace): + self.model = A2C('MlpPolicy', marketplace, verbose=False) + + def _load(self, load_path): + self.model = A2C.load(load_path) + + +class StableBaselinesPPO(StableBaselinesAgent): + name = 'Stable_Baselines_PPO' + + def _initialize_model(self, marketplace): + self.model = PPO('MlpPolicy', marketplace, verbose=False) + + def _load(self, load_path): + self.model = PPO.load(load_path) + + +class StableBaselinesSAC(StableBaselinesAgent): + name = 'Stable_Baselines_SAC' + + def _initialize_model(self, marketplace): + self.model = SAC('MlpPolicy', marketplace, verbose=False) + + def _load(self, load_path): + self.model = SAC.load(load_path) diff --git a/recommerce/rl/training.py b/recommerce/rl/training.py index 4fcb8dcd..2e3a8821 100644 --- a/recommerce/rl/training.py +++ b/recommerce/rl/training.py @@ -93,7 +93,7 @@ def consider_print_info(self, frame_idx: int, episode_number: int, averaged_info def consider_save_model(self, episodes_idx: int, force: bool = False) -> None: if ((episodes_idx % 500 == 0 and episodes_idx > 0) or force) and self.best_mean_interim_reward is not None: - path_to_parameters = self.RL_agent.save(model_path=self.model_path, model_name=f'{self.signature}_{episodes_idx:05d}') + path_to_parameters = self.RL_agent.save(model_path=self.model_path, model_name=f'{self.signature}_{episodes_idx:05d}.dat') tqdm.write(f'I write the interim model after {episodes_idx} episodes to the disk.') tqdm.write(f'You can find the parameters here: {path_to_parameters}.') tqdm.write(f'This model achieved a mean reward of {self.best_mean_interim_reward}.') diff --git a/recommerce/rl/training_scenario.py b/recommerce/rl/training_scenario.py index 4c819e13..bf375c1a 100644 --- a/recommerce/rl/training_scenario.py +++ b/recommerce/rl/training_scenario.py @@ -3,6 +3,7 @@ import recommerce.market.sim_market as sim_market import recommerce.rl.actorcritic.actorcritic_agent as actorcritic_agent import recommerce.rl.q_learning.q_learning_agent as q_learning_agent +import recommerce.rl.stable_baselines.stable_baselines_model as sbmodel from recommerce.configuration.environment_config import EnvironmentConfigLoader, TrainingEnvironmentConfig from recommerce.market.circular.circular_vendors import CircularAgent from recommerce.rl.actorcritic.actorcritic_training import ActorCriticTrainer @@ -52,6 +53,10 @@ def train_continuos_a2c_circular_economy_rebuy(): run_training_session(circular_market.CircularEconomyRebuyPriceOneCompetitor, actorcritic_agent.ContinuosActorCriticAgentFixedOneStd) +def train_stable_baselines_ppo(): + sbmodel.StableBaselinesPPO(circular_market.CircularEconomyRebuyPriceOneCompetitor(True)).train_agent() + + def train_from_config(): """ Use the `environment_config_training.json` file to decide on the training parameters. diff --git a/setup.cfg b/setup.cfg index 80f5c57b..4378e185 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,7 @@ install_requires = pytest>=6.2.4 pytest-randomly>=3.11.0 pytest-xdist>=2.5.0 + stable-baselines3[extra]>=1.5.0 python_requires = >=3.8 [options.entry_points] diff --git a/tests/test_agent_monitoring/test_am_configuration.py b/tests/test_agent_monitoring/test_am_configuration.py index d0870229..a3eb77e1 100644 --- a/tests/test_agent_monitoring/test_am_configuration.py +++ b/tests/test_agent_monitoring/test_am_configuration.py @@ -44,7 +44,7 @@ def test_get_modelfile_path(): mock_exists.return_value = False with pytest.raises(AssertionError) as assertion_message: monitor.configurator._get_modelfile_path('non_existing_modelfile') - assert 'the specified modelfile does not exist' in str(assertion_message.value) + assert 'Modelfiles must end in .dat or .zip: non_existing_modelfile' in str(assertion_message.value) incorrect_update_agents_RL_testcases = [ diff --git a/tests/test_stable_baselines_training.py b/tests/test_stable_baselines_training.py new file mode 100644 index 00000000..c1cdec0e --- /dev/null +++ b/tests/test_stable_baselines_training.py @@ -0,0 +1,34 @@ +import pytest + +import recommerce.market.circular.circular_sim_market as circular_market +import recommerce.rl.stable_baselines.stable_baselines_model as sb_model + + +@pytest.mark.training +@pytest.mark.slow +def test_ddpg_training(): + sb_model.StableBaselinesDDPG(circular_market.CircularEconomyRebuyPriceOneCompetitor(True)).train_agent(1500, 30) + + +@pytest.mark.training +@pytest.mark.slow +def test_td3_training(): + sb_model.StableBaselinesTD3(circular_market.CircularEconomyRebuyPriceOneCompetitor(True)).train_agent(1500, 30) + + +@pytest.mark.training +@pytest.mark.slow +def test_a2c_training(): + sb_model.StableBaselinesA2C(circular_market.CircularEconomyRebuyPriceOneCompetitor(True)).train_agent(1500, 30) + + +@pytest.mark.training +@pytest.mark.slow +def test_ppo_training(): + sb_model.StableBaselinesPPO(circular_market.CircularEconomyRebuyPriceOneCompetitor(True)).train_agent(1500, 30) + + +@pytest.mark.training +@pytest.mark.slow +def test_sac_training(): + sb_model.StableBaselinesSAC(circular_market.CircularEconomyRebuyPriceOneCompetitor(True)).train_agent(1500, 30)