Skip to content

Commit

Permalink
test: extend tests and docs (#12)
Browse files Browse the repository at this point in the history
* Update docstrings

* Update tests

* Update environment readme
  • Loading branch information
zombie-einstein authored Dec 11, 2024
1 parent 5c509c7 commit 8acf242
Show file tree
Hide file tree
Showing 14 changed files with 162 additions and 78 deletions.
36 changes: 22 additions & 14 deletions docs/environments/search_and_rescue.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,34 @@

[//]: # (TODO: Add animated plot)

Multi-agent environment, modelling a group of agents searching the environment
Multi-agent environment, modelling a group of agents searching a 2d environment
for multiple targets. Agents are individually rewarded for finding a target
that has not previously been detected.

Each agent visualises a local region around it, creating a simple segmented view
of locations of other agents in the vicinity. The environment is updated in the
following sequence:
Each agent visualises a local region around itself, represented as a simple segmented
view of locations of other agents and targets in the vicinity. The environment
is updated in the following sequence:

- The velocity of searching agents are updated, and consequently their positions.
- The positions of targets are updated.
- Agents are rewarded for being within a fixed range of targets, and the target
being within its view cone.
- Targets within detection range and an agents view cone are marked as found.
- Agents are rewarded for locating previously unfound targets.
- Local views of the environment are generated for each search agent.

The agents are allotted a fixed number of steps to locate the targets. The search
space is a uniform space with unit dimensions, and wrapped at the boundaries.
space is a uniform square space, wrapped at the boundaries.

Many aspects of the environment can be customised:

- Agent observations can include targets as well as other searcher agents.
- Rewards can be shared by agents, or can be treated completely individually for individual agents.
- Target dynamics can be customised to model various search scenarios.

## Observations

- `searcher_views`: jax array (float) of shape `(num_searchers, num_vision)`. Each agent
generates an independent observation, an array of values representing the distance
along a ray from the agent to the nearest neighbour, with each cell representing a
- `searcher_views`: jax array (float) of shape `(num_searchers, channels, num_vision)`.
Each agent generates an independent observation, an array of values representing the distance
along a ray from the agent to the nearest neighbour or target, with each cell representing a
ray angle (with `num_vision` rays evenly distributed over the agents field of vision).
For example if an agent sees another agent straight ahead and `num_vision = 5` then
the observation array could be
Expand All @@ -34,11 +39,12 @@ space is a uniform space with unit dimensions, and wrapped at the boundaries.
```

where `-1.0` indicates there is no agents along that ray, and `0.5` is the normalised
distance to the other agent.
distance to the other agent. Channels in the segmented view are used to differentiate
between different agents/targets and can be customised. By default, the view has three
channels representing other agents, found targets, and unfound targets.
- `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets
remaining to be detected (i.e. 1.0 when no targets have been found).
- `time_remaining`: float in the range `[0, 1]`. The normalised number of steps remaining
to locate the targets (i.e. 0.0 at the end of the episode).
- `Step`: int in the range `[0, time_limit]`. The current simulation step.

## Actions

Expand All @@ -64,4 +70,6 @@ Once applied, agent speeds are clipped to velocities within a fixed range of spe
## Rewards

Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually.
Agents are rewarded 1.0 for locating a target that has not already been detected.
Agents are rewarded +1 for locating a target that has not already been detected. It is possible
for multiple agents to detect a target inside a step, as such rewards can either be shared
by the locating agents, or each agent can get the full reward.
16 changes: 8 additions & 8 deletions jumanji/environments/swarms/common/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@

@esquilax.transforms.amap
def update_velocity(
_: chex.PRNGKey,
_key: chex.PRNGKey,
params: types.AgentParams,
x: Tuple[chex.Array, types.AgentState],
) -> Tuple[float, float]:
) -> Tuple[chex.Numeric, chex.Numeric]:
"""
Get the updated agent heading and speeds from actions
Args:
_: Dummy JAX random key.
_key: Dummy JAX random key.
params: Agent parameters.
x: Agent rotation and acceleration actions.
Expand Down Expand Up @@ -105,10 +105,10 @@ def update_state(

def view_reduction(view_a: chex.Array, view_b: chex.Array) -> chex.Array:
"""
Binary view reduction function.
Binary view reduction function for use in Esquilax spatial transformation.
Handles reduction where a value of -1.0 indicates no
agent in view-range. Returns the min value of they
agent in view-range. Returns the min value if they
are both positive, but the max value if one or both of
the values is -1.0.
Expand Down Expand Up @@ -137,7 +137,7 @@ def angular_width(
env_size: float,
) -> Tuple[chex.Array, chex.Array, chex.Array]:
"""
Get the normalised distance, and left and right angles to another agent.
Get the normalised distance, and angles to edges of another agent.
Args:
viewing_pos: Co-ordinates of the viewing agent
Expand Down Expand Up @@ -175,10 +175,10 @@ def view(
Simple view model where the agents view angle is subdivided
into an array of values representing the distance from
the agent along a rays from the agent, with rays evenly distributed.
the agent along a rays from the agent, with rays evenly distributed
across the agents field of view. The limit of vision is set at 1.0.
The default value if no object is within range is -1.0.
Currently, this model assumes the viewed objects are circular.
Currently, this model assumes the viewed agent/objects are circular.
Args:
_key: Dummy JAX random key, required by esquilax API, but
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/swarms/common/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def draw_agents(ax: Axes, agent_states: AgentState, color: str) -> Quiver:
def format_plot(
fig: Figure, ax: Axes, env_dims: Tuple[float, float], border: float = 0.01
) -> Tuple[Figure, Axes]:
"""Format a flock/swarm plot, remove ticks and bound to the unit interval
"""Format a flock/swarm plot, remove ticks and bound to the environment dimensions.
Args:
fig: Matplotlib figure.
Expand Down
46 changes: 45 additions & 1 deletion jumanji/environments/swarms/search_and_rescue/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax.random
import pytest

from jumanji.environments.swarms.search_and_rescue import SearchAndRescue
from jumanji.environments.swarms.search_and_rescue import SearchAndRescue, observations


@pytest.fixture
Expand All @@ -32,6 +32,50 @@ def env() -> SearchAndRescue:
)


class FixtureRequest:
"""Just used for typing"""

param: observations.ObservationFn


@pytest.fixture(
params=[
observations.AgentObservationFn(
num_vision=32,
vision_range=0.1,
view_angle=0.5,
agent_radius=0.01,
env_size=1.0,
),
observations.AgentAndTargetObservationFn(
num_vision=32,
vision_range=0.1,
view_angle=0.5,
agent_radius=0.01,
env_size=1.0,
),
observations.AgentAndAllTargetObservationFn(
num_vision=32,
vision_range=0.1,
view_angle=0.5,
agent_radius=0.01,
env_size=1.0,
),
]
)
def multi_obs_env(request: FixtureRequest) -> SearchAndRescue:
return SearchAndRescue(
target_contact_range=0.05,
searcher_max_rotate=0.2,
searcher_max_accelerate=0.01,
searcher_min_speed=0.01,
searcher_max_speed=0.05,
searcher_view_angle=0.5,
time_limit=10,
observation=request.param,
)


@pytest.fixture
def key() -> chex.PRNGKey:
return jax.random.PRNGKey(101)
20 changes: 12 additions & 8 deletions jumanji/environments/swarms/search_and_rescue/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
class TargetDynamics(abc.ABC):
@abc.abstractmethod
def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState:
"""Interface for target position update function.
"""Interface for target state update function.
NOTE: Target positions should be bound to environment
area (generally wrapped around at the boundaries).
Args:
key: random key.
key: Random key.
targets: Current target states.
env_size: Environment size.
Returns:
Updated target states.
Expand All @@ -37,23 +41,23 @@ def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) ->
class RandomWalk(TargetDynamics):
def __init__(self, step_size: float):
"""
Random walk target dynamics.
Simple random walk target dynamics.
Target positions are updated with random
steps, sampled uniformly from the range
[-step-size, step-size].
Target positions are updated with random steps, sampled uniformly
from the range `[-step-size, step-size]`.
Args:
step_size: Maximum random step-size
step_size: Maximum random step-size in each axis.
"""
self.step_size = step_size

def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState:
"""Update target positions.
"""Update target state.
Args:
key: random key.
targets: Current target states.
env_size: Environment size.
Returns:
Updated target states.
Expand Down
44 changes: 24 additions & 20 deletions jumanji/environments/swarms/search_and_rescue/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics
from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator
from jumanji.environments.swarms.search_and_rescue.observations import (
AgentAndTargetObservationFn,
AgentAndAllTargetObservationFn,
ObservationFn,
)
from jumanji.environments.swarms.search_and_rescue.reward import RewardFn, SharedRewardFn
Expand All @@ -46,8 +46,9 @@ class SearchAndRescue(Environment):
for a set of targets on a 2d environment. Agents are rewarded
(individually) for coming within a fixed range of a target that has
not already been detected. Agents visualise their local environment
(i.e. the location of other agents) via a simple segmented view model.
The environment consists of a uniform space with wrapped boundaries.
(i.e. the location of other agents and targets) via a simple segmented
view model. The environment area is a uniform square space with wrapped
boundaries.
An episode will terminate if all targets have been located by the team of
searching agents.
Expand All @@ -58,12 +59,12 @@ class SearchAndRescue(Environment):
channels can be used to differentiate between agents and targets.
Each entry in the view indicates the distant to another agent/target
along a ray from the agent, and is -1.0 if nothing is in range along the ray.
The view model can be customised using an `ObservationFn` implementation.
The view model can be customised using an `ObservationFn` implementation, e.g.
the view can include all agents and targets, or just other agents.
targets_remaining: (float) Number of targets remaining to be found from
the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates
all the targets are still to be found).
time_remaining: (float) Steps remaining to find agents, scaled to the
range [0,1] (i.e. the value is 0 when time runs out).
step: (int) current simulation step.
- action: jax array (float) of shape (num_searchers, 2)
Array of individual agent actions. Each agents actions rotate and
Expand All @@ -80,13 +81,14 @@ class SearchAndRescue(Environment):
- state: `State`
- searchers: `AgentState`
- pos: jax array (float) of shape (num_searchers, 2) in the range [0, 1].
- pos: jax array (float) of shape (num_searchers, 2) in the range [0, env_size].
- heading: jax array (float) of shape (num_searcher,) in
the range [0, 2pi].
- speed: jax array (float) of shape (num_searchers,) in the
range [min_speed, max_speed].
- targets: `TargetState`
- pos: jax array (float) of shape (num_targets, 2) in the range [0, 1].
- pos: jax array (float) of shape (num_targets, 2) in the range [0, env_size].
- vel: jax array (float) of shape (num_targets, 2).
- found: jax array (bool) of shape (num_targets,) flag indicating if
target has been located by an agent.
- key: jax array (uint32) of shape (2,)
Expand Down Expand Up @@ -127,8 +129,8 @@ def __init__(
searcher_max_rotate: Maximum rotation searcher agents can
turn within a step. Should be a value from [0,1]
representing a fraction of pi radians.
searcher_max_accelerate: Maximum acceleration/deceleration
a searcher agent can apply within a step.
searcher_max_accelerate: Magnitude of the maximum
acceleration/deceleration a searcher agent can apply within a step.
searcher_min_speed: Minimum speed a searcher agent can move at.
searcher_max_speed: Maximum speed a searcher agent can move at.
searcher_view_angle: Searcher agent local view angle. Should be
Expand All @@ -145,8 +147,11 @@ def __init__(
with 20 targets and 10 searchers.
reward_fn: Reward aggregation function. Defaults to `SharedRewardFn` where
agents share rewards if they locate a target simultaneously.
observation: Agent observation view generation function. Defaults to
`AgentAndAllTargetObservationFn` where all targets (found and unfound)
and other ogents are included in the generated view.
"""
# self.searcher_vision_range = searcher_vision_range

self.target_contact_range = target_contact_range

self.searcher_params = AgentParams(
Expand All @@ -161,7 +166,7 @@ def __init__(
self.generator = generator or RandomGenerator(num_targets=100, num_searchers=2)
self._viewer = viewer or SearchAndRescueViewer()
self._reward_fn = reward_fn or SharedRewardFn()
self._observation = observation or AgentAndTargetObservationFn(
self._observation = observation or AgentAndAllTargetObservationFn(
num_vision=64,
vision_range=0.1,
view_angle=searcher_view_angle,
Expand Down Expand Up @@ -190,7 +195,7 @@ def __repr__(self) -> str:
)

def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
"""Initialise searcher positions and velocities, and target positions.
"""Initialise searcher and target initial states.
Args:
key: Random key used to reset the environment.
Expand All @@ -217,7 +222,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser
state: Updated searcher and target positions and velocities.
timestep: Transition timestep with individual agent local observations.
"""
# Note: only one new key is needed for the targets, as all other
# Note: only one new key is needed for the target updates, as all other
# keys are just dummy values required by Esquilax
key, target_key = jax.random.split(state.key, num=2)
searchers = update_state(
Expand All @@ -228,22 +233,21 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser

# Searchers return an array of flags of any targets they are in range of,
# and that have not already been located, result shape here is (n-searcher, n-targets)
n_targets = targets.pos.shape[0]
targets_found = spatial(
utils.searcher_detect_targets,
reduction=jnp.logical_or,
default=jnp.zeros((n_targets,), dtype=bool),
default=jnp.zeros((self.generator.num_targets,), dtype=bool),
i_range=self.target_contact_range,
dims=self.generator.env_size,
)(
key,
self.searcher_params.view_angle,
searchers,
(jnp.arange(n_targets), targets),
(jnp.arange(self.generator.num_targets), targets),
pos=searchers.pos,
pos_b=targets.pos,
env_size=self.generator.env_size,
n_targets=n_targets,
n_targets=self.generator.num_targets,
)

rewards = self._reward_fn(targets_found, state.step, self.time_limit)
Expand Down Expand Up @@ -352,14 +356,14 @@ def render(self, state: State) -> None:
"""Render a frame of the environment for a given state using matplotlib.
Args:
state: State object containing the current dynamics of the environment.
state: State object containing the current state of the environment.
"""
self._viewer.render(state)

def animate(
self,
states: Sequence[State],
interval: int = 200,
interval: int = 100,
save_path: Optional[str] = None,
) -> FuncAnimation:
"""Create an animation from a sequence of environment states.
Expand Down
Loading

0 comments on commit 8acf242

Please sign in to comment.