Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
updates to NativeResource/NativeResourceRef and unit tests to NativeR…
Browse files Browse the repository at this point in the history
…esource
  • Loading branch information
nswamy committed Sep 24, 2018
1 parent 0bb098d commit 3a4b2bc
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])

override def nativeResource: CPtrAddress = handle
override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
override val bytesAllocated: Long = 0
override val phantomRef: NativeResourceRef = super.register()
override val ref: NativeResourceRef = super.register()

/**
* Return a new executor with the same symbol and shared memory,
Expand Down
15 changes: 3 additions & 12 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ class DataBatch(val data: IndexedSeq[NDArray],
// use DataDesc to indicate the order of data/label loading
// (must match the order of input data/label)
private val providedDataDesc: IndexedSeq[DataDesc],
private val providedLabelDesc: IndexedSeq[DataDesc]) extends
NativeResource {
private val providedLabelDesc: IndexedSeq[DataDesc]) {
// TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)]
// However, since the data and label can be accessed publicly (no getter and setter)
// the change on this will break BC
Expand All @@ -163,26 +162,17 @@ class DataBatch(val data: IndexedSeq[NDArray],
this(data, label, index, pad, bucketKey,
DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel))
}

// overriding here so DataBatch gets added to Scope and can be disposed
override def nativeResource: CPtrAddress = 0
override def nativeDeAllocator: CPtrAddress => MXUint = doNothingDeAllocator
def doNothingDeAllocator(x: CPtrAddress): MXUint = {0}
override val phantomRef: NativeResourceRef = super.register()
override val bytesAllocated: DataIterCreator = 0

/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
*/
override def dispose(): Unit = {
def dispose(): Unit = {
if (data != null) {
data.foreach(arr => if (arr != null) arr.dispose())
}
if (label != null) {
label.foreach(arr => if (arr != null) arr.dispose())
}
super.dispose()
}

// The name and shape of data
Expand All @@ -208,6 +198,7 @@ class DataBatch(val data: IndexedSeq[NDArray],
def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc

def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc

}

object DataBatch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,11 +567,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArrayCollector.collect(this)
}

override def nativeResource: CPtrAddress = handle
override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree
override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product

override val phantomRef: NativeResourceRef = super.register()
override val ref: NativeResourceRef = super.register()

// record arrays who construct this array instance
// we use weak reference to prevent gc blocking
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,46 +24,44 @@ import java.util.concurrent._
import org.apache.mxnet.Base.checkCall
import java.util.concurrent.atomic.AtomicLong

import org.apache.mxnet.NativeResourceRef.phantomRefMap
import org.slf4j.{Logger, LoggerFactory}

/**
* NativeResource trait is used to manage MXNet Objects
* such as NDArray, Symbol, Executor, etc.,
* The MXNet Object calls {@link NativeResource.register}
* and assign the returned NativeResourceRef to {@link PhantomReference}
* The MXNet Object calls NativeResource.register
* and assign the returned NativeResourceRef to PhantomReference
* NativeResource also implements AutoCloseable so MXNetObjects
* can be used like Resources in try-with-resources paradigm
*/
// scalastyle:off finalize
private[mxnet] trait NativeResource
extends AutoCloseable with WarnIfNotDisposed {

/**
* native Address associated with this object
*/
def nativeResource: CPtrAddress
def nativeAddress: CPtrAddress

/**
* Function Pointer to the NativeDeAllocator of {@link nativeAddress}
* Function Pointer to the NativeDeAllocator of nativeAddress
*/
def nativeDeAllocator: (CPtrAddress => Int)

/** Call {@link NativeResource.register} to get {@link NativeResourceRef}
/** Call NativeResource.register to get the reference
*/
val phantomRef: NativeResourceRef
val ref: NativeResourceRef

/**
* Off-Heap Bytes Allocated for this object
*/
// intentionally making it a val, so it gets evaluated when defined
val bytesAllocated: Long

private var scope: ResourceScope = null
private[mxnet] var scope: ResourceScope = null

@volatile var disposed = false

override def isDisposed: Boolean = disposed
override def isDisposed: Boolean = disposed || isDeAllocated

/**
* Register this object for PhantomReference tracking and within
* ResourceScope if used inside ResourceScope.
Expand All @@ -80,12 +78,9 @@ private[mxnet] trait NativeResource
NativeResourceRef.register(this, nativeDeAllocator)
}

/**
* Removes this object from PhantomRef tracking and from ResourceScope
* @param removeFromScope
*/
// Removes this object from PhantomRef tracking and from ResourceScope
private def deRegister(removeFromScope: Boolean = true): Unit = {
NativeResourceRef.deRegister(phantomRef)
NativeResourceRef.deRegister(ref)
if (scope != null && removeFromScope) scope.deRegister(this)
}

Expand All @@ -101,45 +96,81 @@ private[mxnet] trait NativeResource

def dispose(removeFromScope: Boolean): Unit = {
if (!disposed) {
print("NativeResource: Disposing NativeResource:%x\n".format(nativeResource))
checkCall(nativeDeAllocator(this.nativeResource))
print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress))
checkCall(nativeDeAllocator(this.nativeAddress))
deRegister(removeFromScope)
NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated)
disposed = true
}
}

/*
this is used by the WarnIfNotDisposed finalizer,
the object could be disposed by the GC without the need for explicit disposal
but the finalizer might not have run, then the WarnIfNotDisposed throws a warning
*/
private def isDeAllocated(): Boolean = NativeResourceRef.isDeAllocated(ref)

}
// scalastyle:on finalize

private[mxnet] object NativeResource {
var totalBytesAllocated : AtomicLong = new AtomicLong(0)
}
// do not make resource a member, this will hold reference and GC will not clear the object.

/* Do not make resource a member of the class,
this will hold reference and GC will not clear the object.
*/
private[mxnet] class NativeResourceRef(resource: NativeResource,
val resourceDeAllocator: CPtrAddress => Int)
extends PhantomReference[NativeResource](resource, NativeResourceRef.referenceQueue) {}
extends PhantomReference[NativeResource](resource, NativeResourceRef.refQ) {}

private[mxnet] object NativeResourceRef {

private[mxnet] val referenceQueue: ReferenceQueue[NativeResource]
private[mxnet] val refQ: ReferenceQueue[NativeResource]
= new ReferenceQueue[NativeResource]

private[mxnet] val phantomRefMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]()
private[mxnet] val refMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]()

private[mxnet] val cleaner = new ResourceCleanupThread()

cleaner.start()

def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)):
NativeResourceRef = {
val resourceRef = new NativeResourceRef(resource, nativeDeAllocator)
phantomRefMap.put(resourceRef, resource.nativeResource)
resourceRef
val ref = new NativeResourceRef(resource, nativeDeAllocator)
refMap.put(ref, resource.nativeAddress)
ref
}

def deRegister(resourceRef: NativeResourceRef): Unit = {
if (phantomRefMap.containsKey(resourceRef)) {
phantomRefMap.remove(resourceRef)
def deRegister(ref: NativeResourceRef): Unit = {
if (refMap.containsKey(ref)) {
refMap.remove(ref)
}
}

/**
* This method will check if the cleaner ran and deAllocated the object
* As a part of GC, when the object is unreachable GC inserts a phantomRef
* to the ReferenceQueue which the cleaner thread will deallocate, however
* the finalizer runs much later depending on the GC.
* @param resource resource to verify if it has been deAllocated
* @return true if already deAllocated
*/
def isDeAllocated(ref: NativeResourceRef): Boolean = {
!refMap.containsKey(ref)
}

def cleanup: Unit = {
print("NativeResourceRef: cleanup\n")
// remove is a blocking call
val ref: NativeResourceRef = refQ.remove().asInstanceOf[NativeResourceRef]
print("NativeResourceRef: got a reference with deAlloc\n")
// phantomRef will be removed from the map when NativeResource.close is called.
val resource = refMap.get(ref)
if (resource != 0L) { // since CPtrAddress is Scala a Long, it cannot be null
print("NativeResourceRef: got a reference for resource\n")
ref.resourceDeAllocator(resource)
refMap.remove(ref)
}
}

Expand All @@ -148,24 +179,10 @@ private[mxnet] object NativeResourceRef {
setName("NativeResourceDeAllocatorThread")
setDaemon(true)

def deAllocate(): Unit = {
print("NativeResourceRef: cleanup\n")
// remove is a blocking call
val ref: NativeResourceRef = referenceQueue.remove().asInstanceOf[NativeResourceRef]
print("NativeResourceRef: got a reference with deAlloc\n")
// phantomRef will be removed from the map when NativeResource.close is called.
val resource = phantomRefMap.get(ref)
if (resource != 0L) { // since CPtrAddress is Scala Long, it cannot be null
print("NativeResourceRef: got a reference for resource\n")
ref.resourceDeAllocator(resource)
phantomRefMap.remove(ref)
}
}

override def run(): Unit = {
while (true) {
try {
deAllocate()
NativeResourceRef.cleanup
}
catch {
case _: InterruptedException => Thread.currentThread().interrupt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@ class ResourceScope extends AutoCloseable {

override def close(): Unit = {
resourceQ.foreach(resource => if (resource != null) {
logger.info("releasing resource:%x\n".format(resource.nativeResource))
logger.info("releasing resource:%x\n".format(resource.nativeAddress))
resource.dispose(false)
} else {logger.info("found resource which is null")}
)
ResourceScope.resourceScope.get().-=(this)
}

private[mxnet] def register(resource: NativeResource): Unit = {
logger.info("ResourceScope: Registering Resource %x".format(resource.nativeResource))
logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress))
resourceQ.+=(resource)
}

private[mxnet] def deRegister(resource: NativeResource): Unit = {
logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeResource))
logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress))
resourceQ.-=(resource)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso

// unable to get the byteAllocated for Symbol
override val bytesAllocated: Long = 0L

override def nativeResource: CPtrAddress = handle
override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxSymbolFree
override val phantomRef: NativeResourceRef = super.register()
override val ref: NativeResourceRef = super.register()


def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
}
}

override def nativeResource: CPtrAddress = handle
override def nativeAddress: CPtrAddress = handle

override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxDataIterFree

override val phantomRef: NativeResourceRef = super.register()
override val ref: NativeResourceRef = super.register()

override val bytesAllocated: Long = 0L

Expand Down

This file was deleted.

Loading

0 comments on commit 3a4b2bc

Please sign in to comment.