diff --git a/rlzoo/algorithms/sac/default.py b/rlzoo/algorithms/sac/default.py index 7ee309e..913db21 100644 --- a/rlzoo/algorithms/sac/default.py +++ b/rlzoo/algorithms/sac/default.py @@ -55,6 +55,7 @@ def classic_control(env, default_seed=True): with tf.name_scope('Policy'): policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space, hidden_dim_list=num_hidden_layer * [hidden_dim], + output_activation=None, state_conditioned=True) net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net] alg_params['net_list'] = net_list @@ -110,6 +111,7 @@ def box2d(env, default_seed=True): with tf.name_scope('Policy'): policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space, hidden_dim_list=num_hidden_layer * [hidden_dim], + output_activation=None, state_conditioned=True) net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net] alg_params['net_list'] = net_list @@ -165,6 +167,7 @@ def mujoco(env, default_seed=True): with tf.name_scope('Policy'): policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space, hidden_dim_list=num_hidden_layer * [hidden_dim], + output_activation=None, state_conditioned=True) net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net] alg_params['net_list'] = net_list @@ -220,6 +223,7 @@ def robotics(env, default_seed=True): with tf.name_scope('Policy'): policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space, hidden_dim_list=num_hidden_layer * [hidden_dim], + output_activation=None, state_conditioned=True) net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net] alg_params['net_list'] = net_list @@ -275,6 +279,7 @@ def dm_control(env, default_seed=True): with tf.name_scope('Policy'): policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space, hidden_dim_list=num_hidden_layer * [hidden_dim], + output_activation=None, state_conditioned=True) net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net] alg_params['net_list'] = net_list @@ -330,6 +335,7 @@ def rlbench(env, default_seed=True): with tf.name_scope('Policy'): policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space, hidden_dim_list=num_hidden_layer * [hidden_dim], + output_activation=None, state_conditioned=True) net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net] alg_params['net_list'] = net_list