Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

NativeResource Management in Scala #12647

Merged
merged 21 commits into from
Oct 19, 2018
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
db87a2b
add Generic MXNetHandle trait and MXNetHandlePhantomRef class that wi…
nswamy Aug 27, 2018
cba8a43
use nswamy@ personal repo for mac testing
nswamy Aug 27, 2018
34106a4
Generic Handle with AutoCloseable
nswamy Aug 28, 2018
373ac78
add NativeResource and NativeResourceManager with Periodic GC calling
nswamy Aug 30, 2018
e0016d7
use NativeResource trait in NDArray, Symbol and Executor
nswamy Aug 30, 2018
5cd3cd3
add run train mnist script
nswamy Aug 30, 2018
ef4bfe8
create a Generic ResourceScope that can collect all NativeResources t…
nswamy Sep 4, 2018
c04e4f0
modify NativeResource and ResourceScope, extend NativeResource in NDA…
nswamy Sep 7, 2018
bb36934
remove GCExecutor
nswamy Sep 7, 2018
9d92dc1
deRegister PhantomReferences by when calling dispose()
nswamy Sep 7, 2018
1465717
add Finalizer(temporary) to NativeResource
nswamy Sep 7, 2018
e9b4b70
refactor NativeResource.dispose() method
nswamy Sep 7, 2018
dd294f0
update NativeResource/add Unit Test for NativeResource
nswamy Sep 21, 2018
980db5a
updates to NativeResource/NativeResourceRef and unit tests to NativeR…
nswamy Sep 24, 2018
2b0b073
remove redundant code added because of the object equality that was n…
nswamy Oct 12, 2018
18b1175
add ResourceScope
nswamy Oct 12, 2018
e2d5c99
Fix NativeResource to not remove from Scope, add Unit Tests to Resour…
nswamy Oct 14, 2018
21140cf
cleanup log/print debug statements
nswamy Oct 14, 2018
362ae18
use TreeSet inplace of ArrayBuffer to speedup removal of resources fr…
nswamy Oct 15, 2018
d78f571
fix segfault that was happening because of NDArray creation on the fl…
nswamy Oct 18, 2018
f0e873b
Add comments for dispose(param:Boolean)
nswamy Oct 18, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,12 @@
<artifactId>commons-io</artifactId>
<version>2.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this module for the Unittest?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object Executor {
* @see Symbol.bind : to create executor
*/
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] val symbol: Symbol) extends WarnIfNotDisposed {
private[mxnet] val symbol: Symbol) extends NativeResource {
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
Expand All @@ -59,16 +59,11 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])

private var disposed = false
protected def isDisposed = disposed

def dispose(): Unit = {
if (!disposed) {
outputs.foreach(_.dispose())
_LIB.mxExecutorFree(handle)
disposed = true
}
}
override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
override val bytesAllocated: Long = 0
override val ref: NativeResourceRef = super.register()

/**
* Return a new executor with the same symbol and shared memory,
Expand Down Expand Up @@ -305,4 +300,5 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
checkCall(_LIB.mxExecutorPrint(handle, str))
str.value
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -562,16 +562,20 @@ object NDArray extends NDArrayBase {
*/
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
val writable: Boolean = true,
addToCollector: Boolean = true) extends WarnIfNotDisposed {
addToCollector: Boolean = true) extends NativeResource {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why extend this? is that possible to make it composite (such as create a NativeResource Object in NDArray)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the right patten. NDArray is a NativeResource and all the functionality of NativeResource is provided through inheritance provided by NDArray.

if (addToCollector) {
NDArrayCollector.collect(this)
}

override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree
override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product

override val ref: NativeResourceRef = super.register()

// record arrays who construct this array instance
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
@volatile private var disposed = false
def isDisposed: Boolean = disposed

def serialize(): Array[Byte] = {
val buf = ArrayBuffer.empty[Byte]
Expand All @@ -584,11 +588,10 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* The NDArrays it depends on will NOT be disposed. <br />
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
override def dispose(): Unit = {
if (!disposed) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be checking isDisposed instead of disposed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point!, i'll change for all and make this variable private.

_LIB.mxNDArrayFree(handle)
super.dispose()
dependencies.clear()
disposed = true
}
}

Expand Down Expand Up @@ -1034,6 +1037,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
// TODO: naive implementation
shape.hashCode + toArray.hashCode
}

}

private[mxnet] object NDArrayConversions {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

import org.apache.mxnet.Base.CPtrAddress
import java.lang.ref.{PhantomReference, ReferenceQueue, WeakReference}
import java.util.concurrent._

import org.apache.mxnet.Base.checkCall
import java.util.concurrent.atomic.AtomicLong


/**
* NativeResource trait is used to manage MXNet Objects
* such as NDArray, Symbol, Executor, etc.,
* The MXNet Object calls NativeResource.register
* and assign the returned NativeResourceRef to PhantomReference
* NativeResource also implements AutoCloseable so MXNetObjects
* can be used like Resources in try-with-resources paradigm
*/
private[mxnet] trait NativeResource
extends AutoCloseable with WarnIfNotDisposed {

/**
* native Address associated with this object
*/
def nativeAddress: CPtrAddress

/**
* Function Pointer to the NativeDeAllocator of nativeAddress
*/
def nativeDeAllocator: (CPtrAddress => Int)

/** Call NativeResource.register to get the reference
*/
val ref: NativeResourceRef

/**
* Off-Heap Bytes Allocated for this object
*/
// intentionally making it a val, so it gets evaluated when defined
val bytesAllocated: Long

private[mxnet] var scope: ResourceScope = null

@volatile var disposed = false

override def isDisposed: Boolean = disposed || isDeAllocated

/**
* Register this object for PhantomReference tracking and in
* ResourceScope if used inside ResourceScope.
* @return NativeResourceRef that tracks reachability of this object
* using PhantomReference
*/
def register(): NativeResourceRef = {
scope = ResourceScope.getCurrentScope()
if (scope != null) scope.add(this)

NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
// register with PhantomRef tracking to release incase the objects go
// out of reference within scope but are held for long time
NativeResourceRef.register(this, nativeDeAllocator)
}

// Removes this object from PhantomRef tracking and from ResourceScope
private def deRegister(removeFromScope: Boolean): Unit = {
NativeResourceRef.deRegister(ref)
if (scope != null && removeFromScope) scope.remove(this)
}

// Implements [[@link AutoCloseable.close]]
override def close(): Unit = {
dispose()
}

// Implements [[@link WarnIfNotDisposed.dispose]]
def dispose(): Unit = dispose(true)

private[mxnet] def dispose(removeFromScope: Boolean = true): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment describing why we need the bool here (in what scenario would we want this to be false?).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Also, given that we've got a mapping from dispose() to dispose(true) above I don't think we need the default true here. I'm not completely sure why the call in 89 isn't considered ambiguous since it could be mapped to either 93 or 95 but apparently the compiler resolves it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing the compiler runs a closure and removes the unnecessary method without which it fails to recognize the WarnIfNotDisposed's. dispose method that needs to be implemented

if (!disposed) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be checking isDisposed here?

checkCall(nativeDeAllocator(this.nativeAddress))
deRegister(removeFromScope)
NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated)
disposed = true
}
}

/*
this is used by the WarnIfNotDisposed finalizer,
the object could be disposed by the GC without the need for explicit disposal
but the finalizer might not have run, then the WarnIfNotDisposed throws a warning
*/
private[mxnet] def isDeAllocated(): Boolean = NativeResourceRef.isDeAllocated(ref)

}

private[mxnet] object NativeResource {
var totalBytesAllocated : AtomicLong = new AtomicLong(0)
}

// Do not make [[NativeResource.resource]] a member of the class,
// this will hold reference and GC will not clear the object.
private[mxnet] class NativeResourceRef(resource: NativeResource,
val resourceDeAllocator: CPtrAddress => Int)
extends PhantomReference[NativeResource](resource, NativeResourceRef.refQ) {}

private[mxnet] object NativeResourceRef {

private[mxnet] val refQ: ReferenceQueue[NativeResource]
= new ReferenceQueue[NativeResource]

private[mxnet] val refMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this Map? can we store resource.nativeAddress (not whole NativeResource.resource) as a member of NativeResourceRef?
My major concern is, given two thread-safe containers cooperate with each other, a small carelessness can result in concurrency disaster. isDeAllocated is not really necessary - if you're worry about WarnIfNotDisposed, can we simply remove it? since now everything is guaranteed to be disposed finally (we invented that because we removed dispose() in finalize at that time).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This map is used to keep a reference to NativeResourceRef without which the GC will just finalize it


private[mxnet] val cleaner = new ResourceCleanupThread()

cleaner.start()

def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)):
NativeResourceRef = {
val ref = new NativeResourceRef(resource, nativeDeAllocator)
refMap.put(ref, resource.nativeAddress)
ref
}

def deRegister(ref: NativeResourceRef): Unit = {
if (refMap.containsKey(ref)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: ConcurrentHashMap does a safe remove. We don't technically have to check to see whether or not the key is there before removing. If we never expect this to happen we could gain a tiny performance increase by not having the contains check.

refMap.remove(ref)
}
}

/**
* This method will check if the cleaner ran and deAllocated the object
* As a part of GC, when the object is unreachable GC inserts a phantomRef
* to the ReferenceQueue which the cleaner thread will deallocate, however
* the finalizer runs much later depending on the GC.
* @param resource resource to verify if it has been deAllocated
* @return true if already deAllocated
*/
def isDeAllocated(ref: NativeResourceRef): Boolean = {
!refMap.containsKey(ref)
}

def cleanup: Unit = {
// remove is a blocking call
val ref: NativeResourceRef = refQ.remove().asInstanceOf[NativeResourceRef]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it NativeResourceRef while refQ is ReferenceQueue[NativeResource], which should return NativeResource. did I miss something?

Copy link
Member Author

@nswamy nswamy Oct 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way PhantomReference works is, it is inserted to the Queue when the object(NativeResource) is not reachable after that point you cannot access the NativeResource using ref.get(), hence the NativeResourceRef(with NativeAddress and DeAllocator fn) is held in a HashMap
see this article

  • from the article
    since Phantom don’t have a link to the actual object, a typical pattern is to derive your own Reference type from Phantom and adding some info useful for the final free, for example filename.

// phantomRef will be removed from the map when NativeResource.close is called.
val resource = refMap.get(ref)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ref is not in map then resource will be null. This passes the next if and the ref.resourceDeAllocator will nullPointerException. This shouldn't ever happen but we should check for it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a check see the next line

    // phantomRef will be removed from the map when NativeResource.close is called.
    val resource = refMap.get(ref)
    if (resource != 0L)  { // since CPtrAddress is Scala a Long, it cannot be null

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I misunderstood you're comment. You're saying that since refMap.get will return a long that resource is a long. In the event of a null return then resource will be 0L because longs cannot be null.

Disregard then

if (resource != 0L) { // since CPtrAddress is Scala a Long, it cannot be null
ref.resourceDeAllocator(resource)
refMap.remove(ref)
}
}

protected class ResourceCleanupThread extends Thread {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set the frequency of this clean up process?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

read the comments. remove is a blocking call.

setPriority(Thread.MAX_PRIORITY)
setName("NativeResourceDeAllocatorThread")
setDaemon(true)

override def run(): Unit = {
while (true) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do with a sleep time?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the fact that this is constantly running in a while(true) loop but I don't have a better solution atm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReferenceQueue's remove is a blocking call, it is not spinning, the thread is blocked.

try {
NativeResourceRef.cleanup
}
catch {
case _: InterruptedException => Thread.currentThread().interrupt()
}
}
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline here or scala style complains

Loading