diff --git a/functorch/.gitignore b/functorch/.gitignore index 7efc346007230..145ab7d608390 100644 --- a/functorch/.gitignore +++ b/functorch/.gitignore @@ -12,6 +12,7 @@ docs/build docs/src docs/source/generated .DS_Store +op_analysis/*.txt # Editor temporaries *.swn diff --git a/functorch/functorch/_src/decompositions.py b/functorch/functorch/_src/decompositions.py index 5c6bd538b7fef..f558640a60ddf 100644 --- a/functorch/functorch/_src/decompositions.py +++ b/functorch/functorch/_src/decompositions.py @@ -239,16 +239,25 @@ def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> T # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output +def apply_loss_reduction(loss: Tensor, reduction: int): + if reduction == Reduction.MEAN.value: + return torch.mean(loss) + elif reduction == Reduction.SUM.value: + return torch.sum(loss) + else: + return loss + + @register_decomposition(aten.l1_loss) def l1_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value) -> Tensor: - if reduction != Reduction.NONE.value: - loss = (self - target).abs() - if reduction == Reduction.MEAN.value: - return torch.mean(loss) - else: - return torch.sum(loss) - else: - return (self - target).abs() + loss = (self - target).abs() + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.mse_loss) +def mse_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value) -> Tensor: + loss = (self - target) ** 2 + return apply_loss_reduction(loss, reduction) @register_decomposition(aten.mse_loss_backward) @@ -257,6 +266,14 @@ def mse_loss_backward(grad_output: Tensor, input: Tensor, target: Tensor, reduct return norm * (input - target) * grad_output +@register_decomposition(aten.huber_loss) +def huber_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value, delta: float = 1.0) -> Tensor: + assert delta > 0, "huber_loss does not support non-positive values for delta." + z = (self - target).abs() + loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) + return apply_loss_reduction(loss, reduction) + + @register_decomposition(aten.huber_loss_backward) def huber_loss_backward(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float): norm = 1. / self.numel() if reduction == Reduction.MEAN.value else 1. diff --git a/functorch/op_analysis/annotated_ops.txt b/functorch/op_analysis/annotated_ops similarity index 100% rename from functorch/op_analysis/annotated_ops.txt rename to functorch/op_analysis/annotated_ops diff --git a/functorch/op_analysis/gen_data.py b/functorch/op_analysis/gen_data.py index fef6a19e67c8f..6000260750af9 100644 --- a/functorch/op_analysis/gen_data.py +++ b/functorch/op_analysis/gen_data.py @@ -84,7 +84,7 @@ def gen_data(special_op_lists, analysis_name): ops = yaml.load(open('../../pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader) - annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops.txt')))} + annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))} from collections import defaultdict uniq_ops = [] @@ -160,6 +160,7 @@ def annotate_ops(ops, is_unique): annotate_ops(ops, is_unique=False) with open(f"{analysis_name}", 'w') as f: + # import pdb; pdb.set_trace() for op in ops: info = [ op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops) @@ -176,12 +177,18 @@ def full_name_check(lst): # Generates batching rule data -gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap') +gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt') + + +def remove_suffix(input_string, suffix): + if suffix and input_string.endswith(suffix): + return input_string[:-len(suffix)] + return input_string if True: with open('run_ops.txt', 'r') as f: - opinfo_ops = [i.strip() for i in f.readlines()] + opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] with open('run_decompositions.txt', 'r') as f: - decomposed_ops = [i.strip() for i in f.readlines()] - gen_data([name_check(opinfo_ops), name_check(decomposed_ops)], 'decompositions') + decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] + gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops)], 'decompositions.txt') diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py index ef0799f6d9e8b..663337901274e 100644 --- a/functorch/test/test_ops.py +++ b/functorch/test/test_ops.py @@ -1292,10 +1292,6 @@ class TestDecompositionOpInfo(TestCase): xfail('linalg.tensorinv'), xfail('to_sparse'), skip('tensor_split'), - skip('mvlgamma', 'mvlgamma_p_1'), - skip('mvlgamma', 'mvlgamma_p_3'), - skip('mvlgamma', 'mvlgamma_p_5'), - skip('eig'), skip('nn.functional.dropout'), skip('_masked.softmin'), skip('_masked.log_softmax'), @@ -1303,16 +1299,10 @@ class TestDecompositionOpInfo(TestCase): skip('_masked.softmax'), skip('_masked.normalize'), xfail('linalg.lu_factor', ''), - # Some weird matmul stuff with int64 matmuls - # inplace op - skip('resize_'), # Weird conj errors xfail('fft.hfft2', dtypes=(torch.float32, torch.float64)), xfail('fft.hfft', dtypes=(torch.float32, torch.float64)), xfail('fft.hfftn', dtypes=(torch.float32, torch.float64)), - skip('nn.functional.binary_cross_entropy', ''), - skip('nn.functional.binary_cross_entropy_with_logits', '',), - skip('nn.functional.huber_loss'), }) def test_decomposition(self, device, dtype, op): # dtype is too confusing of a name for how we're using it