Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Aug 1, 2024
1 parent 99332f5 commit 05ecfd4
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 45 deletions.
41 changes: 41 additions & 0 deletions torchrl/csrc/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
// utils.h
#include "utils.h"
#include <iostream>
torch::Tensor safetanh(torch::Tensor input, float eps) {
return SafeTanh::apply(input, eps);
}
torch::Tensor safeatanh(torch::Tensor input, float eps) {
return SafeInvTanh::apply(input, eps);
}
torch::Tensor SafeTanh::forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps) {
auto out = torch::tanh(input);
auto lim = 1.0 - eps;
out = out.clamp(-lim, lim);
ctx->save_for_backward({out});
return out;
}
torch::autograd::tensor_list SafeTanh::backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto out = saved[0];
auto go = grad_outputs[0];
auto grad = go * (1 - out * out);
return {grad, torch::Tensor()};
}
torch::Tensor SafeInvTanh::forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps) {
auto lim = 1.0 - eps;
auto intermediate = input.clamp(-lim, lim);
ctx->save_for_backward({intermediate});
auto out = torch::atanh(intermediate);
return out;
}
torch::autograd::tensor_list SafeInvTanh::backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto go = grad_outputs[0];
auto grad = go / (1 - input * input);
return {grad, torch::Tensor()};
}
56 changes: 11 additions & 45 deletions torchrl/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,24 @@
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
// utils.h

#pragma once

#include <torch/extension.h>
#include <torch/torch.h>

#include <iostream>

using namespace torch::autograd;
torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6);
torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6);

class SafeTanh : public Function<SafeTanh> {
class SafeTanh : public torch::autograd::Function<SafeTanh> {
public:
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input,
float eps = 1e-6) {
auto out = torch::tanh(input);
auto lim = 1.0 - eps;
out = out.clamp(-lim, lim);
ctx->save_for_backward({out});
return out;
}

static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto out = saved[0];
auto go = grad_outputs[0];
auto grad = go * (1 - out * out);
return {grad, torch::Tensor()};
}
static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps);
static torch::autograd::tensor_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs);
};

torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6) {
return SafeTanh::apply(input, eps);
}

class SafeInvTanh : public Function<SafeInvTanh> {
class SafeInvTanh : public torch::autograd::Function<SafeInvTanh> {
public:
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input,
float eps = 1e-6) {
auto lim = 1.0 - eps;
auto intermediate = input.clamp(-lim, lim);
ctx->save_for_backward({intermediate});
auto out = torch::atanh(intermediate);
return out;
}

static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto go = grad_outputs[0];
auto grad = go / (1 - input * input);
return {grad, torch::Tensor()};
}
static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input, float eps);
static torch::autograd::tensor_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs);
};

torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6) {
return SafeInvTanh::apply(input, eps);
}

0 comments on commit 05ecfd4

Please sign in to comment.