Skip to content

Commit

Permalink
Fix memory leaks in Gluon (apache#18328)
Browse files Browse the repository at this point in the history
Fix leak of ndarray objects in the frontend due to reference cycle.
  • Loading branch information
leezu authored and chinakook committed Jul 24, 2020
1 parent ad7c53b commit 2cc44b4
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 13 deletions.
25 changes: 15 additions & 10 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import threading
import copy
import warnings
import re
import weakref
from collections import OrderedDict, defaultdict

import re
import numpy as np

from ..base import mx_real_t, MXNetError, NDArrayHandle, py_str
Expand All @@ -47,7 +49,7 @@ class _BlockScope(object):
_current = threading.local()

def __init__(self, block):
self._block = block
self._block = weakref.ref(block) if block is not None else None
self._counter = {}
self._old_scope = None
self._name_scope = None
Expand All @@ -59,7 +61,8 @@ def create(prefix, params, hint):
The profiler scope is to support the GPU memory profiler.
"""
current = getattr(_BlockScope._current, "value", None)
if current is None:
block = current._block() if current is not None else None
if current is None or block is None:
if prefix is None:
if not hasattr(_name.NameManager._current, "value"):
_name.NameManager._current.value = _name.NameManager()
Expand All @@ -78,29 +81,31 @@ def create(prefix, params, hint):
prefix = '%s%d_'%(hint, count)
current._counter[hint] = count + 1
if params is None:
parent = current._block.params
parent = block.params
params = ParameterDict(parent.prefix+prefix, parent._shared)
else:
params = ParameterDict(params.prefix, params)
# replace the trailing underscore with colon
profiler_scope_name = (prefix[:-1] if prefix.endswith('_') \
else prefix) + ":"
return current._block.prefix + prefix, params, \
current._block._profiler_scope_name + profiler_scope_name
return block.prefix + prefix, params, \
block._profiler_scope_name + profiler_scope_name

def __enter__(self):
if self._block._empty_prefix:
block = self._block()
if block is None or block._empty_prefix:
return self
self._old_scope = getattr(_BlockScope._current, "value", None)
_BlockScope._current.value = self
self._name_scope = _name.Prefix(self._block.prefix)
self._name_scope = _name.Prefix(block.prefix)
self._name_scope.__enter__()
self._profiler_scope = _profiler.Scope(self._block._profiler_scope_name)
self._profiler_scope = _profiler.Scope(block._profiler_scope_name)
self._profiler_scope.__enter__()
return self

def __exit__(self, ptype, value, trace):
if self._block._empty_prefix:
block = self._block()
if block is None or block._empty_prefix:
return
self._name_scope.__exit__(ptype, value, trace)
self._name_scope = None
Expand Down
39 changes: 38 additions & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import os
import tempfile
import gc

import mxnet as mx
from mxnet import gluon
Expand Down Expand Up @@ -3215,3 +3215,40 @@ def hybrid_forward(self, F, x):

mx.test_utils.assert_almost_equal(grad1, grad2)

def test_no_memory_leak_in_gluon():
# Collect all other garbage prior to this test. Otherwise the test may fail
# due to unrelated memory leaks.
gc.collect()

gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize()
del net
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

# Check for leaked NDArrays
seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mx.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]
5 changes: 3 additions & 2 deletions tests/python/unittest/test_thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ def __init__(self, prefix):
status = [False]
event = threading.Event()
def f():
with block._BlockScope(dummy_block("spawned_")):
x= NameManager.current.get(None, "hello")
net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block
with block._BlockScope(net):
x = NameManager.current.get(None, "hello")
event.wait()
if x == "spawned_hello0":
status[0] = True
Expand Down

0 comments on commit 2cc44b4

Please sign in to comment.