-
Notifications
You must be signed in to change notification settings - Fork 6.8k
NativeResource Management in Scala #12647
Changes from 18 commits
db87a2b
cba8a43
34106a4
373ac78
e0016d7
5cd3cd3
ef4bfe8
c04e4f0
bb36934
9d92dc1
1465717
e9b4b70
dd294f0
980db5a
2b0b073
18b1175
e2d5c99
21140cf
362ae18
d78f571
f0e873b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -562,16 +562,20 @@ object NDArray extends NDArrayBase { | |
*/ | ||
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, | ||
val writable: Boolean = true, | ||
addToCollector: Boolean = true) extends WarnIfNotDisposed { | ||
addToCollector: Boolean = true) extends NativeResource { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why extend this? is that possible to make it composite (such as create a NativeResource Object in NDArray)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the right patten. NDArray |
||
if (addToCollector) { | ||
NDArrayCollector.collect(this) | ||
} | ||
|
||
override def nativeAddress: CPtrAddress = handle | ||
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree | ||
override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product | ||
|
||
override val ref: NativeResourceRef = super.register() | ||
|
||
// record arrays who construct this array instance | ||
// we use weak reference to prevent gc blocking | ||
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] | ||
@volatile private var disposed = false | ||
def isDisposed: Boolean = disposed | ||
|
||
def serialize(): Array[Byte] = { | ||
val buf = ArrayBuffer.empty[Byte] | ||
|
@@ -584,11 +588,10 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, | |
* The NDArrays it depends on will NOT be disposed. <br /> | ||
* The object shall never be used after it is disposed. | ||
*/ | ||
def dispose(): Unit = { | ||
override def dispose(): Unit = { | ||
if (!disposed) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be checking isDisposed instead of disposed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point!, i'll change for all and make this variable private. |
||
_LIB.mxNDArrayFree(handle) | ||
super.dispose() | ||
dependencies.clear() | ||
disposed = true | ||
} | ||
} | ||
|
||
|
@@ -1034,6 +1037,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, | |
// TODO: naive implementation | ||
shape.hashCode + toArray.hashCode | ||
} | ||
|
||
} | ||
|
||
private[mxnet] object NDArrayConversions { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mxnet | ||
|
||
import org.apache.mxnet.Base.CPtrAddress | ||
import java.lang.ref.{PhantomReference, ReferenceQueue, WeakReference} | ||
import java.util.concurrent._ | ||
|
||
import org.apache.mxnet.Base.checkCall | ||
import java.util.concurrent.atomic.AtomicLong | ||
|
||
|
||
/** | ||
* NativeResource trait is used to manage MXNet Objects | ||
* such as NDArray, Symbol, Executor, etc., | ||
* The MXNet Object calls NativeResource.register | ||
* and assign the returned NativeResourceRef to PhantomReference | ||
* NativeResource also implements AutoCloseable so MXNetObjects | ||
* can be used like Resources in try-with-resources paradigm | ||
*/ | ||
private[mxnet] trait NativeResource | ||
extends AutoCloseable with WarnIfNotDisposed { | ||
|
||
/** | ||
* native Address associated with this object | ||
*/ | ||
def nativeAddress: CPtrAddress | ||
|
||
/** | ||
* Function Pointer to the NativeDeAllocator of nativeAddress | ||
*/ | ||
def nativeDeAllocator: (CPtrAddress => Int) | ||
|
||
/** Call NativeResource.register to get the reference | ||
*/ | ||
val ref: NativeResourceRef | ||
|
||
/** | ||
* Off-Heap Bytes Allocated for this object | ||
*/ | ||
// intentionally making it a val, so it gets evaluated when defined | ||
val bytesAllocated: Long | ||
|
||
private[mxnet] var scope: ResourceScope = null | ||
|
||
@volatile var disposed = false | ||
|
||
override def isDisposed: Boolean = disposed || isDeAllocated | ||
|
||
/** | ||
* Register this object for PhantomReference tracking and in | ||
* ResourceScope if used inside ResourceScope. | ||
* @return NativeResourceRef that tracks reachability of this object | ||
* using PhantomReference | ||
*/ | ||
def register(): NativeResourceRef = { | ||
scope = ResourceScope.getCurrentScope() | ||
if (scope != null) scope.add(this) | ||
|
||
NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated) | ||
// register with PhantomRef tracking to release incase the objects go | ||
// out of reference within scope but are held for long time | ||
NativeResourceRef.register(this, nativeDeAllocator) | ||
} | ||
|
||
// Removes this object from PhantomRef tracking and from ResourceScope | ||
private def deRegister(removeFromScope: Boolean): Unit = { | ||
NativeResourceRef.deRegister(ref) | ||
if (scope != null && removeFromScope) scope.remove(this) | ||
} | ||
|
||
// Implements [[@link AutoCloseable.close]] | ||
override def close(): Unit = { | ||
dispose() | ||
} | ||
|
||
// Implements [[@link WarnIfNotDisposed.dispose]] | ||
def dispose(): Unit = dispose(true) | ||
|
||
private[mxnet] def dispose(removeFromScope: Boolean = true): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a comment describing why we need the bool here (in what scenario would we want this to be false?). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: Also, given that we've got a mapping from dispose() to dispose(true) above I don't think we need the default true here. I'm not completely sure why the call in 89 isn't considered ambiguous since it could be mapped to either 93 or 95 but apparently the compiler resolves it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am guessing the compiler runs a closure and removes the unnecessary method without which it fails to recognize the WarnIfNotDisposed's. dispose method that needs to be implemented |
||
if (!disposed) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we be checking isDisposed here? |
||
checkCall(nativeDeAllocator(this.nativeAddress)) | ||
deRegister(removeFromScope) | ||
NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated) | ||
disposed = true | ||
} | ||
} | ||
|
||
/* | ||
this is used by the WarnIfNotDisposed finalizer, | ||
the object could be disposed by the GC without the need for explicit disposal | ||
but the finalizer might not have run, then the WarnIfNotDisposed throws a warning | ||
*/ | ||
private[mxnet] def isDeAllocated(): Boolean = NativeResourceRef.isDeAllocated(ref) | ||
|
||
} | ||
|
||
private[mxnet] object NativeResource { | ||
var totalBytesAllocated : AtomicLong = new AtomicLong(0) | ||
} | ||
|
||
// Do not make [[NativeResource.resource]] a member of the class, | ||
// this will hold reference and GC will not clear the object. | ||
private[mxnet] class NativeResourceRef(resource: NativeResource, | ||
val resourceDeAllocator: CPtrAddress => Int) | ||
extends PhantomReference[NativeResource](resource, NativeResourceRef.refQ) {} | ||
|
||
private[mxnet] object NativeResourceRef { | ||
|
||
private[mxnet] val refQ: ReferenceQueue[NativeResource] | ||
= new ReferenceQueue[NativeResource] | ||
|
||
private[mxnet] val refMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need this Map? can we store resource.nativeAddress (not whole NativeResource.resource) as a member of NativeResourceRef? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This map is used to keep a reference to |
||
|
||
private[mxnet] val cleaner = new ResourceCleanupThread() | ||
|
||
cleaner.start() | ||
|
||
def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)): | ||
NativeResourceRef = { | ||
val ref = new NativeResourceRef(resource, nativeDeAllocator) | ||
refMap.put(ref, resource.nativeAddress) | ||
ref | ||
} | ||
|
||
def deRegister(ref: NativeResourceRef): Unit = { | ||
if (refMap.containsKey(ref)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: ConcurrentHashMap does a safe remove. We don't technically have to check to see whether or not the key is there before removing. If we never expect this to happen we could gain a tiny performance increase by not having the contains check. |
||
refMap.remove(ref) | ||
} | ||
} | ||
|
||
/** | ||
* This method will check if the cleaner ran and deAllocated the object | ||
* As a part of GC, when the object is unreachable GC inserts a phantomRef | ||
* to the ReferenceQueue which the cleaner thread will deallocate, however | ||
* the finalizer runs much later depending on the GC. | ||
* @param resource resource to verify if it has been deAllocated | ||
* @return true if already deAllocated | ||
*/ | ||
def isDeAllocated(ref: NativeResourceRef): Boolean = { | ||
!refMap.containsKey(ref) | ||
} | ||
|
||
def cleanup: Unit = { | ||
// remove is a blocking call | ||
val ref: NativeResourceRef = refQ.remove().asInstanceOf[NativeResourceRef] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way PhantomReference works is, it is inserted to the Queue when the object(NativeResource) is not reachable after that point you cannot access the NativeResource using ref.get(), hence the NativeResourceRef(with NativeAddress and DeAllocator fn) is held in a HashMap
|
||
// phantomRef will be removed from the map when NativeResource.close is called. | ||
val resource = refMap.get(ref) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If ref is not in map then resource will be null. This passes the next if and the ref.resourceDeAllocator will nullPointerException. This shouldn't ever happen but we should check for it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is a check see the next line
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh, I misunderstood you're comment. You're saying that since refMap.get will return a long that resource is a long. In the event of a null return then resource will be 0L because longs cannot be null. Disregard then |
||
if (resource != 0L) { // since CPtrAddress is Scala a Long, it cannot be null | ||
ref.resourceDeAllocator(resource) | ||
refMap.remove(ref) | ||
} | ||
} | ||
|
||
protected class ResourceCleanupThread extends Thread { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we set the frequency of this clean up process? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. read the comments. remove is a blocking call. |
||
setPriority(Thread.MAX_PRIORITY) | ||
setName("NativeResourceDeAllocatorThread") | ||
setDaemon(true) | ||
|
||
override def run(): Unit = { | ||
while (true) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we do with a sleep time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like the fact that this is constantly running in a while(true) loop but I don't have a better solution atm. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ReferenceQueue's remove is a blocking call, it is not spinning, the thread is blocked. |
||
try { | ||
NativeResourceRef.cleanup | ||
} | ||
catch { | ||
case _: InterruptedException => Thread.currentThread().interrupt() | ||
} | ||
} | ||
} | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. newline here or scala style complains |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this module for the Unittest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes