diff --git a/functorch/__init__.py b/functorch/__init__.py index aff35e592d80e5..0aef38c8a9bb84 100644 --- a/functorch/__init__.py +++ b/functorch/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch - from torch._functorch.deprecated import ( combine_state_for_ensemble, functionalize, @@ -26,13 +25,15 @@ FunctionalModuleWithBuffers, ) +# Was never documented +from torch._functorch.python_key import make_fx + + # Top-level APIs. Please think carefully before adding something to the # top-level namespace: # - private helper functions should go into torch._functorch # - very experimental things should go into functorch.experimental # - compilation related things should go into functorch.compile -# Was never documented -from torch._functorch.python_key import make_fx __version__ = torch.__version__ diff --git a/functorch/benchmarks/chrome_trace_parser.py b/functorch/benchmarks/chrome_trace_parser.py index ceb6ea58fbb915..826d53c990da89 100755 --- a/functorch/benchmarks/chrome_trace_parser.py +++ b/functorch/benchmarks/chrome_trace_parser.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 import argparse import logging - import os import pandas as pd from torch._functorch.benchmark_utils import compute_utilization + # process the chrome traces output by the pytorch profiler # require the json input file's name to be in format {model_name}_chrome_trace_*.json # the runtimes file should have format (model_name, runtime) diff --git a/functorch/benchmarks/cse.py b/functorch/benchmarks/cse.py index 3bf5ab11514c55..bd5c89c968da03 100644 --- a/functorch/benchmarks/cse.py +++ b/functorch/benchmarks/cse.py @@ -1,8 +1,6 @@ import torch import torch.fx as fx - from functorch import make_fx - from torch._functorch.compile_utils import fx_graph_cse from torch.profiler import profile, ProfilerActivity diff --git a/functorch/benchmarks/operator_authoring.py b/functorch/benchmarks/operator_authoring.py index 975311a67c5680..614a438f938106 100644 --- a/functorch/benchmarks/operator_authoring.py +++ b/functorch/benchmarks/operator_authoring.py @@ -5,9 +5,9 @@ import pandas as pd import torch - from functorch.compile import pointwise_operator + WRITE_CSV = False CUDA = False SIZES = [1, 512, 8192] diff --git a/functorch/benchmarks/per_sample_grads.py b/functorch/benchmarks/per_sample_grads.py index 7eb99096ec7287..95b76252244b57 100644 --- a/functorch/benchmarks/per_sample_grads.py +++ b/functorch/benchmarks/per_sample_grads.py @@ -6,9 +6,9 @@ import torch import torch.nn as nn - from functorch import grad, make_functional, vmap + device = "cuda" batch_size = 128 torch.manual_seed(0) diff --git a/functorch/benchmarks/pointwise_scorecard.py b/functorch/benchmarks/pointwise_scorecard.py index 6b3250cc9ec46a..5f46c0a74fc5d4 100644 --- a/functorch/benchmarks/pointwise_scorecard.py +++ b/functorch/benchmarks/pointwise_scorecard.py @@ -4,9 +4,9 @@ import time import torch - from functorch import pointwise_operator + torch.set_num_threads(1) torch._C._debug_set_fusion_group_inlining(False) diff --git a/functorch/benchmarks/process_scorecard.py b/functorch/benchmarks/process_scorecard.py index e535dcb5b5aa27..5649b5e1302696 100644 --- a/functorch/benchmarks/process_scorecard.py +++ b/functorch/benchmarks/process_scorecard.py @@ -1,6 +1,7 @@ import matplotlib.pyplot as plt import pandas + df = pandas.read_csv("perf.csv") ops = pandas.unique(df["operator"]) diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index 2413ed331fb2bc..a6de6ad59e9501 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -3,12 +3,13 @@ from typing import Sequence, Union import functorch._C - import torch from functorch._C import dim as _C + from .tree_map import tree_flatten, tree_map from .wrap_type import wrap_type + _C._patch_tensor_class() dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists @@ -23,6 +24,7 @@ class DimensionBindError(Exception): from . import op_properties + # use dict to avoid writing C++ bindings for set pointwise = dict.fromkeys(op_properties.pointwise, True) diff --git a/functorch/dim/batch_tensor.py b/functorch/dim/batch_tensor.py index 0fc17f2492d534..dae9b270896e98 100644 --- a/functorch/dim/batch_tensor.py +++ b/functorch/dim/batch_tensor.py @@ -7,6 +7,7 @@ from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers + _enabled = False diff --git a/functorch/dim/dim.py b/functorch/dim/dim.py index f8e34af96225f3..cbafce2f0ee0c4 100644 --- a/functorch/dim/dim.py +++ b/functorch/dim/dim.py @@ -5,12 +5,12 @@ # LICENSE file in the root directory of this source tree. import dis import inspect - from dataclasses import dataclass from typing import Union from . import DimList + _vmap_levels = [] diff --git a/functorch/dim/op_properties.py b/functorch/dim/op_properties.py index 3760f2cb0ea79c..01313f71f030d5 100644 --- a/functorch/dim/op_properties.py +++ b/functorch/dim/op_properties.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch + # pointwise operators can go through a faster pathway tensor_magic_methods = ["add", ""] diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py index 2e5f9f50901929..a7cec06aa65973 100644 --- a/functorch/dim/reference.py +++ b/functorch/dim/reference.py @@ -6,12 +6,13 @@ # reference python implementations for C ops import torch - from functorch._C import dim as _C + from . import op_properties from .batch_tensor import _enable_layers from .tree_map import tree_flatten, tree_map + DimList = _C.DimList import operator from functools import reduce @@ -407,7 +408,6 @@ def t__getitem__(self, input): # (keep track of whether we have to call super) # * call super if needed # * if we have dims to bind, bind them (it will help if we eliminated ... and None before) - # this handles bool indexing handling, as well as some other simple cases. is_simple = ( diff --git a/functorch/dim/tree_map.py b/functorch/dim/tree_map.py index 1f02f02656f288..3d2eae0582c856 100644 --- a/functorch/dim/tree_map.py +++ b/functorch/dim/tree_map.py @@ -6,6 +6,7 @@ from functorch._C import dim + tree_flatten = dim.tree_flatten diff --git a/functorch/dim/wrap_type.py b/functorch/dim/wrap_type.py index e2146c4a21a144..aae543b91a896e 100644 --- a/functorch/dim/wrap_type.py +++ b/functorch/dim/wrap_type.py @@ -14,6 +14,7 @@ from functorch._C import dim as _C + _wrap_method = _C._wrap_method FUNC_TYPES = ( diff --git a/functorch/docs/source/conf.py b/functorch/docs/source/conf.py index 8a1bf182ddaa4e..a53fb86d716937 100644 --- a/functorch/docs/source/conf.py +++ b/functorch/docs/source/conf.py @@ -16,6 +16,7 @@ import functorch + # import sys # source code directory, relative to this file, for sphinx-autobuild @@ -27,6 +28,7 @@ import pytorch_sphinx_theme + # -- General configuration ------------------------------------------------ # Required version of sphinx is set from docs/requirements.txt @@ -274,11 +276,11 @@ def setup(app): # -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # See http://stackoverflow.com/a/41184353/3343043 - from docutils import nodes from sphinx import addnodes from sphinx.util.docfields import TypedField + # Without this, doctest adds any example with a `>>>` as a test doctest_test_doctest_blocks = "" doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS diff --git a/functorch/einops/__init__.py b/functorch/einops/__init__.py index b32751d6e2493a..d7ac34f7a37220 100644 --- a/functorch/einops/__init__.py +++ b/functorch/einops/__init__.py @@ -1,3 +1,4 @@ from .rearrange import rearrange + __all__ = ["rearrange"] diff --git a/functorch/einops/_parsing.py b/functorch/einops/_parsing.py index 25f86ec6feee30..ffb1fc00a20ee9 100644 --- a/functorch/einops/_parsing.py +++ b/functorch/einops/_parsing.py @@ -28,6 +28,7 @@ import warnings from typing import Collection, List, Mapping, Optional, Set, Tuple, Union + _ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated diff --git a/functorch/einops/rearrange.py b/functorch/einops/rearrange.py index 0449bb7ed2c72e..1cd3cd8b3cf64f 100644 --- a/functorch/einops/rearrange.py +++ b/functorch/einops/rearrange.py @@ -4,8 +4,8 @@ from typing import Callable, Dict, List, Sequence, Tuple, Union import torch - from functorch._C import dim as _C + from ._parsing import ( _ellipsis, AnonymousAxis, @@ -14,6 +14,7 @@ validate_rearrange_expressions, ) + __all__ = ["rearrange"] dims = _C.dims diff --git a/functorch/examples/compilation/eager_fusion.py b/functorch/examples/compilation/eager_fusion.py index 3f89dec347e0ee..c5d13a4abe9b70 100644 --- a/functorch/examples/compilation/eager_fusion.py +++ b/functorch/examples/compilation/eager_fusion.py @@ -2,9 +2,9 @@ import torch import torch.utils - from functorch.compile import aot_function, tvm_compile + a = torch.randn(2000, 1, 4, requires_grad=True) b = torch.randn(1, 2000, 4) diff --git a/functorch/examples/compilation/fuse_module.py b/functorch/examples/compilation/fuse_module.py index a241efa7cf8874..a0eb60347714b3 100644 --- a/functorch/examples/compilation/fuse_module.py +++ b/functorch/examples/compilation/fuse_module.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn - from functorch.compile import compiled_module, tvm_compile diff --git a/functorch/examples/compilation/linear_train.py b/functorch/examples/compilation/linear_train.py index e538c54cb31e72..9bc686414609fc 100644 --- a/functorch/examples/compilation/linear_train.py +++ b/functorch/examples/compilation/linear_train.py @@ -8,10 +8,10 @@ import torch import torch.nn as nn - from functorch import make_functional from functorch.compile import nnc_jit + torch._C._jit_override_can_fuse_on_cpu(True) diff --git a/functorch/examples/compilation/simple_function.py b/functorch/examples/compilation/simple_function.py index 9091de556c72fa..d916cc5b6ee492 100644 --- a/functorch/examples/compilation/simple_function.py +++ b/functorch/examples/compilation/simple_function.py @@ -7,7 +7,6 @@ import time import torch - from functorch import grad, make_fx from functorch.compile import nnc_jit diff --git a/functorch/examples/dp_cifar10/cifar10_transforms.py b/functorch/examples/dp_cifar10/cifar10_transforms.py index 786031bce9b8d6..ed1da15ee691f9 100644 --- a/functorch/examples/dp_cifar10/cifar10_transforms.py +++ b/functorch/examples/dp_cifar10/cifar10_transforms.py @@ -21,9 +21,9 @@ import torch.nn as nn import torch.optim as optim import torch.utils.data - from torch.func import functional_call, grad_and_value, vmap + logging.basicConfig( format="%(asctime)s:%(levelname)s:%(message)s", datefmt="%m/%d/%Y %H:%M:%S", diff --git a/functorch/examples/ensembling/parallel_train.py b/functorch/examples/ensembling/parallel_train.py index 8fb0dae48e205a..a674a24c738dc3 100644 --- a/functorch/examples/ensembling/parallel_train.py +++ b/functorch/examples/ensembling/parallel_train.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from torch.func import functional_call, grad_and_value, stack_module_state, vmap + # Adapted from http://willwhitney.com/parallel-training-jax.html , which is a # tutorial on Model Ensembling with JAX by Will Whitney. # diff --git a/functorch/examples/lennard_jones/lennard_jones.py b/functorch/examples/lennard_jones/lennard_jones.py index 1b3d248ede11dc..30a50c14a7f794 100644 --- a/functorch/examples/lennard_jones/lennard_jones.py +++ b/functorch/examples/lennard_jones/lennard_jones.py @@ -7,6 +7,7 @@ from torch.func import jacrev, vmap from torch.nn.functional import mse_loss + sigma = 0.5 epsilon = 4.0 diff --git a/functorch/examples/maml_omniglot/maml-omniglot-higher.py b/functorch/examples/maml_omniglot/maml-omniglot-higher.py index db058e1f621181..82e33581124ebd 100755 --- a/functorch/examples/maml_omniglot/maml-omniglot-higher.py +++ b/functorch/examples/maml_omniglot/maml-omniglot-higher.py @@ -34,7 +34,6 @@ import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np - import pandas as pd from support.omniglot_loaders import OmniglotNShot @@ -43,6 +42,7 @@ import torch.optim as optim from torch import nn + mpl.use("Agg") plt.style.use("bmh") diff --git a/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py index 436d11d159129b..9067c7b75bcc61 100755 --- a/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py +++ b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py @@ -33,17 +33,16 @@ import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np - import pandas as pd from support.omniglot_loaders import OmniglotNShot import torch import torch.nn.functional as F import torch.optim as optim - from functorch import make_functional_with_buffers from torch import nn + mpl.use("Agg") plt.style.use("bmh") diff --git a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py index 2a7c1b95fb0cd1..cbc28ac1ee5778 100755 --- a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py +++ b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py @@ -34,7 +34,6 @@ import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np - import pandas as pd from support.omniglot_loaders import OmniglotNShot @@ -44,6 +43,7 @@ from torch import nn from torch.func import functional_call, grad, vmap + mpl.use("Agg") plt.style.use("bmh") diff --git a/functorch/examples/maml_regression/evjang.py b/functorch/examples/maml_regression/evjang.py index 423aaab1e16989..76e6d9c1c871fa 100644 --- a/functorch/examples/maml_regression/evjang.py +++ b/functorch/examples/maml_regression/evjang.py @@ -11,6 +11,7 @@ import torch from torch.nn import functional as F + mpl.use("Agg") diff --git a/functorch/examples/maml_regression/evjang_transforms.py b/functorch/examples/maml_regression/evjang_transforms.py index 92a8fbcdbf55b4..31f0f72d110040 100644 --- a/functorch/examples/maml_regression/evjang_transforms.py +++ b/functorch/examples/maml_regression/evjang_transforms.py @@ -12,6 +12,7 @@ from torch.func import grad, vmap from torch.nn import functional as F + mpl.use("Agg") diff --git a/functorch/examples/maml_regression/evjang_transforms_module.py b/functorch/examples/maml_regression/evjang_transforms_module.py index e2b8c548282e45..b5fac551122980 100644 --- a/functorch/examples/maml_regression/evjang_transforms_module.py +++ b/functorch/examples/maml_regression/evjang_transforms_module.py @@ -9,11 +9,11 @@ import numpy as np import torch - from functorch import grad, make_functional, vmap from torch import nn from torch.nn import functional as F + mpl.use("Agg") diff --git a/functorch/experimental/control_flow.py b/functorch/experimental/control_flow.py index e24fc614282001..cbfd76d184cc22 100644 --- a/functorch/experimental/control_flow.py +++ b/functorch/experimental/control_flow.py @@ -1,6 +1,5 @@ from torch import cond # noqa: F401 from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401 - from torch._higher_order_ops.map import ( # noqa: F401 _stack_pytree, _unstack_pytree, diff --git a/functorch/notebooks/_src/plot_ensembling.py b/functorch/notebooks/_src/plot_ensembling.py index f3627e425031b5..55554a1985b434 100644 --- a/functorch/notebooks/_src/plot_ensembling.py +++ b/functorch/notebooks/_src/plot_ensembling.py @@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F + torch.manual_seed(0) @@ -85,6 +86,7 @@ def forward(self, x): # stateless version of the model (fmodel) and stacked parameters and buffers. from functorch import combine_state_for_ensemble + fmodel, params, buffers = combine_state_for_ensemble(models) [p.requires_grad_() for p in params] @@ -97,6 +99,7 @@ def forward(self, x): assert minibatches.shape == (num_models, 64, 1, 28, 28) from functorch import vmap + predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) assert torch.allclose( predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6 diff --git a/functorch/notebooks/_src/plot_jacobians_and_hessians.py b/functorch/notebooks/_src/plot_jacobians_and_hessians.py index ca6e160bad25b3..295810675ea02c 100644 --- a/functorch/notebooks/_src/plot_jacobians_and_hessians.py +++ b/functorch/notebooks/_src/plot_jacobians_and_hessians.py @@ -13,6 +13,7 @@ import torch import torch.nn.functional as F + torch.manual_seed(0) @@ -54,6 +55,7 @@ def compute_jac(xp): # to PyTorch Autograd; instead, functorch provides a ``vjp`` transform: from functorch import vjp, vmap + _, vjp_fn = vjp(partial(predict, weight, bias), x) (ft_jacobian,) = vmap(vjp_fn)(unit_vectors) assert torch.allclose(ft_jacobian, jacobian) @@ -69,6 +71,7 @@ def compute_jac(xp): # respect to. from functorch import jacrev + ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) assert torch.allclose(ft_jacobian, jacobian) @@ -78,6 +81,7 @@ def compute_jac(xp): # eliminate overhead and give better utilization of your hardware. from torch.utils.benchmark import Timer + without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) print(without_vmap.timeit(500)) @@ -108,6 +112,7 @@ def compute_jac(xp): # it column-by-column. The Jacobian matrix has M rows and N columns. from functorch import jacfwd, jacrev + # Benchmark with more inputs than outputs Din = 32 Dout = 2048 @@ -144,6 +149,7 @@ def compute_jac(xp): # ``jacrev(jacrev(f))`` instead to compute hessians. from functorch import hessian + # # TODO: make sure PyTorch has tanh_backward implemented for jvp!! # hess0 = hessian(predict, argnums=2)(weight, bias, x) # hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x) diff --git a/functorch/notebooks/_src/plot_per_sample_gradients.py b/functorch/notebooks/_src/plot_per_sample_gradients.py index b0d10bcf484c6b..98e850e5ce0022 100644 --- a/functorch/notebooks/_src/plot_per_sample_gradients.py +++ b/functorch/notebooks/_src/plot_per_sample_gradients.py @@ -13,6 +13,7 @@ import torch.nn as nn import torch.nn.functional as F + torch.manual_seed(0) @@ -94,6 +95,7 @@ def compute_sample_grads(data, targets): # ``functorch.make_functional_with_buffers``. from functorch import grad, make_functional_with_buffers, vmap + fmodel, params, buffers = make_functional_with_buffers(model) diff --git a/tools/linter/adapters/ufmt_linter.py b/tools/linter/adapters/ufmt_linter.py index 9d67784b63a3f9..d48d2c6b3c027f 100644 --- a/tools/linter/adapters/ufmt_linter.py +++ b/tools/linter/adapters/ufmt_linter.py @@ -33,7 +33,6 @@ # .github/** # benchmarks/** # functorch/** - "functorch/**", # tools/** # torchgen/** # test/**