diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 6a60a29cf0..65d2e90f13 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -206,6 +206,8 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N """ assert model_path is not None, 'model_path must be specified' for name, m in self.bound_model.named_modules(): + if name == "": + continue mask = self.mask_dict.get(name) if mask is not None: mask_sum = mask.sum().item()