Skip to content

Commit

Permalink
fix bug in set_float32_precision
Browse files Browse the repository at this point in the history
fix bug in UT
  • Loading branch information
Nicorgi committed Jan 16, 2025
1 parent 28c9ec2 commit 969c351
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
3 changes: 1 addition & 2 deletions tests/torchtune/training/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def test_error_bf16_unsupported(self, mock_verify):
get_dtype(torch.bfloat16)

@pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
@mock.patch("torchtune.training.precision.is_npu_available", return_value=True)
def test_set_float32_precision(self, mock_npu_available) -> None:
def test_set_float32_precision(self) -> None:
setattr( # noqa: B010
torch.backends, "__allow_nonbracketed_mutation_flag", True
)
Expand Down
2 changes: 1 addition & 1 deletion torchtune/training/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _set_float32_precision(precision: str = "high") -> None:
precision (str): The setting to determine which datatypes to use for matrix multiplication and convolution operations.
"""
# Not relevant for non-CUDA or non-NPU devices
if not torch.cuda.is_available() or not is_npu_available:
if not (torch.cuda.is_available() or is_npu_available):
return
# set precision for matrix multiplications
torch.set_float32_matmul_precision(precision)
Expand Down

0 comments on commit 969c351

Please sign in to comment.