From d18565643598dfeb419e96d009052e17e21b790e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 8 Nov 2023 04:21:33 -0600 Subject: [PATCH] Fixed failing compiled execs --- test/test_transforms_v2.py | 115 +++-------------------------- torchvision/tv_tensors/__init__.py | 4 + 2 files changed, 14 insertions(+), 105 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4e3c6bf49dc..7f19b923559 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -193,11 +193,10 @@ def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs): return functional_compiled = torch.compile(functional) - functional_compiled(input.as_subclass(torch.Tensor), *args, **kwargs) + functional_compiled(input, *args, **kwargs) - explanation = torch._dynamo.explain(functional_compiled)(input.as_subclass(torch.Tensor), *args, **kwargs) - # TODO: Set expected values to 1, 0 once fixed the graph break related to function registration - assert explanation.graph_count in (1, 2) + explanation = torch._dynamo.explain(functional_compiled)(input, *args, **kwargs) + # TODO: Set expected value to 0 once fixed the graph break related to function registration assert explanation.graph_break_count in (0, 1) @@ -1162,22 +1161,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py:1573: in UNPACK_SEQUENCE - # assert len(val) == inst.argval - # E AssertionError: - # E - # E from user code: - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 392, in resume_in_affine - # E return kernel( - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 692, in affine_image - # E return _apply_grid_transform(image, grid, interpolation.value, fill=fill) - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform - # E num_channels, input_height, input_width = input_shape[-3:] - check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True - check_functional( - F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke - ) + check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1601,12 +1585,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # Error is the same as for TestAffine - check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True - check_functional( - F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke - ) + check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -2671,38 +2650,8 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # - TestElastic.test_functional[make_bounding_boxes]: - # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:410: in _fn - # return fn(*args, **kwargs) - # torchvision/transforms/v2/functional/_geometry.py:1705: in elastic - # kernel = _get_kernel(elastic, type(inpt)) - # torchvision/transforms/v2/functional/_geometry.py:1706: in resume_in_elastic - # return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill) - # torchvision/transforms/v2/functional/_geometry.py:1741: in elastic_image - # raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") - # torchvision/transforms/v2/functional/_geometry.py:1741: in resume_in_elastic_image - # raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") - # E ValueError: Argument displacement shape should be (1, 1, 4, 2), but given torch.Size([1, 17, 11, 2]) - # - # - TestElastic.test_functional[make_segmentation_mask]: - # E AssertionError: - # E - # E from user code: - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1706, in resume_in_elastic - # E return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill) - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1746, in elastic_image - # E output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform - # E num_channels, input_height, input_width = input_shape[-3:] - check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True input = make_input() - check_functional( - F.elastic, - input, - displacement=self._make_displacement(input), - check_torch_compile_smoke=check_torch_compile_smoke, - ) + check_functional(F.elastic, input, displacement=self._make_displacement(input)) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -2815,23 +2764,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # E AssertionError: - # E - # E from user code: - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1324, in resume_in_crop - # E return kernel(inpt, top=top, left=left, height=height, width=width) - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1343, in crop_image - # E return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill - # E num_channels, height, width = shape[-3:] - check_torch_compile_smoke = False if make_input == make_bounding_boxes else True - check_functional( - F.crop, - make_input(self.INPUT_SIZE), - **self.MINIMAL_CROP_KWARGS, - check_torch_compile_smoke=check_torch_compile_smoke, - ) + check_functional(F.crop, make_input(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -3468,18 +3401,7 @@ def test_kernel_inplace(self, old_format, new_format): @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) def test_functional(self, old_format, new_format): - # TODO: Disabled torch.compile check due to the error: - # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format - # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") - # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format - # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") - # E ValueError: For pure tensor inputs, `old_format` has to be passed. - check_functional( - F.convert_bounding_box_format, - make_bounding_boxes(format=old_format), - new_format=new_format, - check_torch_compile_smoke=False, - ) + check_functional(F.convert_bounding_box_format, make_bounding_boxes(format=old_format), new_format=new_format) @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) @pytest.mark.parametrize("format_type", ["enum", "str"]) @@ -3753,18 +3675,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # E AssertionError: - # E - # E from user code: - # E File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1104, in resume_in_pad - # E return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode) - # E File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1154, in pad_image - # E return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) - # E File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill - # E num_channels, height, width = shape[-3:] - check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True - check_functional(F.pad, make_input(), padding=[1], check_torch_compile_smoke=check_torch_compile_smoke) + check_functional(F.pad, make_input(), padding=[1]) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -4415,13 +4326,7 @@ def test_kernel(self, format, dtype, device): @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_functional(self, format): - # TODO: Disabled torch.compile check due to the error: - # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format - # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") - # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format - # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") - # E ValueError: For pure tensor inputs, `old_format` has to be passed. - check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format), check_torch_compile_smoke=False) + check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format)) def test_errors(self): input_tv_tensor = make_bounding_boxes() diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index d55e10e8620..2e58d9d4c6a 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -8,6 +8,10 @@ from ._video import Video +# TODO: Fix this. We skip this method as it leads to +# RecursionError: maximum recursion depth exceeded while calling a Python object +# Keeping it here, leads to graph breaks between multiple functional ops instead of having a single graph +@torch.compiler.disable def wrap(wrappee, *, like, **kwargs): """[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.