Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update NativeResource/add Unit Test for NativeResource
Browse files Browse the repository at this point in the history
  • Loading branch information
nswamy committed Sep 24, 2018
1 parent aea6c81 commit 0bb098d
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 80 deletions.
7 changes: 7 additions & 0 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,12 @@
<artifactId>commons-io</artifactId>
<version>2.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])

override def nativeAddress: CPtrAddress = handle
override def nativeResource: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
override def bytesAllocated: Long = 0
override val bytesAllocated: Long = 0
override val phantomRef: NativeResourceRef = super.register()

/**
Expand Down
15 changes: 12 additions & 3 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ class DataBatch(val data: IndexedSeq[NDArray],
// use DataDesc to indicate the order of data/label loading
// (must match the order of input data/label)
private val providedDataDesc: IndexedSeq[DataDesc],
private val providedLabelDesc: IndexedSeq[DataDesc]) {
private val providedLabelDesc: IndexedSeq[DataDesc]) extends
NativeResource {
// TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)]
// However, since the data and label can be accessed publicly (no getter and setter)
// the change on this will break BC
Expand All @@ -162,17 +163,26 @@ class DataBatch(val data: IndexedSeq[NDArray],
this(data, label, index, pad, bucketKey,
DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel))
}

// overriding here so DataBatch gets added to Scope and can be disposed
override def nativeResource: CPtrAddress = 0
override def nativeDeAllocator: CPtrAddress => MXUint = doNothingDeAllocator
def doNothingDeAllocator(x: CPtrAddress): MXUint = {0}
override val phantomRef: NativeResourceRef = super.register()
override val bytesAllocated: DataIterCreator = 0

/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
override def dispose(): Unit = {
if (data != null) {
data.foreach(arr => if (arr != null) arr.dispose())
}
if (label != null) {
label.foreach(arr => if (arr != null) arr.dispose())
}
super.dispose()
}

// The name and shape of data
Expand All @@ -198,7 +208,6 @@ class DataBatch(val data: IndexedSeq[NDArray],
def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc

def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc

}

object DataBatch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,9 +567,9 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArrayCollector.collect(this)
}

override def nativeAddress: CPtrAddress = handle
override def nativeResource: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree
override def bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product
override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product

override val phantomRef: NativeResourceRef = super.register()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ import java.lang.ref.{PhantomReference, ReferenceQueue, WeakReference}
import java.util.concurrent._

import org.apache.mxnet.Base.checkCall
import java.lang.{AutoCloseable, ThreadLocal}
import java.util.concurrent.atomic.AtomicLong

import org.apache.mxnet.NativeResourceRef.phantomRefMap
import org.slf4j.{Logger, LoggerFactory}

/**
* NativeResource trait is used to manage MXNet Objects
* such as NDArray, Symbol, Executor, etc.,
* The MXNet Object calls {@link NativeResource.register}
* and assign the returned NativeResourceRef to {@link phantomRef}
* and assign the returned NativeResourceRef to {@link PhantomReference}
* NativeResource also implements AutoCloseable so MXNetObjects
* can be used like Resources in try-with-resources paradigm
*/
Expand All @@ -40,13 +41,11 @@ private[mxnet] trait NativeResource

/**
* native Address associated with this object
* @return
*/
def nativeAddress: CPtrAddress
def nativeResource: CPtrAddress

/**
* Function Pointer to the NativeDeAllocator of {@link nativeAddress}
* @return
*/
def nativeDeAllocator: (CPtrAddress => Int)

Expand All @@ -56,9 +55,9 @@ private[mxnet] trait NativeResource

/**
* Off-Heap Bytes Allocated for this object
* @return
*/
def bytesAllocated: Long
// intentionally making it a val, so it gets evaluated when defined
val bytesAllocated: Long

private var scope: ResourceScope = null

Expand All @@ -72,11 +71,10 @@ private[mxnet] trait NativeResource
* using PhantomReference
*/
def register(): NativeResourceRef = {

scope = ResourceScope.getScope()
if (scope != null) {
scope.register(this)
}
if (scope != null) scope.register(this)

NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
// register with PhantomRef tracking to release incase the objects go
// out of reference within scope but are held for long time
NativeResourceRef.register(this, nativeDeAllocator)
Expand All @@ -91,93 +89,83 @@ private[mxnet] trait NativeResource
if (scope != null && removeFromScope) scope.deRegister(this)
}

/**
* Implements {@link AutoCloseable.close}
*/
// Implements {@link AutoCloseable.close}
override def close(): Unit = {
dispose()
deRegister(true)
}

/**
* Implements {@link WarnIfNotDisposed.dispose}
*/
// Implements {@link WarnIfNotDisposed.dispose}
def dispose(): Unit = {
dispose(true)
}

def dispose(removeFromScope: Boolean): Unit = {
if (!disposed) {
print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress))
checkCall(nativeDeAllocator(this.nativeAddress))
print("NativeResource: Disposing NativeResource:%x\n".format(nativeResource))
checkCall(nativeDeAllocator(this.nativeResource))
deRegister(removeFromScope)
NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated)
disposed = true
}
}

override protected def finalize(): Unit = {
if (!isDisposed) {
print("LEAK: %x\n".format(this.nativeAddress))
super.finalize()
}
}
}
// scalastyle:on finalize

// do not make nativeRes a member, this will hold reference and GC will not clear the object.
private[mxnet] object NativeResource {
var totalBytesAllocated : AtomicLong = new AtomicLong(0)
}
// do not make resource a member, this will hold reference and GC will not clear the object.
private[mxnet] class NativeResourceRef(resource: NativeResource,
val resourceDeAllocator: CPtrAddress => Int)
extends PhantomReference[NativeResource](resource, NativeResourceRef.referenceQueue) {
}
extends PhantomReference[NativeResource](resource, NativeResourceRef.referenceQueue) {}

private[mxnet] object NativeResourceRef {

private val referenceQueue: ReferenceQueue[NativeResource] = new ReferenceQueue[NativeResource]
private[mxnet] val referenceQueue: ReferenceQueue[NativeResource]
= new ReferenceQueue[NativeResource]

private val phantomRefMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]()
private[mxnet] val phantomRefMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]()

private val cleanupThread = new ResourceCleanupThread()
private[mxnet] val cleaner = new ResourceCleanupThread()

cleanupThread.start()
cleaner.start()

def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)):
NativeResourceRef = {
val resourceRef = new NativeResourceRef(resource, nativeDeAllocator)
phantomRefMap.put(resourceRef, resource.nativeAddress)
phantomRefMap.put(resourceRef, resource.nativeResource)
resourceRef
}

def deRegister(resourceRef: NativeResourceRef): Unit = {
val nativeDeAllocator = phantomRefMap.get(resourceRef)
if (nativeDeAllocator != 0L) { // since CPtrAddress is Scala Long, it cannot be null
if (phantomRefMap.containsKey(resourceRef)) {
phantomRefMap.remove(resourceRef)
}
}

def cleanup(): Unit = {
print("NativeResourceRef: cleanup\n")
// remove is a blocking call
val ref: NativeResourceRef = referenceQueue.remove().asInstanceOf[NativeResourceRef]
print("NativeResourceRef: got a reference with deAlloc\n")
// phantomRef will be removed from the map when NativeResource.close is called.
val resource = phantomRefMap.get(ref)

if (resource != 0L) { // since CPtrAddress is Scala Long, it cannot be null
print("NativeResourceRef: got a reference for resource\n")
ref.resourceDeAllocator(resource)
phantomRefMap.remove(ref)
}
}

private class ResourceCleanupThread extends Thread {
protected class ResourceCleanupThread extends Thread {
setPriority(Thread.MAX_PRIORITY)
setName("NativeResourceDeAllocatorThread")
setDaemon(true)

def deAllocate(): Unit = {
print("NativeResourceRef: cleanup\n")
// remove is a blocking call
val ref: NativeResourceRef = referenceQueue.remove().asInstanceOf[NativeResourceRef]
print("NativeResourceRef: got a reference with deAlloc\n")
// phantomRef will be removed from the map when NativeResource.close is called.
val resource = phantomRefMap.get(ref)
if (resource != 0L) { // since CPtrAddress is Scala Long, it cannot be null
print("NativeResourceRef: got a reference for resource\n")
ref.resourceDeAllocator(resource)
phantomRefMap.remove(ref)
}
}

override def run(): Unit = {
while (true) {
try {
cleanup()
deAllocate()
}
catch {
case _: InterruptedException => Thread.currentThread().interrupt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,20 @@ class ResourceScope extends AutoCloseable {

override def close(): Unit = {
resourceQ.foreach(resource => if (resource != null) {
logger.info("releasing resource:%x\n".format(resource.nativeAddress))
logger.info("releasing resource:%x\n".format(resource.nativeResource))
resource.dispose(false)
} else {logger.info("found resource which is null")}
)
ResourceScope.resourceScope.get().-=(this)
}

private[mxnet] def register(resource: NativeResource): Unit = {
logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress))
logger.info("ResourceScope: Registering Resource %x".format(resource.nativeResource))
resourceQ.+=(resource)
}

// TODO(@nswamy): this is linear in time, find better data structure
private[mxnet] def deRegister(resource: NativeResource): Unit = {
logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress))
logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeResource))
resourceQ.-=(resource)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol])

// unable to get the byteAllocated for Symbol
override def bytesAllocated: Long = 0L
override val bytesAllocated: Long = 0L

override def nativeAddress: CPtrAddress = handle
override def nativeResource: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxSymbolFree
override val phantomRef: NativeResourceRef = super.register()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import scala.collection.mutable.ListBuffer
private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
dataName: String = "data",
labelName: String = "label")
extends DataIter with WarnIfNotDisposed {
extends DataIter with NativeResource {

private val logger = LoggerFactory.getLogger(classOf[MXDataIter])

Expand Down Expand Up @@ -67,20 +67,13 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
}
}

override def nativeResource: CPtrAddress = handle

private var disposed = false
protected def isDisposed = disposed
override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxDataIterFree

/**
* Release the native memory.
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (!disposed) {
_LIB.mxDataIterFree(handle)
disposed = true
}
}
override val phantomRef: NativeResourceRef = super.register()

override val bytesAllocated: Long = 0L

/**
* reset the iterator
Expand Down
Loading

0 comments on commit 0bb098d

Please sign in to comment.