Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-600][Scala] NDArray auto-collector (#11751)
Browse files Browse the repository at this point in the history
* [Scala] NDArrayCollector for automatically disposing NDArrays

* modify doc for NDArrayCollector

* modify the function doc of NDArrayCollector.withScope

* remove trivial changes

* put dispose in finally

* fix jni NDArray signature

* modify doc and private var

* dispose res when test finishes

* add comments, change variable name
  • Loading branch information
yzhliu authored Jul 19, 2018
1 parent 4b8ab63 commit 1031fe1
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private def getOutputs: Array[NDArray] = {
val ndHandles = ArrayBuffer[NDArrayHandle]()
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
ndHandles.toArray.map(new NDArray(_))
ndHandles.toArray.map(new NDArray(_, addToCollector = false))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Monitor(
override def invoke(name: String, arr: NDArrayHandle): Unit = {
// wrapper for executor callback
if (activated) {
val array = new NDArray(arr, writable = false)
val array = new NDArray(arr, writable = false, addToCollector = false)
val elem = (step, name, statFunc(array))
queue += elem
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,16 @@ object NDArray extends NDArrayBase {
* </b>
*/
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
val writable: Boolean = true) extends WarnIfNotDisposed {
val writable: Boolean = true,
addToCollector: Boolean = true) extends WarnIfNotDisposed {
if (addToCollector) {
NDArrayCollector.collect(this)
}

// record arrays who construct this array instance
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
private var disposed = false
@volatile private var disposed = false
def isDisposed: Boolean = disposed

def serialize(): Array[Byte] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

import org.apache.mxnet.Base.CPtrAddress
import org.slf4j.LoggerFactory

import scala.annotation.varargs
import scala.collection.mutable

/**
* A collector to store NDArrays.
* It provides a scope, NDArrays allocated in the scope can either <br />
* - be disposed automatically when the code block finishes, or <br />
* - simply be collected for future usage.
* <br />
* If the return type of scope is <em>NDArray</em> or <em>NDArrayFuncReturn</em>,
* the collector is smart enough NOT to collect or dispose the returned NDArray. <br />
* However in other cases, it is users' responsibility NOT to leak allocated NDArrays outside,
* (e.g., store to a global variable and use later, pass to another thread, etc.) <br />
* Usage Example:
* <pre>
* val a = NDArray.array(Array(-1f, 0f, 1f, 2f, 3f, 4f), shape = Shape(2, 3))
* val res = NDArrayCollector.auto().withScope {
* (NDArray.relu(a) + a).toArray
* }
* </pre>
* In the case above, the intermediate NDArrays
* (created by <em>NDArray.relu</em> and <em>+</em>) will be disposed automatically. <br />
* User can also decide to dispose the collected NDArrays later: <br />
* <pre>
* val collector = NDArrayCollector.manual()
* val res = collector.withScope {
* (NDArray.relu(a) + a).toArray
* }
* collector.foreach(_.dispose())
* </pre>
* For Java users: <br />
* <pre>
* NDArray a = NDArray.array(new float[]{-1f, 0f, 1f, 2f, 3f, 4f},
* Shape.create(2, 3), Context.cpu(0));
* float[] sliced = NDArrayCollector.auto().withScope(
* new scala.runtime.AbstractFunction0<float[]>() {
* @Override
* public float[] apply() {
* a.slice(0, 1).toArray();
* }
* });
* </pre>
*/
object NDArrayCollector {
private val logger = LoggerFactory.getLogger(classOf[NDArrayCollector])

private val currCollector = new ThreadLocal[NDArrayCollector] {
override def initialValue = new NDArrayCollector(false, false)
}

/**
* Create a collector which will dispose the collected NDArrays automatically.
* @return an auto-disposable collector.
*/
def auto(): NDArrayCollector = new NDArrayCollector(true)

/**
* Create a collector allows users to later dispose the collected NDArray manually.
* @return a manually-disposable collector.
*/
def manual(): NDArrayCollector = new NDArrayCollector(false)

/**
* Collect the NDArrays into the collector of the current thread.
* @param ndArray NDArrays need to be collected.
*/
@varargs def collect(ndArray: NDArray*): Unit = {
currCollector.get().add(ndArray: _*)
}
}

class NDArrayCollector private(private val autoDispose: Boolean = true,
private val doCollect: Boolean = true) {
// native ptr (handle) of the NDArray -> NDArray
// in some rare situation, multiple NDArrays have same native ptr,
// the Map here is to prevent from disposing more than once.
private val arrays = mutable.HashMap.empty[CPtrAddress, NDArray]

private def add(nd: NDArray*): Unit = {
if (doCollect) nd.foreach(arr => arrays.put(arr.handle, arr))
}

/**
* Clear the collector.
*/
def clear(): Unit = {
arrays.clear()
}

/**
* Iterate over the collected NDArrays and apply the user-defined function to each NDArray.
* @param f the function that is applied for its side-effect to every NDArray.
* The result of function <em>f</em> is discarded.
*/
def foreach(f: NDArray => Unit): Unit = {
arrays.values.foreach(f(_))
}

/**
* @return how many unique NDArrays are collected.
*/
def size: Int = arrays.size

/**
* Create a code scope, NDArrays allocated within this scope will be collected.
* The collected NDArrays will be either <br />
* - disposed automatically when the code block finishes (when using <em>auto</em>) or <br />
* - stored for later access (when using <em>manual</em>) <br />
* If the return type of scope is <em>NDArray</em> or <em>NDArrayFuncReturn</em>,
* it is smart enough NOT to collect or dispose the returned NDArray. <br />
* However in other cases, it is users' responsibility NOT to leak allocated NDArrays outside.
* @param codeBlock code block to be executed within the scope.
* @tparam T return type of the function <em>codeBlock</em>.
* @return The result of function <em>codeBlock</em>.
*/
def withScope[T](codeBlock: => T): T = {
val old = NDArrayCollector.currCollector.get()
NDArrayCollector.currCollector.set(this)
try {
val ret = codeBlock
ret match {
case ndRet: NDArray =>
arrays.remove(ndRet.handle)
case ndarrays: NDArrayFuncReturn =>
ndarrays.arr.foreach(nd => arrays.remove(nd.handle))
case _ => // do nothing
}
ret
} finally {
if (autoDispose) {
foreach(_.dispose())
clear()
}
NDArrayCollector.currCollector.set(old)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ abstract class CustomOp {
val tensors = (0 until 5).toArray.map( x => ArrayBuffer[NDArray]() )
for (i <- 0 until numNdarray) {
if (tags(i) == 1 || tags(i) == 4) {
tensors(tags(i)) += new NDArray(ndarraies(i), writable = true)
tensors(tags(i)) += new NDArray(ndarraies(i), writable = true, addToCollector = false)
} else {
tensors(tags(i)) += new NDArray(ndarraies(i), writable = false)
tensors(tags(i)) += new NDArray(ndarraies(i), writable = false, addToCollector = false)
}
}
val reqEnum = Array("null", "write", "inplace", "add")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,15 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
*/
private def _padData(ndArray: NDArray): NDArray = {
val padNum = cursor + dataBatchSize - numData
val newArray = NDArray.zeros(ndArray.slice(0, dataBatchSize).shape)
val batch = ndArray.slice(cursor, numData)
val padding = ndArray.slice(0, padNum)
newArray.slice(0, dataBatchSize - padNum).set(batch).dispose()
newArray.slice(dataBatchSize - padNum, dataBatchSize).set(padding).dispose()
batch.dispose()
padding.dispose()
newArray
val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1, ndArray.shape.size)
val newArray = NDArray.zeros(shape)
NDArrayCollector.auto().withScope {
val batch = ndArray.slice(cursor, numData)
val padding = ndArray.slice(0, padNum)
newArray.slice(0, dataBatchSize - padNum).set(batch)
newArray.slice(dataBatchSize - padNum, dataBatchSize).set(padding)
newArray
}
}

private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}

class NDArrayCollectorSuite extends FunSuite with BeforeAndAfterAll with Matchers {

test("auto dispose") {
val a = NDArray.array(Array(-1f, 0f, 1f, 2f, 3f, 4f), shape = Shape(2, 3))
var b, c: NDArray = null

val res = NDArrayCollector.auto().withScope {
b = NDArray.relu(a) // [0, 0, 1, 2, 3, 4]
c = a + b // [-1, 0, 2, 4, 6, 8]
c.slice(0, 1)
}

assert(b.isDisposed)
assert(c.isDisposed)
assert(!res.isDisposed) // smart enough not to dispose the returned NDArray

assert(res.toArray === Array(-1f, 0f, 2f))

res.dispose()
}

test("manually dispose") {
val a = NDArray.array(Array(-1f, 0f, 1f, 2f, 3f, 4f), shape = Shape(2, 3))
var b, c: NDArray = null

val collector = NDArrayCollector.manual()
val res = collector.withScope {
b = NDArray.relu(a) // [0, 0, 1, 2, 3, 4]
c = a + b // [-1, 0, 2, 4, 6, 8]
c.slice(0, 1)
}

assert(res.toArray === Array(-1f, 0f, 2f))

assert(collector.size === 2) // smart enough not to collect the returned NDArray
assert(!b.isDisposed)
assert(!c.isDisposed)
assert(!res.isDisposed)

collector.foreach(_.dispose())
assert(b.isDisposed)
assert(c.isDisposed)
assert(!res.isDisposed)

collector.clear()
assert(collector.size === 0)

res.dispose()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ extern "C" void KVStoreUpdaterCallbackFunc

// find java NDArray constructor
jclass ndObjClass = env->FindClass("org/apache/mxnet/NDArray");
jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "<init>", "(JZ)V");
jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "<init>", "(JZZ)V");

jobject ndRecv = env->NewObject(ndObjClass, ndObjConstructor,
reinterpret_cast<jlong>(recv), true);
Expand Down

0 comments on commit 1031fe1

Please sign in to comment.