Skip to content

Commit

Permalink
Revert D34808842: Reland "[pytorch][PR] Support dataclasses in TorchS…
Browse files Browse the repository at this point in the history
…cript"

Test Plan: revert-hammer

Differential Revision:
D34808842 (pytorch@b57cc9c)

Original commit changeset: 02f807cff1ea

Original Phabricator Diff: D34808842 (pytorch@b57cc9c)

fbshipit-source-id: bd7c47493b598677e77634d06d7dc3e3a457b92d
(cherry picked from commit e1853d7)
  • Loading branch information
b0noI authored and pytorchmergebot committed Mar 25, 2022
1 parent 7fe0b6a commit 3b3bdfd
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 394 deletions.
161 changes: 0 additions & 161 deletions test/jit/test_dataclasses.py

This file was deleted.

19 changes: 19 additions & 0 deletions test/jit/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.testing import FileCheck
from torch import jit
from jit.test_module_interface import TestModuleInterface # noqa: F401
import unittest
import os
import sys
import torch
Expand Down Expand Up @@ -46,6 +47,24 @@ def func(x):
self.assertEqual(out, out_script)
self.assertEqual(captured, captured_script)

@unittest.skipIf(sys.version_info[:2] < (3, 7), "`dataclasses` module not present on < 3.7")
def test_dataclass_error(self):
from dataclasses import dataclass

@dataclass
class NormalizationInfo(object):
mean: float = 0.0

def compute(self, total_rows):
return self.mean

def fn():
return NormalizationInfo(1, 2, 3, 4, 5)

with self.assertRaisesRegex(OSError, "could not get source code"):
torch.jit.script(fn)


def test_kwarg_support(self):
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
class M(torch.nn.Module):
Expand Down
1 change: 0 additions & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
from jit.test_dce import TestDCE # noqa: F401
from jit.test_sparse import TestSparse # noqa: F401
from jit.test_tensor_methods import TestTensorMethods # noqa: F401
from jit.test_dataclasses import TestDataclasses # noqa: F401

# Torch
from torch import Tensor
Expand Down
30 changes: 3 additions & 27 deletions torch/_jit_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,6 @@
boolean_dispatched: 'weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]' = weakref.WeakKeyDictionary() # noqa: T484


FAKE_FILENAME_PREFIX = '__torch_jit_dataclass'


class SourceLoader:

def __init__(self):
self.content = {}

def cache(self, fn, source):
self.content[fn] = source

def get_source(self, fn):
return self.content.get(fn)


loader = SourceLoader()


def createResolutionCallbackFromEnv(lookup_base):
"""
Creates a resolution callback that will look up qualified names in an
Expand Down Expand Up @@ -341,14 +323,6 @@ def get_type_hint_captures(fn):
A Dict[str, Any] containing a mapping from the literal annotations used on
fn to the Python objects they refer to.
"""
# First, try to get the source of the function. We'll need to parse it to find the actual string names
# that were used to annotate the types, since inspect.signature() will only return the class object that
# the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
# This may happen in cases where the function is synthesized dynamically at runtime.
src = loader.get_source(fn)
if src is None:
src = inspect.getsource(fn)

# Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
# types are strings. These are only understood by TorchScript in the context of a type annotation
# that refers to a class in its own definition, but trying to include a mapping for this in the result
Expand All @@ -364,6 +338,8 @@ def get_type_hint_captures(fn):
# Then, get the literal type annotations from the function declaration
# by source inspection. This accounts for the case in which aliases are used
# to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
src = inspect.getsource(fn)

# frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
a = ast.parse(dedent(src))
if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
Expand Down Expand Up @@ -929,7 +905,7 @@ def is_optional_as_optional(ann):

def is_union_as_optional(ann):
ann_args = ann.__args__
return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
return len(ann_args) == 2 and None in ann_args

return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))

Expand Down
9 changes: 1 addition & 8 deletions torch/_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,13 @@ def remove_prefix(text, prefix):
return text[text.startswith(prefix) and len(prefix):]

# Find the line and line number containing the function definition
idx = None
for i, l in enumerate(sourcelines):
if l.lstrip().startswith("def"):
idx = i
break

# This will happen when the function is a lambda- we won't find "def" anywhere in the source
# lines in that case. Currently trying to JIT compile a lambda will throw an error up in
# `parse_def()`, but we might want to handle this case in the future.
if idx is None:
return sourcelines
fn_def = sourcelines[idx]

# Get a string representing the amount of leading whitespace
fn_def = sourcelines[idx]
whitespace = fn_def.split("def")[0]

# Add this leading whitespace to all lines before and after the `def`
Expand Down
Loading

0 comments on commit 3b3bdfd

Please sign in to comment.