Skip to content

Commit

Permalink
[functorch] updated some decompositions and cleaned some stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee authored and bigfootjon committed Jul 21, 2022
1 parent 151b1f4 commit 9eb10d0
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 23 deletions.
1 change: 1 addition & 0 deletions functorch/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ docs/build
docs/src
docs/source/generated
.DS_Store
op_analysis/*.txt

# Editor temporaries
*.swn
Expand Down
33 changes: 25 additions & 8 deletions functorch/functorch/_src/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,25 @@ def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> T
# return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output


def apply_loss_reduction(loss: Tensor, reduction: int):
if reduction == Reduction.MEAN.value:
return torch.mean(loss)
elif reduction == Reduction.SUM.value:
return torch.sum(loss)
else:
return loss


@register_decomposition(aten.l1_loss)
def l1_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value) -> Tensor:
if reduction != Reduction.NONE.value:
loss = (self - target).abs()
if reduction == Reduction.MEAN.value:
return torch.mean(loss)
else:
return torch.sum(loss)
else:
return (self - target).abs()
loss = (self - target).abs()
return apply_loss_reduction(loss, reduction)


@register_decomposition(aten.mse_loss)
def mse_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value) -> Tensor:
loss = (self - target) ** 2
return apply_loss_reduction(loss, reduction)


@register_decomposition(aten.mse_loss_backward)
Expand All @@ -257,6 +266,14 @@ def mse_loss_backward(grad_output: Tensor, input: Tensor, target: Tensor, reduct
return norm * (input - target) * grad_output


@register_decomposition(aten.huber_loss)
def huber_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value, delta: float = 1.0) -> Tensor:
assert delta > 0, "huber_loss does not support non-positive values for delta."
z = (self - target).abs()
loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
return apply_loss_reduction(loss, reduction)


@register_decomposition(aten.huber_loss_backward)
def huber_loss_backward(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
norm = 1. / self.numel() if reduction == Reduction.MEAN.value else 1.
Expand Down
File renamed without changes.
17 changes: 12 additions & 5 deletions functorch/op_analysis/gen_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def gen_data(special_op_lists, analysis_name):

ops = yaml.load(open('../../pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)

annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops.txt')))}
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))}
from collections import defaultdict

uniq_ops = []
Expand Down Expand Up @@ -160,6 +160,7 @@ def annotate_ops(ops, is_unique):

annotate_ops(ops, is_unique=False)
with open(f"{analysis_name}", 'w') as f:
# import pdb; pdb.set_trace()
for op in ops:
info = [
op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops)
Expand All @@ -176,12 +177,18 @@ def full_name_check(lst):


# Generates batching rule data
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap')
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt')


def remove_suffix(input_string, suffix):
if suffix and input_string.endswith(suffix):
return input_string[:-len(suffix)]
return input_string


if True:
with open('run_ops.txt', 'r') as f:
opinfo_ops = [i.strip() for i in f.readlines()]
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
with open('run_decompositions.txt', 'r') as f:
decomposed_ops = [i.strip() for i in f.readlines()]
gen_data([name_check(opinfo_ops), name_check(decomposed_ops)], 'decompositions')
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops)], 'decompositions.txt')
10 changes: 0 additions & 10 deletions functorch/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,27 +1292,17 @@ class TestDecompositionOpInfo(TestCase):
xfail('linalg.tensorinv'),
xfail('to_sparse'),
skip('tensor_split'),
skip('mvlgamma', 'mvlgamma_p_1'),
skip('mvlgamma', 'mvlgamma_p_3'),
skip('mvlgamma', 'mvlgamma_p_5'),
skip('eig'),
skip('nn.functional.dropout'),
skip('_masked.softmin'),
skip('_masked.log_softmax'),
skip('stft'),
skip('_masked.softmax'),
skip('_masked.normalize'),
xfail('linalg.lu_factor', ''),
# Some weird matmul stuff with int64 matmuls
# inplace op
skip('resize_'),
# Weird conj errors
xfail('fft.hfft2', dtypes=(torch.float32, torch.float64)),
xfail('fft.hfft', dtypes=(torch.float32, torch.float64)),
xfail('fft.hfftn', dtypes=(torch.float32, torch.float64)),
skip('nn.functional.binary_cross_entropy', ''),
skip('nn.functional.binary_cross_entropy_with_logits', '',),
skip('nn.functional.huber_loss'),
})
def test_decomposition(self, device, dtype, op):
# dtype is too confusing of a name for how we're using it
Expand Down

0 comments on commit 9eb10d0

Please sign in to comment.