Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/architecture typo fix. #410

Merged
merged 4 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@


def main(_: Any) -> None:
"""Run example.

Args:
_ (Any): None
"""

# Environment.
environment_factory = functools.partial(
Expand All @@ -61,7 +66,7 @@ def main(_: Any) -> None:
mad4pg.make_default_networks,
vmin=-10,
vmax=50,
archecture_type=ArchitectureType.recurrent,
architecture_type=ArchitectureType.recurrent,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@


def main(_: Any) -> None:
"""Run example.

Args:
_ (Any): None
"""

# Environment.
environment_factory = functools.partial(
Expand All @@ -58,7 +63,7 @@ def main(_: Any) -> None:

# Networks.
network_factory = lp_utils.partial_kwargs(
maddpg.make_default_networks, archecture_type=ArchitectureType.recurrent
maddpg.make_default_networks, architecture_type=ArchitectureType.recurrent
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@


def main(_: Any) -> None:
"""Run example.

Args:
_ (Any): None
"""

# Environment.
environment_factory = functools.partial(
Expand All @@ -60,7 +65,7 @@ def main(_: Any) -> None:

# Networks.
network_factory = lp_utils.partial_kwargs(
maddpg.make_default_networks, archecture_type=ArchitectureType.recurrent
maddpg.make_default_networks, architecture_type=ArchitectureType.recurrent
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@


def main(_: Any) -> None:
"""Run example.

Args:
_ (Any): None
"""

# Environment.
environment_factory = functools.partial(
pettingzoo_utils.make_environment,
Expand All @@ -60,7 +66,7 @@ def main(_: Any) -> None:

# Networks.
network_factory = lp_utils.partial_kwargs(
maddpg.make_default_networks, archecture_type=ArchitectureType.recurrent
maddpg.make_default_networks, architecture_type=ArchitectureType.recurrent
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
Expand Down
8 changes: 7 additions & 1 deletion examples/robocup/recurrent/state_based/run_mad4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,19 @@


def main(_: Any) -> None:
"""Run example.

Args:
_ (Any): None
"""

# Environment.
environment_factory = lp_utils.partial_kwargs(robocup_utils.make_environment)

# Networks.
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks,
archecture_type=ArchitectureType.recurrent,
architecture_type=ArchitectureType.recurrent,
vmin=-5,
vmax=5,
)
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/tf/mad4pg/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def make_default_networks(
policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None,
critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256),
sigma: float = 0.3,
archecture_type: ArchitectureType = ArchitectureType.feedforward,
architecture_type: ArchitectureType = ArchitectureType.feedforward,
num_atoms: int = 51,
seed: Optional[int] = None,
) -> Mapping[str, types.TensorTransformation]:
Expand All @@ -54,7 +54,7 @@ def make_default_networks(
critic_networks_layer_sizes: size of critic networks.
sigma: hyperparameters used to add Gaussian noise
for simple exploration. Defaults to 0.3.
archecture_type: archecture used
architecture_type: architecture used
for agent networks. Can be feedforward or recurrent.
Defaults to ArchitectureType.feedforward.

Expand All @@ -77,7 +77,7 @@ def make_default_networks(
policy_networks_layer_sizes=policy_networks_layer_sizes,
critic_networks_layer_sizes=critic_networks_layer_sizes,
sigma=sigma,
archecture_type=archecture_type,
architecture_type=architecture_type,
vmin=vmin,
vmax=vmax,
num_atoms=num_atoms,
Expand Down
12 changes: 6 additions & 6 deletions mava/systems/tf/maddpg/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def make_default_networks(
policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None,
critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256),
sigma: float = 0.3,
archecture_type: ArchitectureType = ArchitectureType.feedforward,
architecture_type: ArchitectureType = ArchitectureType.feedforward,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
num_atoms: Optional[int] = None,
Expand All @@ -57,7 +57,7 @@ def make_default_networks(
critic_networks_layer_sizes: size of critic networks.
sigma: hyperparameters used to add Gaussian noise
for simple exploration. Defaults to 0.3.
archecture_type: archecture used
architecture_type: architecture used
for agent networks. Can be feedforward or recurrent.
Defaults to ArchitectureType.feedforward.

Expand All @@ -70,15 +70,15 @@ def make_default_networks(
"""
# Set Policy function and layer size
# Default size per arch type.
if archecture_type == ArchitectureType.feedforward:
if architecture_type == ArchitectureType.feedforward:
if not policy_networks_layer_sizes:
policy_networks_layer_sizes = (
256,
256,
256,
)
policy_network_func = snt.Sequential
elif archecture_type == ArchitectureType.recurrent:
elif architecture_type == ArchitectureType.recurrent:
if not policy_networks_layer_sizes:
policy_networks_layer_sizes = (128, 128)
policy_network_func = snt.DeepRNN
Expand Down Expand Up @@ -130,13 +130,13 @@ def make_default_networks(
# An optional network to process observations
observation_network = tf2_utils.to_sonnet_module(tf.identity)
# Create the policy network.
if archecture_type == ArchitectureType.feedforward:
if architecture_type == ArchitectureType.feedforward:
policy_network = [
networks.LayerNormMLP(
policy_networks_layer_sizes[key], activate_final=True, seed=seed
),
]
elif archecture_type == ArchitectureType.recurrent:
elif architecture_type == ArchitectureType.recurrent:
policy_network = [
networks.LayerNormMLP(
policy_networks_layer_sizes[key][:-1],
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/tf/madqn/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def make_default_networks(
agent_net_keys: specifies what network each agent uses.
net_spec_keys: specifies the specs of each network.
value_networks_layer_sizes: size of value networks.
archecture_type: archecture used
architecture_type: architecture used
for agent networks. Can be feedforward or recurrent.
Defaults to ArchitectureType.feedforward.
seed: random seed for network initialization.
Expand Down
2 changes: 1 addition & 1 deletion tests/systems/mad4pg_system_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_recurrent_mad4pg_on_debugging_env(self) -> None:
# networks
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks,
archecture_type=ArchitectureType.recurrent,
architecture_type=ArchitectureType.recurrent,
policy_networks_layer_sizes=(32, 32),
vmin=-10,
vmax=50,
Expand Down
2 changes: 1 addition & 1 deletion tests/systems/maddpg_system_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_recurrent_maddpg_on_debugging_env(self) -> None:
# networks
network_factory = lp_utils.partial_kwargs(
maddpg.make_default_networks,
archecture_type=ArchitectureType.recurrent,
architecture_type=ArchitectureType.recurrent,
policy_networks_layer_sizes=(32, 32),
)

Expand Down