diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 6d9bf04346f7..a4599c8435fa 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1364,7 +1364,7 @@ def save(self, fname, remove_amp_cast=True): else: check_call(_LIB.MXSymbolSaveToFile(self.handle, c_str(fname))) - def tojson(self): + def tojson(self, remove_amp_cast=True): """Saves symbol to a JSON string. See Also @@ -1372,7 +1372,12 @@ def tojson(self): symbol.load_json : Used to load symbol from JSON string. """ json_str = ctypes.c_char_p() - check_call(_LIB.MXSymbolSaveToJSON(self.handle, ctypes.byref(json_str))) + if remove_amp_cast: + handle = SymbolHandle() + check_call(_LIB.MXSymbolRemoveAmpCast(self.handle, ctypes.byref(handle))) + check_call(_LIB.MXSymbolSaveToJSON(handle, ctypes.byref(json_str))) + else: + check_call(_LIB.MXSymbolSaveToJSON(self.handle, ctypes.byref(json_str))) return py_str(json_str.value) @staticmethod