diff --git a/torchrl/csrc/segment_tree.h b/torchrl/csrc/segment_tree.h index a751c59b2d2..4f96908aaca 100644 --- a/torchrl/csrc/segment_tree.h +++ b/torchrl/csrc/segment_tree.h @@ -43,8 +43,7 @@ class SegmentTree { public: SegmentTree(int64_t size, const T& identity_element) : size_(size), identity_element_(identity_element) { - for (capacity_ = 1; capacity_ <= size; capacity_ <<= 1) - ; + for (capacity_ = 1; capacity_ <= size; capacity_ <<= 1); values_.assign(2 * capacity_, identity_element_); } diff --git a/torchrl/csrc/utils.cpp b/torchrl/csrc/utils.cpp index 54c05365091..79cd43fdffb 100644 --- a/torchrl/csrc/utils.cpp +++ b/torchrl/csrc/utils.cpp @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. // utils.h #include "utils.h" + #include torch::Tensor safetanh(torch::Tensor input, float eps) { return SafeTanh::apply(input, eps); @@ -11,28 +12,34 @@ torch::Tensor safetanh(torch::Tensor input, float eps) { torch::Tensor safeatanh(torch::Tensor input, float eps) { return SafeInvTanh::apply(input, eps); } -torch::Tensor SafeTanh::forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps) { +torch::Tensor SafeTanh::forward(torch::autograd::AutogradContext* ctx, + torch::Tensor input, float eps) { auto out = torch::tanh(input); auto lim = 1.0 - eps; out = out.clamp(-lim, lim); ctx->save_for_backward({out}); return out; } -torch::autograd::tensor_list SafeTanh::backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs) { +torch::autograd::tensor_list SafeTanh::backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { auto saved = ctx->get_saved_variables(); auto out = saved[0]; auto go = grad_outputs[0]; auto grad = go * (1 - out * out); return {grad, torch::Tensor()}; } -torch::Tensor SafeInvTanh::forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps) { +torch::Tensor SafeInvTanh::forward(torch::autograd::AutogradContext* ctx, + torch::Tensor input, float eps) { auto lim = 1.0 - eps; auto intermediate = input.clamp(-lim, lim); ctx->save_for_backward({intermediate}); auto out = torch::atanh(intermediate); return out; } -torch::autograd::tensor_list SafeInvTanh::backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs) { +torch::autograd::tensor_list SafeInvTanh::backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { auto saved = ctx->get_saved_variables(); auto input = saved[0]; auto go = grad_outputs[0]; diff --git a/torchrl/csrc/utils.h b/torchrl/csrc/utils.h index 0b41fcac467..2d93469d82a 100644 --- a/torchrl/csrc/utils.h +++ b/torchrl/csrc/utils.h @@ -14,12 +14,18 @@ torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6); class SafeTanh : public torch::autograd::Function { public: - static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps); - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs); + static torch::Tensor forward(torch::autograd::AutogradContext* ctx, + torch::Tensor input, float eps); + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs); }; class SafeInvTanh : public torch::autograd::Function { public: - static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps); - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs); + static torch::Tensor forward(torch::autograd::AutogradContext* ctx, + torch::Tensor input, float eps); + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs); };