Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix memory leak in FeedForward.scala by making it a native resource a…
Browse files Browse the repository at this point in the history
…nd disposing argparams, auxParams

in dispose() method
  • Loading branch information
nswamy committed Oct 23, 2018
1 parent 8b26de6 commit f1ab955
Showing 1 changed file with 44 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.mxnet

import org.apache.mxnet.Base.CPtrAddress
import org.apache.mxnet.io.NDArrayIter
import org.apache.mxnet.optimizer.SGD
import org.slf4j.{LoggerFactory, Logger}
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -55,7 +56,7 @@ class FeedForward private(
argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
private val allowExtraParams: Boolean,
val beginEpoch: Int) {
val beginEpoch: Int) extends NativeResource {

val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
private var argumentChecked = false
Expand Down Expand Up @@ -126,6 +127,8 @@ class FeedForward private(
}

// Initialize weight parameters and auxiliary states
// The NDArrays associated with the _argParms and _auxParams are not disposed instead
// they are passed a outer scope if available.
private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
: (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = {
val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
Expand All @@ -137,16 +140,26 @@ class FeedForward private(
val paramNameShapes = (argNames zip argShapes).filter { case (name, _) =>
paramNames.contains(name)
}
val argParams = paramNameShapes.map { case (name, shape) =>
(name, NDArray.zeros(shape))
val argParams = paramNameShapes.map { case (name, shape) => {
val param = NDArray.zeros(shape)
val curScope = ResourceScope.getCurrentScope()
if (curScope.isDefined) curScope.get.moveToOuterScope(param)
(name, param)
}
}.toMap
val auxParams = (auxNames zip auxShapes).map { case (name, shape) =>
(name, NDArray.zeros(shape))

val auxParams = (auxNames zip auxShapes).map { case (name, shape) => {
val param = NDArray.zeros(shape)
val curScope = ResourceScope.getCurrentScope()
if (curScope.isDefined) curScope.get.moveToOuterScope(param)
(name, param)
}
}.toMap

for ((k, v) <- argParams) {
if (_argParams != null && _argParams.contains(k) && (!overwrite)) {
argParams(k).set(_argParams(k))

} else {
initializer(k, v)
}
Expand Down Expand Up @@ -356,7 +369,11 @@ class FeedForward private(
batchEndCallback: BatchEndCallback = null, logger: Logger = FeedForward.logger,
workLoadList: Seq[Float] = null): Unit = {
require(evalMetric != null, "evalMetric cannot be null")
ResourceScope.using() {
// TODO: https://issues.apache.org/jira/browse/MXNET-1171
// this leaks memory, initSymbolParams->initParams is already called which allocates
// NDArray in argParams, auxParams and here we are overwriting it by calling again.
// PhantomRef should take care of releasing this when GC is called, however we have to
// wait for the GC call to happen.
val (argNames, paramNames, auxNames) = initSymbolParams(trainData)

// init optimizer
Expand Down Expand Up @@ -395,7 +412,6 @@ class FeedForward private(
workLoadList = workLoadList,
monitor = monitor,
symGen = symGen)
}
}

/**
Expand All @@ -422,9 +438,29 @@ class FeedForward private(
def serialize(): Array[Byte] = {
Model.serialize(this.symbol, getArgParams, getAuxParams)
}

// hack to make the FeedForward.scala work with ResourceScope and
// automatically release _argParms and _auxParms
override def nativeAddress: CPtrAddress = hashCode()

override def nativeDeAllocator: CPtrAddress => Int = FeedForward.doNothingDeAllocator

override val ref: NativeResourceRef = super.register()

override val bytesAllocated: Long = 0L

override def dispose(): Unit = {
if (!super.isDisposed) {
_argParams.foreach { case (_, param) => param.dispose() }
_auxParams.foreach { case (_, param) => param.dispose() }
}
}
}

object FeedForward {

private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0

private val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
// Check if name is a data argument.
private def isDataArg(name: String): Boolean = {
Expand Down

0 comments on commit f1ab955

Please sign in to comment.