Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Aug 1, 2024
1 parent 05ecfd4 commit c21ff4d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
3 changes: 1 addition & 2 deletions torchrl/csrc/segment_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ class SegmentTree {
public:
SegmentTree(int64_t size, const T& identity_element)
: size_(size), identity_element_(identity_element) {
for (capacity_ = 1; capacity_ <= size; capacity_ <<= 1)
;
for (capacity_ = 1; capacity_ <= size; capacity_ <<= 1);
values_.assign(2 * capacity_, identity_element_);
}

Expand Down
15 changes: 11 additions & 4 deletions torchrl/csrc/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,42 @@
// LICENSE file in the root directory of this source tree.
// utils.h
#include "utils.h"

#include <iostream>
torch::Tensor safetanh(torch::Tensor input, float eps) {
return SafeTanh::apply(input, eps);
}
torch::Tensor safeatanh(torch::Tensor input, float eps) {
return SafeInvTanh::apply(input, eps);
}
torch::Tensor SafeTanh::forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps) {
torch::Tensor SafeTanh::forward(torch::autograd::AutogradContext* ctx,
torch::Tensor input, float eps) {
auto out = torch::tanh(input);
auto lim = 1.0 - eps;
out = out.clamp(-lim, lim);
ctx->save_for_backward({out});
return out;
}
torch::autograd::tensor_list SafeTanh::backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs) {
torch::autograd::tensor_list SafeTanh::backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto out = saved[0];
auto go = grad_outputs[0];
auto grad = go * (1 - out * out);
return {grad, torch::Tensor()};
}
torch::Tensor SafeInvTanh::forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps) {
torch::Tensor SafeInvTanh::forward(torch::autograd::AutogradContext* ctx,
torch::Tensor input, float eps) {
auto lim = 1.0 - eps;
auto intermediate = input.clamp(-lim, lim);
ctx->save_for_backward({intermediate});
auto out = torch::atanh(intermediate);
return out;
}
torch::autograd::tensor_list SafeInvTanh::backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs) {
torch::autograd::tensor_list SafeInvTanh::backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto go = grad_outputs[0];
Expand Down
14 changes: 10 additions & 4 deletions torchrl/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@ torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6);

class SafeTanh : public torch::autograd::Function<SafeTanh> {
public:
static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps);
static torch::autograd::tensor_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs);
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
torch::Tensor input, float eps);
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs);
};

class SafeInvTanh : public torch::autograd::Function<SafeInvTanh> {
public:
static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps);
static torch::autograd::tensor_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs);
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
torch::Tensor input, float eps);
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs);
};

0 comments on commit c21ff4d

Please sign in to comment.