Skip to content

Commit

Permalink
Merge branch 'main' into feat/add_sliding_tile_puzzle_environment
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a authored Jan 10, 2024
2 parents 1f9f404 + f6c9ef3 commit 3fa7677
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 9 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ instance/

# Sphinx documentation
docs/_build/
# MkDocs documentation
docs_public/

# PyBuilder
.pybuilder/
Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
<p align="center">
<a href="docs/img/jumanji_logo.png">
<img src="docs/img/jumanji_logo.png" alt="Jumanji logo" width="50%"/>
</a>
<picture>
<source media="(prefers-color-scheme: dark)" srcset="docs/img/jumanji_logo_dm.png">
<source media="(prefers-color-scheme: light)" srcset="docs/img/jumanji_logo.png">
<img alt="Jumanji Logo" src="docs/img/jumanji_logo.png", width="50%">
</picture>
</p>

[![Python Versions](https://img.shields.io/pypi/pyversions/jumanji.svg?style=flat-square)](https://www.python.org/doc/versions/)
Expand Down
8 changes: 8 additions & 0 deletions docs/api/environments/tetris.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
::: jumanji.environments.packing.tetris.env.Tetris
selection:
members:
- __init__
- reset
- step
- observation_spec
- action_spec
Binary file added docs/img/jumanji_logo_dm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions jumanji/environments/routing/robot_warehouse/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""File adapted from [Rware](https://github.com/semitable/robotic-warehouse). More specifically,
the rendering code is copied from the original Rware environment and should be ignored from the
copyright."""

# flake8: noqa: CCR001

from typing import Callable, Optional, Sequence, Tuple
Expand Down
4 changes: 2 additions & 2 deletions jumanji/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ def test_shape_element_type_error(self) -> None:

def test_dtype_type_error(self) -> None:
with pytest.raises(TypeError):
specs.Array((1, 2, 3), "32")
specs.Array((1, 2, 3), "32") # type: ignore

def test_scalar_shape(self) -> None:
specs.Array((), jnp.int32)

def test_string_dtype_error(self) -> None:
specs.Array((1, 2, 3), "int32")
specs.Array((1, 2, 3), "int32") # type: ignore

def test_dtype(self) -> None:
specs.Array((1, 2, 3), int)
Expand Down
11 changes: 8 additions & 3 deletions jumanji/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import field
from typing import TYPE_CHECKING, Dict, Generic, Optional, Sequence, TypeVar, Union

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
Expand Down Expand Up @@ -71,14 +72,14 @@ class TimeStep(Generic[Observation]):
extras: environment metric(s) or information returned by the environment but
not observed by the agent (hence not in the observation). For example, it
could be whether an invalid action was taken. In most environments, extras
is None.
is an empty dictionary.
"""

step_type: StepType
reward: Array
discount: Array
observation: Observation
extras: Optional[Dict] = None
extras: Dict = field(default_factory=dict)

def first(self) -> Array:
return self.step_type == StepType.FIRST
Expand Down Expand Up @@ -110,6 +111,7 @@ def restart(
Returns:
TimeStep identified as a reset.
"""
extras = extras or {}
return TimeStep(
step_type=StepType.FIRST,
reward=jnp.zeros(shape, dtype=float),
Expand Down Expand Up @@ -144,6 +146,7 @@ def transition(
TimeStep identified as a transition.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
extras = extras or {}
return TimeStep(
step_type=StepType.MID,
reward=reward,
Expand Down Expand Up @@ -175,6 +178,7 @@ def termination(
Returns:
TimeStep identified as the termination of an episode.
"""
extras = extras or {}
return TimeStep(
step_type=StepType.LAST,
reward=reward,
Expand Down Expand Up @@ -208,6 +212,7 @@ def truncation(
TimeStep identified as the truncation of an episode.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
extras = extras or {}
return TimeStep(
step_type=StepType.LAST,
reward=reward,
Expand All @@ -228,4 +233,4 @@ def get_valid_dtype(dtype: Union[jnp.dtype, type]) -> jnp.dtype:
Returns:
dtype converted to the correct type precision.
"""
return jnp.empty((), dtype).dtype
return jnp.empty((), dtype).dtype # type: ignore
8 changes: 8 additions & 0 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ def action_spec(self) -> specs.Spec:
"""Returns the action spec."""
return self._env.action_spec()

def reward_spec(self) -> specs.Array:
"""Returns the reward spec."""
return self._env.reward_spec()

def discount_spec(self) -> specs.BoundedArray:
"""Returns the discount spec."""
return self._env.discount_spec()

def render(self, state: State) -> Any:
"""Compute render frames during initialisation of the environment.
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ chex>=0.1.3
dm-env>=1.5
gym>=0.22.0
jax>=0.2.26
matplotlib>=3.3.4
matplotlib~=3.7.4
numpy>=1.19.5
Pillow>=9.0.0
typing-extensions>=4.0.0

0 comments on commit 3fa7677

Please sign in to comment.