Skip to content

Commit

Permalink
[Inductor] Generalize new introduced device-bias code. (pytorch#126261)
Browse files Browse the repository at this point in the history
We find some Inductor test case failues when enabling Inductor UT for Intel GPU, the root cause is new introduced Inductor device-bias code from recent community PRs, which cause differnet beheaviors between Intel GPU and CUDA. This PR generalize these codes to align their beheaviors.

Pull Request resolved: pytorch#126261
Approved by: https://github.com/EikanWang, https://github.com/peterbell10
  • Loading branch information
etaf authored and ZelboK committed May 19, 2024
1 parent 68c29aa commit cd60801
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
)

from . import config, inductor_prims
from .utils import needs_fallback_due_to_atomic_add_limitations, use_scatter_fallback
from .utils import (
is_gpu,
needs_fallback_due_to_atomic_add_limitations,
use_scatter_fallback,
)

log = logging.getLogger(__name__)
aten = torch.ops.aten
Expand Down Expand Up @@ -167,7 +171,7 @@ def convolution_backward(
groups,
output_mask,
):
if not output_mask[2] or grad_output.device.type != "cuda":
if not output_mask[2] or not is_gpu(grad_output.device.type):
return NotImplemented
grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
grad_inp, grad_weight, _ = aten.convolution_backward(
Expand Down Expand Up @@ -593,7 +597,7 @@ def select_decomp_table():

@register_decomposition(aten.masked_scatter)
def masked_scatter(self, mask, source):
if self.device.type == "cuda":
if is_gpu(self.device.type):
# This two-step algorithm is the same as eager CUDA, for eager CPU we
# use a 1-shot serial iteration.
self, mask = aten.broadcast_tensors([self, mask])
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def should_assume_input_aligned(example_input: torch.Tensor):
# See Note: [Input Alignment handling in Inductor]

# right now, we only care about alignment for cuda tensors.
if example_input.device.type != "cuda":
if not is_gpu(example_input.device.type):
return False
return config.assume_aligned_inputs or tensor_is_aligned(example_input)

Expand Down

0 comments on commit cd60801

Please sign in to comment.