Skip to content

Commit

Permalink
#0: Correcting bad dim check in CCL tests (#17392)
Browse files Browse the repository at this point in the history
When setting up the output shard specs in the pytest scripts, it was
hardcoded that the tensor be 4D causing incorrect results when the
tensor was not 4D.

### Checklist
- [ ] Post commit CI passes -- Unable to run due to CI problems, have
run testing on local t3k machine.
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
jvegaTT authored and yieldthought committed Jan 31, 2025
1 parent 4b90c10 commit 443780d
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/ccl/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ def run_all_gather_sharded(
)
input_mem_config = ttnn.MemoryConfig(tensor_mem_layout, buffer_type=ttnn.BufferType.L1, shard_spec=input_shard_spec)
output_shard_shape = list(input_shard_shape)
if dim == 3:
if dim == len(input_shape) - 1:
output_shard_shape[1] *= num_devices
else:
output_shard_shape[0] *= num_devices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
output_shard_spec = None
if input_shard_spec is not None:
output_shard_shape = list(input_shard_spec.shape)
if dim == 3:
if dim == len(per_chip_output_shape) - 1:
output_shard_shape[1] *= num_devices_per_line
else:
output_shard_shape[0] *= num_devices_per_line
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def sharded_impl(
if is_known_failure:
pytest.skip(f"Skipping unsupported case {message}.")
output_shard_shape = list(input_shard_shape)
if dim == 3:
if dim == len(input_shape) - 1:
output_shard_shape[1] *= num_devices
else:
output_shard_shape[0] *= num_devices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def run_all_gather_impl(
tensor_mem_layout, buffer_type=ttnn.BufferType.L1, shard_spec=input_shard_spec
)
output_shard_shape = list(input_shard_shape)
if dim == 3:
if dim == len(output_shape) - 1:
output_shard_shape[1] *= num_devices
else:
output_shard_shape[0] *= num_devices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def run_reduce_scatter_test(
tensor_mem_layout, buffer_type=ttnn.BufferType.L1, shard_spec=input_shard_spec
)
output_shard_shape = list(input_shard_shape)
if dim == 3:
if dim == len(per_chip_output_shape) - 1:
output_shard_shape[1] *= num_devices
else:
output_shard_shape[0] *= num_devices
Expand Down

0 comments on commit 443780d

Please sign in to comment.