diff --git a/python/oneflow/nn/graph/graph.py b/python/oneflow/nn/graph/graph.py index 6735d2b7374..4d494b64c4b 100644 --- a/python/oneflow/nn/graph/graph.py +++ b/python/oneflow/nn/graph/graph.py @@ -364,7 +364,7 @@ def state_dict( assert len(additional_var_names) == len(additional_var_tensors) for i in range(len(additional_var_names)): additional_tensor = additional_var_tensors[i] - if not self._is_global_view: + if not self._is_global_view and additional_tensor.is_global: additional_tensor = additional_tensor.to_local() destination[additional_var_names[i]] = additional_tensor else: