From f1ab955509a1615a17383b9846395bebf8b0d228 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Mon, 22 Oct 2018 18:10:24 -0700 Subject: [PATCH] fix memory leak in FeedForward.scala by making it a native resource and disposing argparams, auxParams in dispose() method --- .../scala/org/apache/mxnet/FeedForward.scala | 52 ++++++++++++++++--- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala index 3d584a9e14dd..2ed9d8cfbb84 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala @@ -17,9 +17,10 @@ package org.apache.mxnet +import org.apache.mxnet.Base.CPtrAddress import org.apache.mxnet.io.NDArrayIter import org.apache.mxnet.optimizer.SGD -import org.slf4j.{LoggerFactory, Logger} +import org.slf4j.{Logger, LoggerFactory} import scala.collection.mutable.ListBuffer @@ -55,7 +56,7 @@ class FeedForward private( argParams: Map[String, NDArray], auxParams: Map[String, NDArray], private val allowExtraParams: Boolean, - val beginEpoch: Int) { + val beginEpoch: Int) extends NativeResource { val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward]) private var argumentChecked = false @@ -126,6 +127,8 @@ class FeedForward private( } // Initialize weight parameters and auxiliary states + // The NDArrays associated with the _argParms and _auxParams are not disposed instead + // they are passed a outer scope if available. private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false) : (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = { val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes) @@ -137,16 +140,26 @@ class FeedForward private( val paramNameShapes = (argNames zip argShapes).filter { case (name, _) => paramNames.contains(name) } - val argParams = paramNameShapes.map { case (name, shape) => - (name, NDArray.zeros(shape)) + val argParams = paramNameShapes.map { case (name, shape) => { + val param = NDArray.zeros(shape) + val curScope = ResourceScope.getCurrentScope() + if (curScope.isDefined) curScope.get.moveToOuterScope(param) + (name, param) + } }.toMap - val auxParams = (auxNames zip auxShapes).map { case (name, shape) => - (name, NDArray.zeros(shape)) + + val auxParams = (auxNames zip auxShapes).map { case (name, shape) => { + val param = NDArray.zeros(shape) + val curScope = ResourceScope.getCurrentScope() + if (curScope.isDefined) curScope.get.moveToOuterScope(param) + (name, param) + } }.toMap for ((k, v) <- argParams) { if (_argParams != null && _argParams.contains(k) && (!overwrite)) { argParams(k).set(_argParams(k)) + } else { initializer(k, v) } @@ -356,7 +369,11 @@ class FeedForward private( batchEndCallback: BatchEndCallback = null, logger: Logger = FeedForward.logger, workLoadList: Seq[Float] = null): Unit = { require(evalMetric != null, "evalMetric cannot be null") - ResourceScope.using() { + // TODO: https://issues.apache.org/jira/browse/MXNET-1171 + // this leaks memory, initSymbolParams->initParams is already called which allocates + // NDArray in argParams, auxParams and here we are overwriting it by calling again. + // PhantomRef should take care of releasing this when GC is called, however we have to + // wait for the GC call to happen. val (argNames, paramNames, auxNames) = initSymbolParams(trainData) // init optimizer @@ -395,7 +412,6 @@ class FeedForward private( workLoadList = workLoadList, monitor = monitor, symGen = symGen) - } } /** @@ -422,9 +438,29 @@ class FeedForward private( def serialize(): Array[Byte] = { Model.serialize(this.symbol, getArgParams, getAuxParams) } + + // hack to make the FeedForward.scala work with ResourceScope and + // automatically release _argParms and _auxParms + override def nativeAddress: CPtrAddress = hashCode() + + override def nativeDeAllocator: CPtrAddress => Int = FeedForward.doNothingDeAllocator + + override val ref: NativeResourceRef = super.register() + + override val bytesAllocated: Long = 0L + + override def dispose(): Unit = { + if (!super.isDisposed) { + _argParams.foreach { case (_, param) => param.dispose() } + _auxParams.foreach { case (_, param) => param.dispose() } + } + } } object FeedForward { + + private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0 + private val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward]) // Check if name is a data argument. private def isDataArg(name: String): Boolean = {