diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 6e2d8d6e9cc7..d5396dab1e67 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -126,5 +126,12 @@
commons-io
2.1
+
+
+ org.mockito
+ mockito-all
+ 1.10.19
+ test
+
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index fc791d5cd9a3..19fb6fe5cee5 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -45,7 +45,7 @@ object Executor {
* @see Symbol.bind : to create executor
*/
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
- private[mxnet] val symbol: Symbol) extends WarnIfNotDisposed {
+ private[mxnet] val symbol: Symbol) extends NativeResource {
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
@@ -59,14 +59,15 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])
- private var disposed = false
- protected def isDisposed = disposed
-
- def dispose(): Unit = {
- if (!disposed) {
- outputs.foreach(_.dispose())
- _LIB.mxExecutorFree(handle)
- disposed = true
+ override def nativeAddress: CPtrAddress = handle
+ override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
+ // cannot determine the off-heap size of this object
+ override val bytesAllocated: Long = 0
+ override val ref: NativeResourceRef = super.register()
+ override def dispose(): Unit = {
+ if (!super.isDisposed) {
+ super.dispose()
+ outputs.foreach(o => o.dispose())
}
}
@@ -305,4 +306,5 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
checkCall(_LIB.mxExecutorPrint(handle, str))
str.value
}
+
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
index 8e89ce76b877..45189a13aefc 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
@@ -52,22 +52,17 @@ object KVStore {
}
}
-class KVStore(private[mxnet] val handle: KVStoreHandle) extends WarnIfNotDisposed {
+class KVStore(private[mxnet] val handle: KVStoreHandle) extends NativeResource {
private val logger: Logger = LoggerFactory.getLogger(classOf[KVStore])
private var updaterFunc: MXKVStoreUpdater = null
- private var disposed = false
- protected def isDisposed = disposed
- /**
- * Release the native memory.
- * The object shall never be used after it is disposed.
- */
- def dispose(): Unit = {
- if (!disposed) {
- _LIB.mxKVStoreFree(handle)
- disposed = true
- }
- }
+ override def nativeAddress: CPtrAddress = handle
+
+ override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxKVStoreFree
+
+ override val ref: NativeResourceRef = super.register()
+
+ override val bytesAllocated: Long = 0L
/**
* Initialize a single or a sequence of key-value pairs into the store.
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
index 4bb9cdd331a6..b835c4964dd0 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
@@ -259,7 +259,9 @@ object Model {
workLoadList: Seq[Float] = Nil,
monitor: Option[Monitor] = None,
symGen: SymbolGenerator = null): Unit = {
- val executorManager = new DataParallelExecutorManager(
+ ResourceScope.using() {
+
+ val executorManager = new DataParallelExecutorManager(
symbol = symbol,
symGen = symGen,
ctx = ctx,
@@ -269,17 +271,17 @@ object Model {
auxNames = auxNames,
workLoadList = workLoadList)
- monitor.foreach(executorManager.installMonitor)
- executorManager.setParams(argParams, auxParams)
+ monitor.foreach(executorManager.installMonitor)
+ executorManager.setParams(argParams, auxParams)
- // updater for updateOnKVStore = false
- val updaterLocal = Optimizer.getUpdater(optimizer)
+ // updater for updateOnKVStore = false
+ val updaterLocal = Optimizer.getUpdater(optimizer)
- kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
- argParams, executorManager.paramNames, updateOnKVStore))
- if (updateOnKVStore) {
- kvStore.foreach(_.setOptimizer(optimizer))
- }
+ kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
+ argParams, executorManager.paramNames, updateOnKVStore))
+ if (updateOnKVStore) {
+ kvStore.foreach(_.setOptimizer(optimizer))
+ }
// Now start training
for (epoch <- beginEpoch until endEpoch) {
@@ -290,45 +292,46 @@ object Model {
var epochDone = false
// Iterate over training data.
trainData.reset()
- while (!epochDone) {
- var doReset = true
- while (doReset && trainData.hasNext) {
- val dataBatch = trainData.next()
- executorManager.loadDataBatch(dataBatch)
- monitor.foreach(_.tic())
- executorManager.forward(isTrain = true)
- executorManager.backward()
- if (updateOnKVStore) {
- updateParamsOnKVStore(executorManager.paramArrays,
- executorManager.gradArrays,
- kvStore, executorManager.paramNames)
- } else {
- updateParams(executorManager.paramArrays,
- executorManager.gradArrays,
- updaterLocal, ctx.length,
- executorManager.paramNames,
- kvStore)
- }
- monitor.foreach(_.tocPrint())
- // evaluate at end, so out_cpu_array can lazy copy
- executorManager.updateMetric(evalMetric, dataBatch.label)
+ ResourceScope.using() {
+ while (!epochDone) {
+ var doReset = true
+ while (doReset && trainData.hasNext) {
+ val dataBatch = trainData.next()
+ executorManager.loadDataBatch(dataBatch)
+ monitor.foreach(_.tic())
+ executorManager.forward(isTrain = true)
+ executorManager.backward()
+ if (updateOnKVStore) {
+ updateParamsOnKVStore(executorManager.paramArrays,
+ executorManager.gradArrays,
+ kvStore, executorManager.paramNames)
+ } else {
+ updateParams(executorManager.paramArrays,
+ executorManager.gradArrays,
+ updaterLocal, ctx.length,
+ executorManager.paramNames,
+ kvStore)
+ }
+ monitor.foreach(_.tocPrint())
+ // evaluate at end, so out_cpu_array can lazy copy
+ executorManager.updateMetric(evalMetric, dataBatch.label)
- nBatch += 1
- batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))
+ nBatch += 1
+ batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))
- // this epoch is done possibly earlier
- if (epochSize != -1 && nBatch >= epochSize) {
- doReset = false
+ // this epoch is done possibly earlier
+ if (epochSize != -1 && nBatch >= epochSize) {
+ doReset = false
+ }
+ }
+ if (doReset) {
+ trainData.reset()
}
- }
- if (doReset) {
- trainData.reset()
- }
- // this epoch is done
- epochDone = (epochSize == -1 || nBatch >= epochSize)
+ // this epoch is done
+ epochDone = (epochSize == -1 || nBatch >= epochSize)
+ }
}
-
val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-$n=$v")
@@ -336,20 +339,22 @@ object Model {
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
- evalData.foreach { evalDataIter =>
- evalMetric.reset()
- evalDataIter.reset()
- // TODO: make DataIter implement Iterator
- while (evalDataIter.hasNext) {
- val evalBatch = evalDataIter.next()
- executorManager.loadDataBatch(evalBatch)
- executorManager.forward(isTrain = false)
- executorManager.updateMetric(evalMetric, evalBatch.label)
- }
+ ResourceScope.using() {
+ evalData.foreach { evalDataIter =>
+ evalMetric.reset()
+ evalDataIter.reset()
+ // TODO: make DataIter implement Iterator
+ while (evalDataIter.hasNext) {
+ val evalBatch = evalDataIter.next()
+ executorManager.loadDataBatch(evalBatch)
+ executorManager.forward(isTrain = false)
+ executorManager.updateMetric(evalMetric, evalBatch.label)
+ }
- val (name, value) = evalMetric.get
- name.zip(value).foreach { case (n, v) =>
- logger.info(s"Epoch[$epoch] Train-$n=$v")
+ val (name, value) = evalMetric.get
+ name.zip(value).foreach { case (n, v) =>
+ logger.info(s"Epoch[$epoch] Validation-$n=$v")
+ }
}
}
@@ -359,8 +364,7 @@ object Model {
epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams))
}
- updaterLocal.dispose()
- executorManager.dispose()
+ }
}
// scalastyle:on parameterNum
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 9b6a7dc66540..f2a7603caa85 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -562,16 +562,20 @@ object NDArray extends NDArrayBase {
*/
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
val writable: Boolean = true,
- addToCollector: Boolean = true) extends WarnIfNotDisposed {
+ addToCollector: Boolean = true) extends NativeResource {
if (addToCollector) {
NDArrayCollector.collect(this)
}
+ override def nativeAddress: CPtrAddress = handle
+ override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree
+ override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product
+
+ override val ref: NativeResourceRef = super.register()
+
// record arrays who construct this array instance
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
- @volatile private var disposed = false
- def isDisposed: Boolean = disposed
def serialize(): Array[Byte] = {
val buf = ArrayBuffer.empty[Byte]
@@ -584,11 +588,10 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* The NDArrays it depends on will NOT be disposed.
* The object shall never be used after it is disposed.
*/
- def dispose(): Unit = {
- if (!disposed) {
- _LIB.mxNDArrayFree(handle)
+ override def dispose(): Unit = {
+ if (!super.isDisposed) {
+ super.dispose()
dependencies.clear()
- disposed = true
}
}
@@ -1034,6 +1037,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
// TODO: naive implementation
shape.hashCode + toArray.hashCode
}
+
}
private[mxnet] object NDArrayConversions {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
new file mode 100644
index 000000000000..48d4b0c193b1
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.Base.CPtrAddress
+import java.lang.ref.{PhantomReference, ReferenceQueue, WeakReference}
+import java.util.concurrent._
+
+import org.apache.mxnet.Base.checkCall
+import java.util.concurrent.atomic.AtomicLong
+
+
+/**
+ * NativeResource trait is used to manage MXNet Objects
+ * such as NDArray, Symbol, Executor, etc.,
+ * The MXNet Object calls NativeResource.register
+ * and assign the returned NativeResourceRef to PhantomReference
+ * NativeResource also implements AutoCloseable so MXNetObjects
+ * can be used like Resources in try-with-resources paradigm
+ */
+private[mxnet] trait NativeResource
+ extends AutoCloseable with WarnIfNotDisposed {
+
+ /**
+ * native Address associated with this object
+ */
+ def nativeAddress: CPtrAddress
+
+ /**
+ * Function Pointer to the NativeDeAllocator of nativeAddress
+ */
+ def nativeDeAllocator: (CPtrAddress => Int)
+
+ /** Call NativeResource.register to get the reference
+ */
+ val ref: NativeResourceRef
+
+ /**
+ * Off-Heap Bytes Allocated for this object
+ */
+ // intentionally making it a val, so it gets evaluated when defined
+ val bytesAllocated: Long
+
+ private[mxnet] var scope: Option[ResourceScope] = None
+
+ @volatile private var disposed = false
+
+ override def isDisposed: Boolean = disposed || isDeAllocated
+
+ /**
+ * Register this object for PhantomReference tracking and in
+ * ResourceScope if used inside ResourceScope.
+ * @return NativeResourceRef that tracks reachability of this object
+ * using PhantomReference
+ */
+ def register(): NativeResourceRef = {
+ scope = ResourceScope.getCurrentScope()
+ if (scope.isDefined) scope.get.add(this)
+
+ NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
+ // register with PhantomRef tracking to release incase the objects go
+ // out of reference within scope but are held for long time
+ NativeResourceRef.register(this, nativeDeAllocator)
+ }
+
+ // Implements [[@link AutoCloseable.close]]
+ override def close(): Unit = {
+ dispose()
+ }
+
+ // Implements [[@link WarnIfNotDisposed.dispose]]
+ def dispose(): Unit = dispose(true)
+
+ /**
+ * This method deAllocates nativeResource and deRegisters
+ * from PhantomRef and removes from Scope if
+ * removeFromScope is set to true.
+ * @param removeFromScope remove from the currentScope if true
+ */
+ // the parameter here controls whether to remove from current scope.
+ // [[ResourceScope.close]] calls NativeResource.dispose
+ // if we remove from the ResourceScope ie., from the container in ResourceScope.
+ // while iterating on the container, calling iterator.next is undefined and not safe.
+ // Note that ResourceScope automatically disposes all the resources within.
+ private[mxnet] def dispose(removeFromScope: Boolean = true): Unit = {
+ if (!disposed) {
+ checkCall(nativeDeAllocator(this.nativeAddress))
+ NativeResourceRef.deRegister(ref) // removes from PhantomRef tracking
+ if (removeFromScope && scope.isDefined) scope.get.remove(this)
+ NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated)
+ disposed = true
+ }
+ }
+
+ /*
+ this is used by the WarnIfNotDisposed finalizer,
+ the object could be disposed by the GC without the need for explicit disposal
+ but the finalizer might not have run, then the WarnIfNotDisposed throws a warning
+ */
+ private[mxnet] def isDeAllocated(): Boolean = NativeResourceRef.isDeAllocated(ref)
+
+}
+
+private[mxnet] object NativeResource {
+ var totalBytesAllocated : AtomicLong = new AtomicLong(0)
+}
+
+// Do not make [[NativeResource.resource]] a member of the class,
+// this will hold reference and GC will not clear the object.
+private[mxnet] class NativeResourceRef(resource: NativeResource,
+ val resourceDeAllocator: CPtrAddress => Int)
+ extends PhantomReference[NativeResource](resource, NativeResourceRef.refQ) {}
+
+private[mxnet] object NativeResourceRef {
+
+ private[mxnet] val refQ: ReferenceQueue[NativeResource]
+ = new ReferenceQueue[NativeResource]
+
+ private[mxnet] val refMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]()
+
+ private[mxnet] val cleaner = new ResourceCleanupThread()
+
+ cleaner.start()
+
+ def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)):
+ NativeResourceRef = {
+ val ref = new NativeResourceRef(resource, nativeDeAllocator)
+ refMap.put(ref, resource.nativeAddress)
+ ref
+ }
+
+ // remove from PhantomRef tracking
+ def deRegister(ref: NativeResourceRef): Unit = refMap.remove(ref)
+
+ /**
+ * This method will check if the cleaner ran and deAllocated the object
+ * As a part of GC, when the object is unreachable GC inserts a phantomRef
+ * to the ReferenceQueue which the cleaner thread will deallocate, however
+ * the finalizer runs much later depending on the GC.
+ * @param resource resource to verify if it has been deAllocated
+ * @return true if already deAllocated
+ */
+ def isDeAllocated(ref: NativeResourceRef): Boolean = {
+ !refMap.containsKey(ref)
+ }
+
+ def cleanup: Unit = {
+ // remove is a blocking call
+ val ref: NativeResourceRef = refQ.remove().asInstanceOf[NativeResourceRef]
+ // phantomRef will be removed from the map when NativeResource.close is called.
+ val resource = refMap.get(ref)
+ if (resource != 0L) { // since CPtrAddress is Scala a Long, it cannot be null
+ ref.resourceDeAllocator(resource)
+ refMap.remove(ref)
+ }
+ }
+
+ protected class ResourceCleanupThread extends Thread {
+ setPriority(Thread.MAX_PRIORITY)
+ setName("NativeResourceDeAllocatorThread")
+ setDaemon(true)
+
+ override def run(): Unit = {
+ while (true) {
+ try {
+ NativeResourceRef.cleanup
+ }
+ catch {
+ case _: InterruptedException => Thread.currentThread().interrupt()
+ }
+ }
+ }
+ }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
index 758cbc829618..c3f8aaec6d60 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
@@ -19,6 +19,8 @@ package org.apache.mxnet
import java.io._
+import org.apache.mxnet.Base.CPtrAddress
+
import scala.collection.mutable
import scala.util.Either
@@ -38,8 +40,10 @@ object Optimizer {
}
override def dispose(): Unit = {
- states.values.foreach(optimizer.disposeState)
- states.clear()
+ if (!super.isDisposed) {
+ states.values.foreach(optimizer.disposeState)
+ states.clear()
+ }
}
override def serializeState(): Array[Byte] = {
@@ -285,7 +289,8 @@ abstract class Optimizer extends Serializable {
}
}
-trait MXKVStoreUpdater {
+trait MXKVStoreUpdater extends
+ NativeResource {
/**
* user-defined updater for the kvstore
* It's this updater's responsibility to delete recv and local
@@ -294,9 +299,14 @@ trait MXKVStoreUpdater {
* @param local the value stored on local on this key
*/
def update(key: Int, recv: NDArray, local: NDArray): Unit
- def dispose(): Unit
- // def serializeState(): Array[Byte]
- // def deserializeState(bytes: Array[Byte]): Unit
+
+ // This is a hack to make Optimizers work with ResourceScope
+ // otherwise the user has to manage calling dispose on this object.
+ override def nativeAddress: CPtrAddress = hashCode()
+ override def nativeDeAllocator: CPtrAddress => Int = doNothingDeAllocator
+ private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0
+ override val ref: NativeResourceRef = super.register()
+ override val bytesAllocated: Long = 0L
}
trait MXKVStoreCachedStates {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
new file mode 100644
index 000000000000..1c5782d873a9
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.util.HashSet
+
+import org.slf4j.LoggerFactory
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Try
+import scala.util.control.{ControlThrowable, NonFatal}
+
+/**
+ * This class manages automatically releasing of [[NativeResource]]s
+ */
+class ResourceScope extends AutoCloseable {
+
+ // HashSet does not take a custom comparator
+ private[mxnet] val resourceQ = new mutable.TreeSet[NativeResource]()(nativeAddressOrdering)
+
+ private object nativeAddressOrdering extends Ordering[NativeResource] {
+ def compare(a: NativeResource, b: NativeResource): Int = {
+ a.nativeAddress compare b.nativeAddress
+ }
+ }
+
+ ResourceScope.addToThreadLocal(this)
+
+ /**
+ * Releases all the [[NativeResource]] by calling
+ * the associated [[NativeResource.close()]] method
+ */
+ override def close(): Unit = {
+ ResourceScope.removeFromThreadLocal(this)
+ resourceQ.foreach(resource => if (resource != null) resource.dispose(false) )
+ resourceQ.clear()
+ }
+
+ /**
+ * Add a NativeResource to the scope
+ * @param resource
+ */
+ def add(resource: NativeResource): Unit = {
+ resourceQ.+=(resource)
+ }
+
+ /**
+ * Remove NativeResource from the Scope, this uses
+ * object equality to find the resource in the stack.
+ * @param resource
+ */
+ def remove(resource: NativeResource): Unit = {
+ resourceQ.-=(resource)
+ }
+}
+
+object ResourceScope {
+
+ private val logger = LoggerFactory.getLogger(classOf[ResourceScope])
+
+ /**
+ * Captures all Native Resources created using the ResourceScope and
+ * at the end of the body, de allocates all the Native resources by calling close on them.
+ * This method will not deAllocate NativeResources returned from the block.
+ * @param scope (Optional). Scope in which to capture the native resources
+ * @param body block of code to execute in this scope
+ * @tparam A return type
+ * @return result of the operation, if the result is of type NativeResource, it is not
+ * de allocated so the user can use it and then de allocate manually by calling
+ * close or enclose in another resourceScope.
+ */
+ // inspired from slide 21 of https://www.slideshare.net/Odersky/fosdem-2009-1013261
+ // and https://github.com/scala/scala/blob/2.13.x/src/library/scala/util/Using.scala
+ // TODO: we should move to the Scala util's Using method when we move to Scala 2.13
+ def using[A](scope: ResourceScope = null)(body: => A): A = {
+
+ val curScope = if (scope != null) scope else new ResourceScope()
+
+ val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
+
+ @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = {
+ g.foreach( n =>
+ n match {
+ case nRes: NativeResource => {
+ removeAndAddToPrevScope(nRes)
+ }
+ case kv: scala.Tuple2[_, _] => {
+ if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+ kv._1.asInstanceOf[NativeResource])
+ if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+ kv._2.asInstanceOf[NativeResource])
+ }
+ }
+ )
+ }
+
+ @inline def removeAndAddToPrevScope(r: NativeResource) = {
+ curScope.remove(r)
+ if (prevScope.isDefined) {
+ prevScope.get.add(r)
+ r.scope = prevScope
+ }
+ }
+
+ @inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = {
+ if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed)
+ }
+
+ var retThrowable: Throwable = null
+
+ try {
+ val ret = body
+ ret match {
+ // don't de-allocate if returning any collection that contains NativeResource.
+ case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric)
+ case nRes: NativeResource => removeAndAddToPrevScope(nRes)
+ case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => removeAndAddToPrevScope(nd) )
+ case _ => // do nothing
+ }
+ ret
+ } catch {
+ case t: Throwable =>
+ retThrowable = t
+ null.asInstanceOf[A] // we'll throw in finally
+ } finally {
+ var toThrow: Throwable = retThrowable
+ if (retThrowable eq null) curScope.close()
+ else {
+ try {
+ curScope.close
+ } catch {
+ case closeThrowable: Throwable =>
+ if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = closeThrowable
+ else safeAddSuppressed(retThrowable, closeThrowable)
+ } finally {
+ throw toThrow
+ }
+ }
+ }
+ }
+
+ // thread local Scopes
+ private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] {
+ override def initialValue(): ArrayBuffer[ResourceScope] =
+ new ArrayBuffer[ResourceScope]()
+ }
+
+ /**
+ * Add resource to current ThreadLocal DataStructure
+ * @param r ResourceScope to add.
+ */
+ private[mxnet] def addToThreadLocal(r: ResourceScope): Unit = {
+ threadLocalScopes.get() += r
+ }
+
+ /**
+ * Remove resource from current ThreadLocal DataStructure
+ * @param r ResourceScope to remove
+ */
+ private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = {
+ threadLocalScopes.get() -= r
+ }
+
+ /**
+ * Get the latest Scope in the stack
+ * @return
+ */
+ private[mxnet] def getCurrentScope(): Option[ResourceScope] = {
+ Try(Some(threadLocalScopes.get().last)).getOrElse(None)
+ }
+
+ /**
+ * Get the Last but one Scope from threadLocal Scopes.
+ * @return n-1th scope or None when not found
+ */
+ private[mxnet] def getPrevScope(): Option[ResourceScope] = {
+ val scopes = threadLocalScopes.get()
+ Try(Some(scopes(scopes.size - 2))).getOrElse(None)
+ }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index b1a3e392f41e..a009e7e343f2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -29,21 +29,15 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
* WARNING: it is your responsibility to clear this object through dispose().
*
*/
-class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotDisposed {
+class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeResource {
private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol])
- private var disposed = false
- protected def isDisposed = disposed
- /**
- * Release the native memory.
- * The object shall never be used after it is disposed.
- */
- def dispose(): Unit = {
- if (!disposed) {
- _LIB.mxSymbolFree(handle)
- disposed = true
- }
- }
+ // unable to get the byteAllocated for Symbol
+ override val bytesAllocated: Long = 0L
+ override def nativeAddress: CPtrAddress = handle
+ override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxSymbolFree
+ override val ref: NativeResourceRef = super.register()
+
def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other))
def +[@specialized(Int, Float, Double) V](other: V): Symbol = {
@@ -793,7 +787,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
}
val execHandle = new ExecutorHandleRef
- val sharedHadle = if (sharedExec != null) sharedExec.handle else 0L
+ val sharedHandle = if (sharedExec != null) sharedExec.handle else 0L
checkCall(_LIB.mxExecutorBindEX(handle,
ctx.deviceTypeid,
ctx.deviceId,
@@ -806,7 +800,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
argsGradHandle,
reqsArray,
auxArgsHandle,
- sharedHadle,
+ sharedHandle,
execHandle))
val executor = new Executor(execHandle.value, this.clone())
executor.argArrays = argsNDArray
@@ -832,6 +826,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
checkCall(_LIB.mxSymbolSaveToJSON(handle, jsonStr))
jsonStr.value
}
+
}
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index f7f858deb82d..998017750db2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -33,7 +33,7 @@ import scala.collection.mutable.ListBuffer
private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
dataName: String = "data",
labelName: String = "label")
- extends DataIter with WarnIfNotDisposed {
+ extends DataIter with NativeResource {
private val logger = LoggerFactory.getLogger(classOf[MXDataIter])
@@ -67,20 +67,13 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
}
}
+ override def nativeAddress: CPtrAddress = handle
- private var disposed = false
- protected def isDisposed = disposed
+ override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxDataIterFree
- /**
- * Release the native memory.
- * The object shall never be used after it is disposed.
- */
- def dispose(): Unit = {
- if (!disposed) {
- _LIB.mxDataIterFree(handle)
- disposed = true
- }
- }
+ override val ref: NativeResourceRef = super.register()
+
+ override val bytesAllocated: Long = 0L
/**
* reset the iterator
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
index e20b433ed1ed..d349feac3e93 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
@@ -17,7 +17,7 @@
package org.apache.mxnet.optimizer
-import org.apache.mxnet.{Optimizer, LRScheduler, NDArray}
+import org.apache.mxnet._
import org.apache.mxnet.NDArrayConversions._
/**
@@ -92,7 +92,13 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
if (momentum == 0.0f) {
null
} else {
- NDArray.zeros(weight.shape, weight.context)
+ val s = NDArray.zeros(weight.shape, weight.context)
+ // this is created on the fly and shared between runs,
+ // we don't want it to be dispose from the scope
+ // and should be handled by the dispose
+ val scope = ResourceScope.getCurrentScope()
+ if (scope.isDefined) scope.get.remove(s)
+ s
}
}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala
new file mode 100644
index 000000000000..81a9f605a887
--- /dev/null
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.lang.ref.ReferenceQueue
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.mxnet.Base.CPtrAddress
+import org.mockito.Matchers.any
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, TagAnnotation}
+import org.mockito.Mockito._
+
+@TagAnnotation("resource")
+class NativeResourceSuite extends FunSuite with BeforeAndAfterAll with Matchers {
+
+ object TestRef {
+ def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ}
+ def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress]
+ = {NativeResourceRef.refMap}
+ def getCleaner: Thread = { NativeResourceRef.cleaner }
+ }
+
+ class TestRef(resource: NativeResource,
+ resourceDeAllocator: CPtrAddress => Int)
+ extends NativeResourceRef(resource, resourceDeAllocator) {
+ }
+
+ test(testName = "test native resource setup/teardown") {
+ val a = spy(NDArray.ones(Shape(2, 3)))
+ val aRef = a.ref
+ val spyRef = spy(aRef)
+
+ assert(TestRef.getRefMap.containsKey(aRef) == true)
+ a.close()
+ verify(a).dispose()
+ verify(a).nativeDeAllocator
+ // resourceDeAllocator does not get called when explicitly closing
+ verify(spyRef, times(0)).resourceDeAllocator
+
+ assert(TestRef.getRefMap.containsKey(aRef) == false)
+ assert(a.isDisposed == true, "isDisposed should be set to true after calling close")
+ }
+
+ test(testName = "test dispose") {
+ val a: NDArray = spy(NDArray.ones(Shape(3, 4)))
+ val aRef = a.ref
+ val spyRef = spy(aRef)
+ a.dispose()
+ verify(a).nativeDeAllocator
+ assert(TestRef.getRefMap.containsKey(aRef) == false)
+ assert(a.isDisposed == true, "isDisposed should be set to true after calling close")
+ }
+}
+
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
new file mode 100644
index 000000000000..41dfa7d0ead2
--- /dev/null
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.lang.ref.ReferenceQueue
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.mxnet.Base.CPtrAddress
+import org.apache.mxnet.ResourceScope.logger
+import org.mockito.Matchers.any
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+import org.mockito.Mockito._
+import scala.collection.mutable.HashMap
+
+class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers {
+
+ class TestNativeResource extends NativeResource {
+ /**
+ * native Address associated with this object
+ */
+ override def nativeAddress: CPtrAddress = hashCode()
+
+ /**
+ * Function Pointer to the NativeDeAllocator of nativeAddress
+ */
+ override def nativeDeAllocator: CPtrAddress => Int = TestNativeResource.deAllocator
+
+ /** Call NativeResource.register to get the reference
+ */
+ override val ref: NativeResourceRef = super.register()
+ /**
+ * Off-Heap Bytes Allocated for this object
+ */
+ override val bytesAllocated: Long = 0
+ }
+ object TestNativeResource {
+ def deAllocator(handle: CPtrAddress): Int = 0
+ }
+
+ object TestPhantomRef {
+ def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ}
+ def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress]
+ = {NativeResourceRef.refMap}
+ def getCleaner: Thread = { NativeResourceRef.cleaner }
+
+ }
+
+ class TestPhantomRef(resource: NativeResource,
+ resourceDeAllocator: CPtrAddress => Int)
+ extends NativeResourceRef(resource, resourceDeAllocator) {
+ }
+
+ test(testName = "test NDArray Auto Release") {
+ var a: NDArray = null
+ var aRef: NativeResourceRef = null
+ var b: NDArray = null
+
+ ResourceScope.using() {
+ b = ResourceScope.using() {
+ a = NDArray.ones(Shape(3, 4))
+ aRef = a.ref
+ val x = NDArray.ones(Shape(3, 4))
+ x
+ }
+ val bRef: NativeResourceRef = b.ref
+ assert(a.isDisposed == true,
+ "objects created within scope should have isDisposed set to true")
+ assert(b.isDisposed == false,
+ "returned NativeResource should not be released")
+ assert(TestPhantomRef.getRefMap.containsKey(aRef) == false,
+ "reference of resource in Scope should be removed refMap")
+ assert(TestPhantomRef.getRefMap.containsKey(bRef) == true,
+ "reference of resource outside scope should be not removed refMap")
+ }
+ assert(b.isDisposed, "resource returned from inner scope should be released in outer scope")
+ }
+
+ test("test return object release from outer scope") {
+ var a: TestNativeResource = null
+ ResourceScope.using() {
+ a = ResourceScope.using() {
+ new TestNativeResource()
+ }
+ assert(a.isDisposed == false, "returned object should not be disposed within Using")
+ }
+ assert(a.isDisposed == true, "returned object should be disposed in the outer scope")
+ }
+
+ test(testName = "test NativeResources in returned Lists are not disposed") {
+ var ndListRet: IndexedSeq[TestNativeResource] = null
+ ResourceScope.using() {
+ ndListRet = ResourceScope.using() {
+ val ndList: IndexedSeq[TestNativeResource] =
+ IndexedSeq(new TestNativeResource(), new TestNativeResource())
+ ndList
+ }
+ ndListRet.foreach(nd => assert(nd.isDisposed == false,
+ "NativeResources within a returned collection should not be disposed"))
+ }
+ ndListRet.foreach(nd => assert(nd.isDisposed == true,
+ "NativeResources returned from inner scope should be disposed in outer scope"))
+ }
+
+ test("test native resource inside a map") {
+ var nRInKeyOfMap: HashMap[TestNativeResource, String] = null
+ var nRInValOfMap: HashMap[String, TestNativeResource] = HashMap[String, TestNativeResource]()
+
+ ResourceScope.using() {
+ nRInKeyOfMap = ResourceScope.using() {
+ val ret = HashMap[TestNativeResource, String]()
+ ret.put(new TestNativeResource, "hello")
+ ret
+ }
+ assert(!nRInKeyOfMap.isEmpty)
+
+ nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed == false,
+ "NativeResources returned in Traversable should not be disposed"))
+ }
+
+ nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed))
+
+ ResourceScope.using() {
+
+ nRInValOfMap = ResourceScope.using() {
+ val ret = HashMap[String, TestNativeResource]()
+ ret.put("world!", new TestNativeResource)
+ ret
+ }
+ assert(!nRInValOfMap.isEmpty)
+ nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed == false,
+ "NativeResources returned in Collection should not be disposed"))
+ }
+ nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed))
+ }
+
+}
diff --git a/scala-package/examples/scripts/run_train_mnist.sh b/scala-package/examples/scripts/run_train_mnist.sh
new file mode 100755
index 000000000000..ea53c1ade66f
--- /dev/null
+++ b/scala-package/examples/scripts/run_train_mnist.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd)
+echo $MXNET_ROOT
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+
+# model dir
+DATA_PATH=$2
+
+java -XX:+PrintGC -Xms256M -Xmx512M -Dmxnet.traceLeakedObjects=false -cp $CLASS_PATH \
+ org.apache.mxnetexamples.imclassification.TrainMnist \
+ --data-dir /home/ubuntu/mxnet_scala/scala-package/examples/mnist/ \
+ --num-epochs 10000000 \
+ --batch-size 1024
\ No newline at end of file