Skip to content

Commit

Permalink
Optimize garage.torch (#2170)
Browse files Browse the repository at this point in the history
Remove `torch.tensor` in tensor conversion

* Method `torch.tensor` always copies the data

* Makes it slower than other tensor converion
  methods

Use `from_numpy` for np to tensor conversion

* Method `from_numpy` has lower overhead than `as_tensor`

* Supply dtype for tensor conversions.
  Faster than calling tensor.float()

Allow enabling of cudnn benchmarks for CNN

* Enabled benchmarks optimize algorithm
  performance
* Remove suppressed argument in CNNModule

Use common functon to convert data to tensors

* Modify `np_to_torch` to only perform
  type conversion when needed
* Use it for all numpy to tensor conversions
* Add common [object] to tensor function

This optimizes conversion of data to tensors

Add custom zero_grad optimizer function

* This will allow setting of grads to None

Resolve conversion of observations to tensors

Set gradients to None in single-task RL algrithms

Maml: sets grads to none

* Make changes to update gradients of MAML
  by setting to None

* This is after checking the optimation
  has doesn't deteriorate it's performance

Fix errors in calling zero_grad

* Use undeprecated `param.add_(Tensor, alpha)` method
  to clear deprecation warning

* Set grads to None for PEARL

* Fix differentiable_sgd runtime error

* Add tensor-conversation optimizations for VPG

Fix pre-commit errors

Fix files permission from 100644 to 100755
  • Loading branch information
mugoh authored Apr 13, 2021
1 parent cb88ffa commit 6700814
Show file tree
Hide file tree
Showing 17 changed files with 163 additions and 94 deletions.
6 changes: 2 additions & 4 deletions benchmarks/src/garage_benchmarks/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@

import click

from garage_benchmarks import (benchmark_algos,
benchmark_auto,
benchmark_baselines,
benchmark_policies,
from garage_benchmarks import (benchmark_algos, benchmark_auto,
benchmark_baselines, benchmark_policies,
benchmark_q_functions)

# yapf: enable
Expand Down
26 changes: 13 additions & 13 deletions src/garage/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
"""PyTorch-backed modules and algorithms."""
# yapf: disable
from garage.torch._functions import (as_torch, as_torch_dict,
compute_advantages, expand_var,
filter_valids, flatten_batch,
from garage.torch._functions import (as_torch_dict, compute_advantages,
expand_var, filter_valids, flatten_batch,
flatten_to_single_vector, global_device,
NonLinearity, output_height_2d,
output_width_2d, pad_to_last, prefer_gpu,
NonLinearity, np_to_torch,
output_height_2d, output_width_2d,
pad_to_last, prefer_gpu,
product_of_gaussians, set_gpu_mode,
soft_update_model, state_dict_to,
torch_to_np, update_module_params)

# yapf: enable
__all__ = [
'compute_advantages',
'NonLinearity',
'as_torch_dict',
'compute_advantages',
'expand_var',
'filter_valids',
'flatten_batch',
'flatten_to_single_vector',
'global_device',
'as_torch',
'np_to_torch',
'output_height_2d',
'output_width_2d',
'pad_to_last',
'prefer_gpu',
'product_of_gaussians',
'set_gpu_mode',
'soft_update_model',
'state_dict_to',
'torch_to_np',
'update_module_params',
'NonLinearity',
'flatten_to_single_vector',
'output_width_2d',
'output_height_2d',
'expand_var',
'state_dict_to',
]
44 changes: 41 additions & 3 deletions src/garage/torch/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,27 @@
_GPU_ID = 0


def zero_optim_grads(optim, set_to_none=True):
"""Sets the gradient of all optimized tensors to None.
This is an optimization alternative to calling `optimizer.zero_grad()`
Args:
optim (torch.nn.Optimizer): The optimizer instance
to zero parameter gradients.
set_to_none (bool): Set gradients to None
instead of calling `zero_grad()`which
sets to 0.
"""
if not set_to_none:
optim.zero_grad()
return

for group in optim.param_groups:
for param in group['params']:
param.grad = None


def compute_advantages(discount, gae_lambda, max_episode_length, baselines,
rewards):
"""Calculate advantages.
Expand Down Expand Up @@ -132,7 +153,7 @@ def filter_valids(tensor, valids):
return [tensor[i][:valid] for i, valid in enumerate(valids)]


def as_torch(array):
def np_to_torch(array):
"""Numpy arrays to PyTorch tensors.
Args:
Expand All @@ -142,7 +163,24 @@ def as_torch(array):
torch.Tensor: float tensor on the global device.
"""
return torch.as_tensor(array).float().to(global_device())
tensor = torch.from_numpy(array)

if tensor.dtype != torch.float32:
tensor = tensor.float()

return tensor.to(global_device())


def list_to_tensor(data):
"""Convert a list to a PyTorch tensor.
Args:
data (list): Data to convert to tensor
Returns:
torch.Tensor: A float tensor
"""
return torch.as_tensor(data, dtype=torch.float32, device=global_device())


def as_torch_dict(array_dict):
Expand All @@ -158,7 +196,7 @@ def as_torch_dict(array_dict):
"""
for key, value in array_dict.items():
array_dict[key] = as_torch(value)
array_dict[key] = np_to_torch(value)
return array_dict


Expand Down
6 changes: 3 additions & 3 deletions src/garage/torch/algos/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from garage.np.algos.rl_algorithm import RLAlgorithm
from garage.np.policies import Policy
from garage.sampler import Sampler
from garage.torch import as_torch
from garage.torch import np_to_torch

# yapf: enable

Expand Down Expand Up @@ -130,8 +130,8 @@ def _train_once(self, trainer, epoch):
minibatches = np.array_split(indices, self._minibatches_per_epoch)
losses = []
for minibatch in minibatches:
observations = as_torch(batch.observations[minibatch])
actions = as_torch(batch.actions[minibatch])
observations = np_to_torch(batch.observations[minibatch])
actions = np_to_torch(batch.actions[minibatch])
self._optimizer.zero_grad()
loss = self._compute_loss(observations, actions)
loss.backward()
Expand Down
5 changes: 3 additions & 2 deletions src/garage/torch/algos/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
obtain_evaluation_episodes)
from garage.np.algos import RLAlgorithm
from garage.torch import as_torch_dict, torch_to_np
from garage.torch._functions import zero_optim_grads

# yapf: enable

Expand Down Expand Up @@ -254,14 +255,14 @@ def optimize_policy(self, samples_data):
qval = self._qf(inputs, actions)
qf_loss = torch.nn.MSELoss()
qval_loss = qf_loss(qval, y_target)
self._qf_optimizer.zero_grad()
zero_optim_grads(self._qf_optimizer)
qval_loss.backward()
self._qf_optimizer.step()

# optimize actor
actions = self.policy(inputs)
action_loss = -1 * self._qf(inputs, actions).mean()
self._policy_optimizer.zero_grad()
zero_optim_grads(self._policy_optimizer)
action_loss.backward()
self._policy_optimizer.step()

Expand Down
15 changes: 8 additions & 7 deletions src/garage/torch/algos/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from garage import _Default, log_performance, make_optimizer
from garage._functions import obtain_evaluation_episodes
from garage.np.algos import RLAlgorithm
from garage.torch import as_torch, global_device
from garage.torch import global_device, np_to_torch
from garage.torch._functions import zero_optim_grads


class DQN(RLAlgorithm):
Expand Down Expand Up @@ -243,12 +244,12 @@ def _optimize_qf(self, timesteps):
qval: Q-value predicted by the Q-network.
"""
observations = as_torch(timesteps.observations)
rewards = as_torch(timesteps.rewards).reshape(-1, 1)
observations = np_to_torch(timesteps.observations)
rewards = np_to_torch(timesteps.rewards).reshape(-1, 1)
rewards *= self._reward_scale
actions = as_torch(timesteps.actions)
next_observations = as_torch(timesteps.next_observations)
terminals = as_torch(timesteps.terminals).reshape(-1, 1)
actions = np_to_torch(timesteps.actions)
next_observations = np_to_torch(timesteps.next_observations)
terminals = np_to_torch(timesteps.terminals).reshape(-1, 1)

next_inputs = next_observations
inputs = observations
Expand Down Expand Up @@ -279,7 +280,7 @@ def _optimize_qf(self, timesteps):
selected_qs = torch.sum(qvals * actions, axis=1)
qval_loss = F.smooth_l1_loss(selected_qs, y_target)

self._qf_optimizer.zero_grad()
zero_optim_grads(self._qf_optimizer)
qval_loss.backward()

# optionally clip the gradients
Expand Down
12 changes: 7 additions & 5 deletions src/garage/torch/algos/maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
make_optimizer)
from garage.np import discount_cumsum
from garage.torch import update_module_params
from garage.torch._functions import np_to_torch, zero_optim_grads
from garage.torch.optimizers import (ConjugateGradientOptimizer,
DifferentiableSGD)

Expand Down Expand Up @@ -121,7 +122,7 @@ def _train_once(self, trainer, all_samples, all_params):

meta_objective = self._compute_meta_loss(all_samples, all_params)

self._meta_optimizer.zero_grad()
zero_optim_grads(self._meta_optimizer)
meta_objective.backward()

self._meta_optimize(all_samples, all_params)
Expand Down Expand Up @@ -165,12 +166,13 @@ def _train_value_function(self, paths):

obs = np.concatenate([path['observations'] for path in paths], axis=0)
returns = np.concatenate([path['returns'] for path in paths])
obs = torch.Tensor(obs)
returns = torch.Tensor(returns)

obs = np_to_torch(obs)
returns = np_to_torch(returns.astype(np.float32))

vf_loss = self._value_function.compute_loss(obs, returns)
# pylint: disable=protected-access
self._inner_algo._vf_optimizer.zero_grad()
zero_optim_grads(self._inner_algo._vf_optimizer._optimizer)
vf_loss.backward()
# pylint: disable=protected-access
self._inner_algo._vf_optimizer.step()
Expand Down Expand Up @@ -232,7 +234,7 @@ def _adapt(self, batch_samples, set_grad=True):
loss = self._inner_algo._compute_loss(*batch_samples[1:])

# Update policy parameters with one SGD step
self._inner_optimizer.zero_grad()
self._inner_optimizer.set_grads_none()
loss.backward(create_graph=set_grad)

with torch.set_grad_enabled(set_grad):
Expand Down
27 changes: 14 additions & 13 deletions src/garage/torch/algos/pearl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from garage.replay_buffer import PathBuffer
from garage.sampler import DefaultWorker
from garage.torch import global_device
from garage.torch._functions import np_to_torch, zero_optim_grads
from garage.torch.embeddings import MLPEncoder
from garage.torch.policies import ContextConditionedPolicy

Expand Down Expand Up @@ -354,14 +355,14 @@ def _optimize_policy(self, indices):
target_v_values = self.target_vf(next_obs, task_z)

# KL constraint on z if probabilistic
self.context_optimizer.zero_grad()
zero_optim_grads(self.context_optimizer)
if self._use_information_bottleneck:
kl_div = self._policy.compute_kl_div()
kl_loss = self._kl_lambda * kl_div
kl_loss.backward(retain_graph=True)

self.qf1_optimizer.zero_grad()
self.qf2_optimizer.zero_grad()
zero_optim_grads(self.qf1_optimizer)
zero_optim_grads(self.qf2_optimizer)

rewards_flat = rewards.view(self._batch_size * num_tasks, -1)
rewards_flat = rewards_flat * self._reward_scale
Expand All @@ -384,7 +385,7 @@ def _optimize_policy(self, indices):
# optimize vf
v_target = min_q - log_pi
vf_loss = self.vf_criterion(v_pred, v_target.detach())
self.vf_optimizer.zero_grad()
zero_optim_grads(self.vf_optimizer)
vf_loss.backward()
self.vf_optimizer.step()
self._update_target_network()
Expand All @@ -402,7 +403,7 @@ def _optimize_policy(self, indices):
pre_activation_reg_loss)
policy_loss = policy_loss + policy_reg_loss

self._policy_optimizer.zero_grad()
zero_optim_grads(self._policy_optimizer)
policy_loss.backward()
self._policy_optimizer.step()

Expand Down Expand Up @@ -498,11 +499,11 @@ def _sample_data(self, indices):
no = np.vstack((no, batch['next_observations'][np.newaxis]))
d = np.vstack((d, batch['dones'][np.newaxis]))

o = torch.as_tensor(o, device=global_device()).float()
a = torch.as_tensor(a, device=global_device()).float()
r = torch.as_tensor(r, device=global_device()).float()
no = torch.as_tensor(no, device=global_device()).float()
d = torch.as_tensor(d, device=global_device()).float()
o = np_to_torch(o)
a = np_to_torch(a)
r = np_to_torch(r)
no = np_to_torch(no)
d = np_to_torch(d)

return o, a, r, no, d

Expand Down Expand Up @@ -541,8 +542,8 @@ def _sample_context(self, indices):
else:
final_context = np.vstack((final_context, context[np.newaxis]))

final_context = torch.as_tensor(final_context,
device=global_device()).float()
final_context = np_to_torch(final_context)

if len(indices) == 1:
final_context = final_context.unsqueeze(0)

Expand Down Expand Up @@ -614,7 +615,7 @@ def adapt_policy(self, exploration_policy, exploration_episodes):
a = exploration_episodes.actions
r = exploration_episodes.rewards.reshape(total_steps, 1)
ctxt = np.hstack((o, a, r)).reshape(1, total_steps, -1)
context = torch.as_tensor(ctxt, device=global_device()).float()
context = np_to_torch(ctxt)
self._policy.infer_posterior(context)

return self._policy
Expand Down
Loading

0 comments on commit 6700814

Please sign in to comment.