Skip to content

Commit

Permalink
Skip some tests for torch 2.4 (#1981)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Aug 15, 2024
1 parent 46d9192 commit f9dfd54
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions neural_compressor/torch/algorithms/pt2e_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.ao.quantization.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer

from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2
from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5


def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
Expand Down Expand Up @@ -102,8 +102,8 @@ def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86Induct
# set global
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)
quantizer.set_global(global_config)
# need torch >= 2.3.2
if GT_TORCH_VERSION_2_3_2: # pragma: no cover
# need torch >= 2.5
if GT_OR_EQUAL_TORCH_VERSION_2_5: # pragma: no cover
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
if op_type_config_dict:
for op_type, config in op_type_config_dict.items():
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_torch_version():
return version


GT_TORCH_VERSION_2_3_2 = get_torch_version() > Version("2.3.2")
GT_OR_EQUAL_TORCH_VERSION_2_5 = get_torch_version() >= Version("2.5")


def get_accelerator(device_name="auto"):
Expand Down
6 changes: 3 additions & 3 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
prepare,
quantize,
)
from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2, TORCH_VERSION_2_2_2, get_torch_version
from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5, TORCH_VERSION_2_2_2, get_torch_version

torch.manual_seed(0)

Expand Down Expand Up @@ -131,7 +131,7 @@ def calib_fn(model):
logger.warning("out shape is %s", out.shape)
assert out is not None

@pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.2")
@pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5")
def test_quantize_simple_model_with_set_local(self, force_not_import_ipex):
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
float_model_output = model(*example_inputs)
Expand Down Expand Up @@ -243,7 +243,7 @@ def get_node_in_graph(graph_module):
nodes_in_graph[n] = 1
return nodes_in_graph

@pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.0")
@pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5")
def test_mixed_fp16_and_int8(self, force_not_import_ipex):
model, example_inputs = self.build_model_include_conv_and_linear()
model = export(model, example_inputs=example_inputs)
Expand Down

0 comments on commit f9dfd54

Please sign in to comment.