Skip to content

Commit

Permalink
[BE][Easy][4/19] enforce style for empty lines in import segments in …
Browse files Browse the repository at this point in the history
…`functorch/` (pytorch#129755)

See pytorch#129751 (comment). Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: pytorch#129755
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#129752
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Jul 18, 2024
1 parent a085acd commit 740fb22
Show file tree
Hide file tree
Showing 36 changed files with 48 additions and 27 deletions.
7 changes: 4 additions & 3 deletions functorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch

from torch._functorch.deprecated import (
combine_state_for_ensemble,
functionalize,
Expand All @@ -26,13 +25,15 @@
FunctionalModuleWithBuffers,
)

# Was never documented
from torch._functorch.python_key import make_fx


# Top-level APIs. Please think carefully before adding something to the
# top-level namespace:
# - private helper functions should go into torch._functorch
# - very experimental things should go into functorch.experimental
# - compilation related things should go into functorch.compile

# Was never documented
from torch._functorch.python_key import make_fx

__version__ = torch.__version__
2 changes: 1 addition & 1 deletion functorch/benchmarks/chrome_trace_parser.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/usr/bin/env python3
import argparse
import logging

import os

import pandas as pd

from torch._functorch.benchmark_utils import compute_utilization


# process the chrome traces output by the pytorch profiler
# require the json input file's name to be in format {model_name}_chrome_trace_*.json
# the runtimes file should have format (model_name, runtime)
Expand Down
2 changes: 0 additions & 2 deletions functorch/benchmarks/cse.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
import torch.fx as fx

from functorch import make_fx

from torch._functorch.compile_utils import fx_graph_cse
from torch.profiler import profile, ProfilerActivity

Expand Down
2 changes: 1 addition & 1 deletion functorch/benchmarks/operator_authoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pandas as pd

import torch

from functorch.compile import pointwise_operator


WRITE_CSV = False
CUDA = False
SIZES = [1, 512, 8192]
Expand Down
2 changes: 1 addition & 1 deletion functorch/benchmarks/per_sample_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import torch
import torch.nn as nn

from functorch import grad, make_functional, vmap


device = "cuda"
batch_size = 128
torch.manual_seed(0)
Expand Down
2 changes: 1 addition & 1 deletion functorch/benchmarks/pointwise_scorecard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import time

import torch

from functorch import pointwise_operator


torch.set_num_threads(1)
torch._C._debug_set_fusion_group_inlining(False)

Expand Down
1 change: 1 addition & 0 deletions functorch/benchmarks/process_scorecard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import matplotlib.pyplot as plt
import pandas


df = pandas.read_csv("perf.csv")

ops = pandas.unique(df["operator"])
Expand Down
4 changes: 3 additions & 1 deletion functorch/dim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from typing import Sequence, Union

import functorch._C

import torch
from functorch._C import dim as _C

from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type


_C._patch_tensor_class()
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists

Expand All @@ -23,6 +24,7 @@ class DimensionBindError(Exception):

from . import op_properties


# use dict to avoid writing C++ bindings for set
pointwise = dict.fromkeys(op_properties.pointwise, True)

Expand Down
1 change: 1 addition & 0 deletions functorch/dim/batch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers


_enabled = False


Expand Down
2 changes: 1 addition & 1 deletion functorch/dim/dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# LICENSE file in the root directory of this source tree.
import dis
import inspect

from dataclasses import dataclass
from typing import Union

from . import DimList


_vmap_levels = []


Expand Down
1 change: 1 addition & 0 deletions functorch/dim/op_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import torch


# pointwise operators can go through a faster pathway

tensor_magic_methods = ["add", ""]
Expand Down
4 changes: 2 additions & 2 deletions functorch/dim/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

# reference python implementations for C ops
import torch

from functorch._C import dim as _C

from . import op_properties
from .batch_tensor import _enable_layers
from .tree_map import tree_flatten, tree_map


DimList = _C.DimList
import operator
from functools import reduce
Expand Down Expand Up @@ -407,7 +408,6 @@ def t__getitem__(self, input):
# (keep track of whether we have to call super)
# * call super if needed
# * if we have dims to bind, bind them (it will help if we eliminated ... and None before)

# this handles bool indexing handling, as well as some other simple cases.

is_simple = (
Expand Down
1 change: 1 addition & 0 deletions functorch/dim/tree_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from functorch._C import dim


tree_flatten = dim.tree_flatten


Expand Down
1 change: 1 addition & 0 deletions functorch/dim/wrap_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from functorch._C import dim as _C


_wrap_method = _C._wrap_method

FUNC_TYPES = (
Expand Down
4 changes: 3 additions & 1 deletion functorch/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import functorch


# import sys

# source code directory, relative to this file, for sphinx-autobuild
Expand All @@ -27,6 +28,7 @@

import pytorch_sphinx_theme


# -- General configuration ------------------------------------------------

# Required version of sphinx is set from docs/requirements.txt
Expand Down Expand Up @@ -274,11 +276,11 @@ def setup(app):

# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043

from docutils import nodes
from sphinx import addnodes
from sphinx.util.docfields import TypedField


# Without this, doctest adds any example with a `>>>` as a test
doctest_test_doctest_blocks = ""
doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS
Expand Down
1 change: 1 addition & 0 deletions functorch/einops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .rearrange import rearrange


__all__ = ["rearrange"]
1 change: 1 addition & 0 deletions functorch/einops/_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union


_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated


Expand Down
3 changes: 2 additions & 1 deletion functorch/einops/rearrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Callable, Dict, List, Sequence, Tuple, Union

import torch

from functorch._C import dim as _C

from ._parsing import (
_ellipsis,
AnonymousAxis,
Expand All @@ -14,6 +14,7 @@
validate_rearrange_expressions,
)


__all__ = ["rearrange"]

dims = _C.dims
Expand Down
2 changes: 1 addition & 1 deletion functorch/examples/compilation/eager_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch
import torch.utils

from functorch.compile import aot_function, tvm_compile


a = torch.randn(2000, 1, 4, requires_grad=True)
b = torch.randn(1, 2000, 4)

Expand Down
1 change: 0 additions & 1 deletion functorch/examples/compilation/fuse_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.nn as nn

from functorch.compile import compiled_module, tvm_compile


Expand Down
2 changes: 1 addition & 1 deletion functorch/examples/compilation/linear_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import torch
import torch.nn as nn

from functorch import make_functional
from functorch.compile import nnc_jit


torch._C._jit_override_can_fuse_on_cpu(True)


Expand Down
1 change: 0 additions & 1 deletion functorch/examples/compilation/simple_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time

import torch

from functorch import grad, make_fx
from functorch.compile import nnc_jit

Expand Down
2 changes: 1 addition & 1 deletion functorch/examples/dp_cifar10/cifar10_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import torch.nn as nn
import torch.optim as optim
import torch.utils.data

from torch.func import functional_call, grad_and_value, vmap


logging.basicConfig(
format="%(asctime)s:%(levelname)s:%(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
Expand Down
1 change: 1 addition & 0 deletions functorch/examples/ensembling/parallel_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
from torch.func import functional_call, grad_and_value, stack_module_state, vmap


# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
# tutorial on Model Ensembling with JAX by Will Whitney.
#
Expand Down
1 change: 1 addition & 0 deletions functorch/examples/lennard_jones/lennard_jones.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.func import jacrev, vmap
from torch.nn.functional import mse_loss


sigma = 0.5
epsilon = 4.0

Expand Down
2 changes: 1 addition & 1 deletion functorch/examples/maml_omniglot/maml-omniglot-higher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
from support.omniglot_loaders import OmniglotNShot

Expand All @@ -43,6 +42,7 @@
import torch.optim as optim
from torch import nn


mpl.use("Agg")
plt.style.use("bmh")

Expand Down
3 changes: 1 addition & 2 deletions functorch/examples/maml_omniglot/maml-omniglot-ptonly.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,16 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
from support.omniglot_loaders import OmniglotNShot

import torch
import torch.nn.functional as F
import torch.optim as optim

from functorch import make_functional_with_buffers
from torch import nn


mpl.use("Agg")
plt.style.use("bmh")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
from support.omniglot_loaders import OmniglotNShot

Expand All @@ -44,6 +43,7 @@
from torch import nn
from torch.func import functional_call, grad, vmap


mpl.use("Agg")
plt.style.use("bmh")

Expand Down
1 change: 1 addition & 0 deletions functorch/examples/maml_regression/evjang.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from torch.nn import functional as F


mpl.use("Agg")


Expand Down
1 change: 1 addition & 0 deletions functorch/examples/maml_regression/evjang_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.func import grad, vmap
from torch.nn import functional as F


mpl.use("Agg")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import numpy as np

import torch

from functorch import grad, make_functional, vmap
from torch import nn
from torch.nn import functional as F


mpl.use("Agg")


Expand Down
1 change: 0 additions & 1 deletion functorch/experimental/control_flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from torch import cond # noqa: F401
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401

from torch._higher_order_ops.map import ( # noqa: F401
_stack_pytree,
_unstack_pytree,
Expand Down
Loading

0 comments on commit 740fb22

Please sign in to comment.