Skip to content

Commit

Permalink
Allow tuples in set_attr and move docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
tristandeleu committed Aug 29, 2021
1 parent 5b24bc6 commit 27d04f6
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 38 deletions.
21 changes: 12 additions & 9 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,11 @@ def call_async(self, name, *args, **kwargs):
name : string
Name of the method or property to call.
args, kwargs :
Arguments and keyword arguments to apply to the method call.
*args
Arguments to apply to the method call.
**kwargs
Keywoard arguments to apply to the method call.
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
Expand Down Expand Up @@ -362,18 +365,18 @@ def set_attr(self, name, values):
name : string
Name of the property to be set in each individual environment.
values : list of object
Values of the property to be set to. If `values` is a list, then
it corresponds to the values for each individual environment,
otherwise a single value is set for all environments.
values : list, tuple, or object
Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
"""
self._assert_is_running()
if not isinstance(values, list):
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"The values must be a list of length the number "
f"of environments. Got `{len(values)}` values for "
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)

Expand Down
32 changes: 3 additions & 29 deletions gym/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,6 @@ def step_wait(self):
)

def call(self, name, *args, **kwargs):
"""
Parameters
----------
name : string
Name of the method or property to call.
args, kwargs :
Arguments and keyword arguments to apply to the method call.
Returns
-------
results : list
List of the results of the individual calls to the method or
property for each environment.
"""
results = []
for env in self.envs:
function = getattr(env, name)
Expand All @@ -122,23 +107,12 @@ def call(self, name, *args, **kwargs):
return tuple(results)

def set_attr(self, name, values):
"""
Parameters
----------
name : string
Name of the property to be set in each individual environment.
values : list of object
Values of the property to be set to. If `values` is a list, then
it corresponds to the values for each individual environment,
otherwise a single value is set for all environments.
"""
if not isinstance(values, list):
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"The values must be a list of length the number "
f"of environments. Got `{len(values)}` values for "
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)

Expand Down
38 changes: 38 additions & 0 deletions gym/vector/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,51 @@ def call_wait(self, **kwargs):
raise NotImplementedError()

def call(self, name, *args, **kwargs):
"""Call a method, or get a property, from each sub-environment.
Parameters
----------
name : string
Name of the method or property to call.
*args
Arguments to apply to the method call.
**kwargs
Keywoard arguments to apply to the method call.
Returns
-------
results : list
List of the results of the individual calls to the method or
property for each environment.
"""
self.call_async(name, *args, **kwargs)
return self.call_wait()

def get_attr(self, name):
"""Get a property from each sub-environment.
Parameters
----------
name : string
Name of the property to be get from each individual environment.
"""
return self.call(name)

def set_attr(self, name, values):
"""Set a property in each sub-environment.
Parameters
----------
name : string
Name of the property to be set in each individual environment.
values : list, tuple, or object
Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
"""
raise NotImplementedError()

def close_extras(self, **kwargs):
Expand Down

0 comments on commit 27d04f6

Please sign in to comment.