Skip to content

Commit

Permalink
Use kwargs to pass the action_mask_ph parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengYen-Tang committed Oct 8, 2019
1 parent 370e031 commit d69ca21
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
12 changes: 7 additions & 5 deletions stable_baselines/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def proba_distribution_from_flat(self, flat):
"""
return self.probability_distribution_class()(flat)

def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0):
def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0, **kwargs):
"""
returns the probability distribution from latent values
Expand Down Expand Up @@ -167,9 +167,10 @@ def probability_distribution_class(self):
def proba_distribution_from_flat(self, flat, action_mask=None):
return CategoricalProbabilityDistribution(flat, action_mask=action_mask)

def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, action_mask=None, init_scale=1.0, init_bias=0.0):
def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0, **kwargs):
pdparam = linear(pi_latent_vector, 'pi', self.n_cat, init_scale=init_scale, init_bias=init_bias)
q_values = linear(vf_latent_vector, 'q', self.n_cat, init_scale=init_scale, init_bias=init_bias)
action_mask = kwargs.get('action_mask_ph')
return self.proba_distribution_from_flat(pdparam, action_mask=action_mask), pdparam, q_values

def param_shape(self):
Expand Down Expand Up @@ -200,9 +201,10 @@ def probability_distribution_class(self):
def proba_distribution_from_flat(self, flat, action_mask=None):
return MultiCategoricalProbabilityDistribution(self.n_vec, flat, action_mask=action_mask)

def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, action_mask=None, init_scale=1.0, init_bias=0.0):
def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0, **kwargs):
pdparam = linear(pi_latent_vector, 'pi', sum(self.n_vec), init_scale=init_scale, init_bias=init_bias)
q_values = linear(vf_latent_vector, 'q', sum(self.n_vec), init_scale=init_scale, init_bias=init_bias)
action_mask = kwargs.get('action_mask_ph')
return self.proba_distribution_from_flat(pdparam, action_mask=action_mask), pdparam, q_values

def param_shape(self):
Expand Down Expand Up @@ -236,7 +238,7 @@ def proba_distribution_from_flat(self, flat):
"""
return self.probability_distribution_class()(flat)

def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0):
def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0, **kwargs):
mean = linear(pi_latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
Expand Down Expand Up @@ -265,7 +267,7 @@ def __init__(self, size):
def probability_distribution_class(self):
return BernoulliProbabilityDistribution

def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0):
def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0, **kwargs):
pdparam = linear(pi_latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
q_values = linear(vf_latent_vector, 'q', self.size, init_scale=init_scale, init_bias=init_bias)
return self.proba_distribution_from_flat(pdparam), pdparam, q_values
Expand Down
12 changes: 5 additions & 7 deletions stable_baselines/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,8 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals
self._action_mask_ph = tf.placeholder(dtype=tf.float32, shape=(n_batch, ac_space.n), name="action_mask_ph")
self._action_mask_ph = tf.placeholder_with_default(tf.zeros_like(self._action_mask_ph), shape=(n_batch, ac_space.n))
self._action_mask_shape = (n_env, ac_space.n)
elif isinstance(ac_space, Box):
self._action_mask_ph = tf.placeholder(dtype=tf.float32, shape=(n_batch, ac_space.shape[0]), name="action_mask_ph")
self._action_mask_ph = tf.placeholder_with_default(tf.zeros_like(self._action_mask_ph), shape=(n_batch, ac_space.shape[0]))
self._action_mask_shape = (n_batch, ac_space.shape[0])
else:
self._action_mask_ph = None

self.sess = sess
self.reuse = reuse
Expand Down Expand Up @@ -463,7 +461,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256
value_fn = linear(rnn_output, 'vf', 1)

self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(rnn_output, rnn_output, self.action_mask_ph)
self.pdtype.proba_distribution_from_latent(rnn_output, rnn_output, action_mask_ph=self.action_mask_ph)

self._value_fn = value_fn
else: # Use the new net_arch parameter
Expand Down Expand Up @@ -530,7 +528,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256
self._value_fn = linear(latent_value, 'vf', 1)
# TODO: why not init_scale = 0.001 here like in the feedforward
self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(latent_policy, latent_value, self.action_mask_ph)
self.pdtype.proba_distribution_from_latent(latent_policy, latent_value, action_mask_ph=self.action_mask_ph)
self._setup_init()

def step(self, obs, state=None, mask=None, deterministic=False, action_mask=None):
Expand Down Expand Up @@ -599,7 +597,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals
self._value_fn = linear(vf_latent, 'vf', 1)

self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, self.action_mask_ph, init_scale=0.01)
self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01, action_mask_ph=self.action_mask_ph)

self._setup_init()

Expand Down

0 comments on commit d69ca21

Please sign in to comment.