Skip to content

Commit

Permalink
Disable TF32 in some linalg functions (pytorch#73460)
Browse files Browse the repository at this point in the history
Summary:
Disable TF32 in some linalg functions

See also pytorch#67948 pytorch#50453 pytorch#44240

Pull Request resolved: pytorch#73460

Reviewed By: albanD

Differential Revision: D34493487

Pulled By: ngimel

fbshipit-source-id: 958cd968ea09df3b5a4d2b4a26aaf0dfddc53981
(cherry picked from commit cd75ec6)
  • Loading branch information
xwang233 authored and pytorchmergebot committed Feb 28, 2022
1 parent 78914b3 commit 89b4cfb
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ Tensor linalg_matrix_power_impl(
const Tensor& self,
int64_t n,
c10::optional<Tensor> _out) {
NoTF32Guard disable_tf32;
auto out = _out.value_or(Tensor());

squareCheckInputs(self, "linalg.matrix_power");
Expand Down
22 changes: 22 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tenso
}

Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
at::NoTF32Guard disable_tf32;
Tensor grad_self = solve_backward_self(grad, self, A);
if (self.ndimension() == 2 && A.ndimension() == 2) {
return -at::mm(grad_self, solution.mH());
Expand Down Expand Up @@ -1133,6 +1134,7 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra
}

Tensor cholesky_jvp(const Tensor& input_tangent, const Tensor& L, bool upper) {
at::NoTF32Guard disable_tf32;
// Differentiation of the Cholesky decomposition, Iain Murray
// https://arxiv.org/abs/1602.07527
// equation 8
Expand All @@ -1147,6 +1149,7 @@ Tensor cholesky_jvp(const Tensor& input_tangent, const Tensor& L, bool upper) {
}

Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
at::NoTF32Guard disable_tf32;
// cf. Iain Murray (2016); arXiv 1602.07527
// This gradient is symmetric, and not triangular.
// Cholesky additionally assumes that the input is symmetric, which is a subspace of
Expand All @@ -1170,6 +1173,7 @@ Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
}

Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) {
at::NoTF32Guard disable_tf32;
Tensor grad_L;
if (grad.defined()) {
Tensor common_term = grad + grad.mT();
Expand Down Expand Up @@ -2414,6 +2418,7 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(const Tensor& dA,
const Tensor& S,
const Tensor& Vh_,
const bool full_matrices) {
at::NoTF32Guard disable_tf32;
// See svd_backward for the derivation
// With sym(X) = X + X^H, we implement
// dU = U (sym(dX S) / E + i Im(diag(dX)) / (2S))
Expand Down Expand Up @@ -2517,6 +2522,7 @@ Tensor svd_backward(const Tensor& gU,
const Tensor& U,
const Tensor& S,
const Tensor& Vh) {
at::NoTF32Guard disable_tf32;
// Throughout both the real and complex case we assume A has distinct singular values.
// Furthermore, if A is rectangular or complex, we assume it's full-rank.
//
Expand Down Expand Up @@ -2726,6 +2732,7 @@ Tensor svd_backward(const Tensor& gU,
// See the details below.
Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool is_eigvec_tensor_nonempty, const Tensor& eigenvalues, const Tensor& eigenvectors) {
at::NoTF32Guard disable_tf32;
TORCH_CHECK(is_eigvec_tensor_nonempty,
"eig_backward: torch.eig(eigenvalues=False) is not differentiable. ",
"Please use torch.linalg.eigvals");
Expand Down Expand Up @@ -2865,6 +2872,7 @@ Tensor linalg_eig_backward(const Tensor& gL,
const Tensor& V,
const bool is_hermitian,
const bool symeig_eigenvectors) {
at::NoTF32Guard disable_tf32;
// https://arxiv.org/pdf/1701.00392.pdf Eq 4.77
// For A = VLV^{-1}, denoting the gradients gA, gV and gL, we have
// gA = V^{-H}(diag_embed(gL) + (V^H gV -V^HV diag(real(V^H gV))) / E*)V^H
Expand Down Expand Up @@ -2947,6 +2955,7 @@ std::tuple<Tensor, Tensor> linalg_eig_jvp(const Tensor& dA,
const Tensor& L,
const Tensor& V,
const bool is_hermitian) {
at::NoTF32Guard disable_tf32;
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// see also https://arxiv.org/pdf/1701.00392.pdf Eqs. (4.60) and (4.63)
// Note that neither of the formulas in these pdfs are correct, as they do not assume that
Expand Down Expand Up @@ -2994,6 +3003,7 @@ Tensor linalg_lstsq_jvp(
const Tensor& dA,
const Tensor& dB
) {
at::NoTF32Guard disable_tf32;
auto pinvA = at::linalg_pinv(A);
auto dpinvA = pinv_jvp(A, pinvA, dA);
auto dX = dpinvA.matmul(B) + pinvA.matmul(dB);
Expand All @@ -3008,6 +3018,7 @@ std::tuple<Tensor, Tensor> linalg_lstsq_backward(
const c10::optional<c10::string_view> driver,
const std::array<bool, 2>& grad_input_mask
) {
at::NoTF32Guard disable_tf32;
Tensor A_grad, B_grad;
if (!grad.defined()) {
return std::make_tuple(A_grad, B_grad);
Expand Down Expand Up @@ -3041,6 +3052,7 @@ std::tuple<Tensor, Tensor> linalg_qr_jvp(
const Tensor& Q,
const Tensor& R
) {
at::NoTF32Guard disable_tf32;
auto m = dA.size(-2);
auto n = dA.size(-1);
auto k = std::min(m, n);
Expand Down Expand Up @@ -3092,6 +3104,7 @@ Tensor linalg_qr_jvp_R(

Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
c10::string_view mode, const Tensor& q, const Tensor& r){
at::NoTF32Guard disable_tf32;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool compute_q, reduced;
std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode);
Expand Down Expand Up @@ -3482,6 +3495,7 @@ std::tuple<Tensor, Tensor> triangular_solve_backward(
const Tensor & b, const Tensor & a, const Tensor & x,
const bool upper, const bool transpose, const bool unitriangular,
std::array<bool, 2> output_mask) {
at::NoTF32Guard disable_tf32;
Tensor grad_b, grad_a;
if (grad_x.defined() || grad_m.defined()) {
if (grad_x.defined()) {
Expand Down Expand Up @@ -3531,6 +3545,7 @@ Tensor linalg_solve_triangular_forward_AD(
const bool upper,
const bool left,
const bool unitriangular) {
at::NoTF32Guard disable_tf32;
// The forward AD formula (for left = true) is A^{-1}(B_t - A_tX)
// For the derivation see:
// [Note: Forward / Backward AD solve_triangular]
Expand All @@ -3548,6 +3563,7 @@ std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
const bool left,
const bool unitriangular,
std::array<bool, 2> output_mask) {
at::NoTF32Guard disable_tf32;
const bool A_requires_grad = output_mask[0];
const bool B_requires_grad = output_mask[1];
// [Note: Forward / Backward AD solve_triangular]
Expand Down Expand Up @@ -3598,6 +3614,7 @@ std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
std::tuple<Tensor, Tensor> cholesky_solve_backward(
const Tensor& grad_x, const Tensor& self,
const Tensor& input2, const Tensor& result, const bool upper) {
at::NoTF32Guard disable_tf32;
Tensor grad_self, grad_input2;
if (grad_x.defined()) {
grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper);
Expand All @@ -3621,6 +3638,7 @@ Tensor cholesky_solve_jvp(
const Tensor& dB,
const bool upper
) {
at::NoTF32Guard disable_tf32;
auto dK = upper ? dU.mH().matmul(U)
: dU.matmul(U.mH());
auto dA = dK + dK.mH();
Expand Down Expand Up @@ -4547,6 +4565,7 @@ std::tuple<Tensor, Tensor> lu_solve_backward(
const Tensor& LU_data,
const Tensor& LU_pivots,
const std::array<bool, 2>& grad_input_mask) {
at::NoTF32Guard disable_tf32;
const bool B_requires_grad = grad_input_mask[0];
const bool LU_data_requires_grad = grad_input_mask[1];
if (!grad.defined() || (!B_requires_grad && !LU_data_requires_grad)) {
Expand Down Expand Up @@ -4614,6 +4633,7 @@ Tensor lu_solve_jvp(
const Tensor& dB,
const Tensor& LU_pivots
) {
at::NoTF32Guard disable_tf32;
Tensor L, U, dL, dU;
std::tie(std::ignore, L, U) = at::lu_unpack(LU_data, LU_pivots, /*unpack_data=*/true, /*unpack_pivots=*/false);
dL = dLU_data.tril(-1);
Expand Down Expand Up @@ -5030,6 +5050,7 @@ Tensor plu_backward_base(
const Tensor& P,
const Tensor& L,
const Tensor& U) {
at::NoTF32Guard disable_tf32;
auto L_grad = grads[0];
auto U_grad = grads[1];

Expand Down Expand Up @@ -5118,6 +5139,7 @@ Tensor lu_factor_ex_jvp(
const Tensor& LU,
const Tensor& pivs
) {
at::NoTF32Guard disable_tf32;
// This function is based on the forward AD derivations outlined
// in the description to the plu_backward_base function.

Expand Down

0 comments on commit 89b4cfb

Please sign in to comment.