Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Allows to build NNs in A2C from a string equal to the function's name #180

Merged
merged 14 commits into from
May 6, 2022
Merged
17 changes: 14 additions & 3 deletions rlberry/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rlberry.agents.torch.utils.models import default_value_net_fn
from rlberry.utils.torch import choose_device
from rlberry.wrappers.uncertainty_estimator_wrapper import UncertaintyEstimatorWrapper
from rlberry.utils.factory import load

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,9 +109,19 @@ def __init__(
self.state_dim = self.env.observation_space.shape[0]
self.action_dim = self.env.action_space.n

#
self.policy_net_fn = policy_net_fn or default_policy_net_fn
self.value_net_fn = value_net_fn or default_value_net_fn
if isinstance(policy_net_fn, str):
self.policy_net_fn = load(policy_net_fn)
elif policy_net_fn is None:
self.policy_net_fn = default_policy_net_fn
else:
self.policy_net_fn = policy_net_fn

if isinstance(value_net_fn, str):
self.value_net_fn = load(value_net_fn)
elif policy_net_fn is None:
self.value_net_fn = default_value_net_fn
else:
self.value_net_fn = value_net_fn

self.optimizer_kwargs = {"optimizer_type": optimizer_type, "lr": learning_rate}

Expand Down
38 changes: 38 additions & 0 deletions rlberry/agents/torch/tests/test_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,41 @@ def test_a2c():

output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False)
a2crlberry_stats.clear_output_dir()

# test also non default
env = "CartPole-v0"
mdp = make(env)
env_ctor = Wrapper
env_kwargs = dict(env=mdp)

a2crlberry_stats = AgentManager(
A2CAgent,
(env_ctor, env_kwargs),
fit_budget=int(2),
eval_kwargs=dict(eval_horizon=2),
init_kwargs=dict(
horizon=2,
policy_net_fn="rlberry.agents.torch.utils.training.model_factory_from_env",
policy_net_kwargs=dict(
type="MultiLayerPerceptron",
layer_sizes=(256,),
reshape=False,
is_policy=True,
),
value_net_fn="rlberry.agents.torch.utils.training.model_factory_from_env",
value_net_kwargs=dict(
type="MultiLayerPerceptron",
layer_sizes=[
512,
],
reshape=False,
out_size=1,
),
),
n_fit=1,
agent_name="A2C_rlberry_" + env,
)
a2crlberry_stats.fit()

output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False)
a2crlberry_stats.clear_output_dir()