diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala index d234ac66bdd8..6136db29d1eb 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala @@ -20,5 +20,5 @@ package org.apache.mxnet * typesafe NDArray API: NDArray.api._ * Main code will be generated during compile time through Macros */ -object NDArrayAPI { +object NDArrayAPI extends NDArrayAPIBase { } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala index 49de9ae73218..56da4fa64cf4 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala @@ -22,5 +22,5 @@ package org.apache.mxnet * typesafe Symbol API: Symbol.api._ * Main code will be generated during compile time through Macros */ -object SymbolAPI { +object SymbolAPI extends SymbolAPIBase { } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala index 47b4c100ea55..8a57527f3556 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala @@ -99,7 +99,9 @@ object ImageClassifierExample { batch = ListBuffer[String]() } } - output += batch.toList + if (batch.length > 0) { + output += batch.toList + } output.toList } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala index f55a60f0144b..b5222e662dfb 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala @@ -100,7 +100,9 @@ object SSDClassifierExample { batch = ListBuffer[String]() } } - output += batch.toList + if (batch.length > 0) { + output += batch.toList + } output.toList } diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala index 6fa313b3e7f2..8d31d1f6b3d6 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala @@ -198,7 +198,7 @@ object ImageClassifier { /** * Loads a batch of images from a folder - * @param inputImageDirPath Path to a folder of images + * @param inputImagePaths Path to a folder of images * @return List of buffered images */ def loadInputBatch(inputImagePaths: List[String]): Traversable[BufferedImage] = { diff --git a/scala-package/macros/pom.xml b/scala-package/macros/pom.xml index 73d90541ba1a..d80f72598750 100644 --- a/scala-package/macros/pom.xml +++ b/scala-package/macros/pom.xml @@ -53,6 +53,7 @@ + @@ -70,6 +71,29 @@ org.apache.maven.plugins maven-compiler-plugin + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + apidoc-generation + package + + java + + + + + + ${project.parent.basedir}/init/target/classes + + + ${project.parent.basedir}/core/src/main/scala/org/apache/mxnet/ + + org.apache.mxnet.APIDocGenerator + + org.scalatest scalatest-maven-plugin diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala new file mode 100644 index 000000000000..90fe2604e8b6 --- /dev/null +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.init.Base._ +import org.apache.mxnet.utils.CToScalaUtils + +import scala.collection.mutable.ListBuffer + +/** + * This object will generate the Scala documentation of the new Scala API + * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala + * The code will be executed during Macros stage and file live in Core stage + */ +private[mxnet] object APIDocGenerator{ + case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean) + case class absClassFunction(name : String, desc : String, + listOfArgs: List[absClassArg], returnType : String) + + + def main(args: Array[String]) : Unit = { + val FILE_PATH = args(0) + absClassGen(FILE_PATH, true) + absClassGen(FILE_PATH, false) + } + + def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = { + // scalastyle:off + val absClassFunctions = getSymbolNDArrayMethods(isSymbol) + // TODO: Add Filter to the same location in case of refactor + val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")).map(absClassFunction => { + val scalaDoc = generateAPIDocFromBackend(absClassFunction) + val defBody = generateAPISignature(absClassFunction, isSymbol) + s"$scalaDoc\n$defBody" + }) + val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase" + val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n" + val scalaStyle = "// scalastyle:off" + val packageDef = "package org.apache.mxnet" + val absClassDef = s"abstract class $packageName" + val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$absClassDef {\n${absFuncs.mkString("\n")}\n}" + import java.io._ + val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala")) + pw.write(finalStr) + pw.close() + } + + // Generate ScalaDoc type + def generateAPIDocFromBackend(func : absClassFunction) : String = { + val desc = func.desc.split("\n").map({ currStr => + s" * $currStr" + }) + val params = func.listOfArgs.map({ absClassArg => + val currArgName = absClassArg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case _ => absClassArg.argName + } + s" * @param $currArgName\t\t${absClassArg.argDesc}" + }) + val returnType = s" * @return ${func.returnType}" + s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" + } + + def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = { + var argDef = ListBuffer[String]() + func.listOfArgs.foreach(absClassArg => { + val currArgName = absClassArg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case _ => absClassArg.argName + } + if (absClassArg.isOptional) { + argDef += s"$currArgName : Option[${absClassArg.argType}] = None" + } + else { + argDef += s"$currArgName : ${absClassArg.argType}" + } + }) + var returnType = func.returnType + if (isSymbol) { + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + } else { + returnType = "org.apache.mxnet.NDArrayFuncReturn" + } + s"def ${func.name} (${argDef.mkString(", ")}) : ${returnType}" + } + + + // List and add all the atomic symbol functions to current module. + private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = { + val opNames = ListBuffer.empty[String] + val returnType = if (isSymbol) "Symbol" else "NDArray" + _LIB.mxListAllOpNames(opNames) + // TODO: Add '_linalg_', '_sparse_', '_image_' support + opNames.map(opName => { + val opHandle = new RefLong + _LIB.nnGetOpHandle(opName, opHandle) + makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType) + }).toList + } + + // Create an atomic symbol function by handle and function name. + private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String) + : absClassFunction = { + val name = new RefString + val desc = new RefString + val keyVarNumArgs = new RefString + val numArgs = new RefInt + val argNames = ListBuffer.empty[String] + val argTypes = ListBuffer.empty[String] + val argDescs = ListBuffer.empty[String] + + _LIB.mxSymbolGetAtomicSymbolInfo( + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) + + val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" + + val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => + val typeAndOption = CToScalaUtils.argumentCleaner(argType, returnType) + new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2) + } + new absClassFunction(aliasName, desc.value, argList.toList, returnType) + } +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index c1c3a429b408..ce5b532bc8b8 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -18,7 +18,7 @@ package org.apache.mxnet import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.OperatorBuildUtils +import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} import scala.annotation.StaticAnnotation import scala.collection.mutable.ListBuffer @@ -133,8 +133,8 @@ private[mxnet] object NDArrayMacro { impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)" // scalastyle:on // Combine and build the function string - val returnType = "org.apache.mxnet.NDArray" - var finalStr = s"def ${ndarrayfunction.name}New" + val returnType = "org.apache.mxnet.NDArrayFuncReturn" + var finalStr = s"def ${ndarrayfunction.name}" finalStr += s" (${argDef.mkString(",")}) : $returnType" finalStr += s" = {${impl.mkString("\n")}}" c.parse(finalStr).asInstanceOf[DefDef] @@ -175,63 +175,6 @@ private[mxnet] object NDArrayMacro { } - // Convert C++ Types to Scala Types - private def typeConversion(in : String, argType : String = "") : String = { - in match { - case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" - case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.NDArray" - case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" - => "Array[org.apache.mxnet.NDArray]" - case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" - case "int" | "intorNone" | "int(non-negative)" => "Int" - case "long" | "long(non-negative)" => "Long" - case "double" | "doubleorNone" => "Double" - case "string" => "String" - case "boolean" | "booleanorNone" => "Boolean" - case "tupleof" | "tupleof" | "ptr" | "" => "Any" - case default => throw new IllegalArgumentException( - s"Invalid type for args: $default, $argType") - } - } - - - /** - * By default, the argType come from the C++ API is a description more than a single word - * For Example: - * , , - * The three field shown above do not usually come at the same time - * This function used the above format to determine if the argument is - * optional, what is it Scala type and possibly pass in a default value - * @param argType Raw arguement Type description - * @return (Scala_Type, isOptional) - */ - private def argumentCleaner(argType : String) : (String, Boolean) = { - val spaceRemoved = argType.replaceAll("\\s+", "") - var commaRemoved : Array[String] = new Array[String](0) - // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} - if (spaceRemoved.charAt(0)== '{') { - val endIdx = spaceRemoved.indexOf('}') - commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") - commaRemoved(0) = "string" - } else { - commaRemoved = spaceRemoved.split(",") - } - // Optional Field - if (commaRemoved.length >= 3) { - // arg: Type, optional, default = Null - require(commaRemoved(1).equals("optional")) - require(commaRemoved(2).startsWith("default=")) - (typeConversion(commaRemoved(0), argType), true) - } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { - val tempType = typeConversion(commaRemoved(0), argType) - val tempOptional = tempType.equals("org.apache.mxnet.NDArray") - (tempType, tempOptional) - } else { - throw new IllegalArgumentException( - s"Unrecognized arg field: $argType, ${commaRemoved.length}") - } - - } // List and add all the atomic symbol functions to current module. @@ -273,7 +216,7 @@ private[mxnet] object NDArrayMacro { } // scalastyle:on println val argList = argNames zip argTypes map { case (argName, argType) => - val typeAndOption = argumentCleaner(argType) + val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.NDArray") new NDArrayArg(argName, typeAndOption._1, typeAndOption._2) } new NDArrayFunction(aliasName, argList.toList) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 234a8604cb91..bacbdb2e3075 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.OperatorBuildUtils +import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs @@ -178,65 +178,6 @@ private[mxnet] object SymbolImplMacros { result } - // Convert C++ Types to Scala Types - def typeConversion(in : String, argType : String = "") : String = { - in match { - case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" - case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol" - case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" - => "Array[org.apache.mxnet.Symbol]" - case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" - case "int" | "intorNone" | "int(non-negative)" => "Int" - case "long" | "long(non-negative)" => "Long" - case "double" | "doubleorNone" => "Double" - case "string" => "String" - case "boolean" => "Boolean" - case "tupleof" | "tupleof" | "ptr" | "" => "Any" - case default => throw new IllegalArgumentException( - s"Invalid type for args: $default, $argType") - } - } - - - /** - * By default, the argType come from the C++ API is a description more than a single word - * For Example: - * , , - * The three field shown above do not usually come at the same time - * This function used the above format to determine if the argument is - * optional, what is it Scala type and possibly pass in a default value - * @param argType Raw arguement Type description - * @return (Scala_Type, isOptional) - */ - def argumentCleaner(argType : String) : (String, Boolean) = { - val spaceRemoved = argType.replaceAll("\\s+", "") - var commaRemoved : Array[String] = new Array[String](0) - // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} - if (spaceRemoved.charAt(0)== '{') { - val endIdx = spaceRemoved.indexOf('}') - commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") - commaRemoved(0) = "string" - } else { - commaRemoved = spaceRemoved.split(",") - } - // Optional Field - if (commaRemoved.length >= 3) { - // arg: Type, optional, default = Null - require(commaRemoved(1).equals("optional")) - require(commaRemoved(2).startsWith("default=")) - (typeConversion(commaRemoved(0), argType), true) - } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { - val tempType = typeConversion(commaRemoved(0), argType) - val tempOptional = tempType.equals("org.apache.mxnet.Symbol") - (tempType, tempOptional) - } else { - throw new IllegalArgumentException( - s"Unrecognized arg field: $argType, ${commaRemoved.length}") - } - - } - - // List and add all the atomic symbol functions to current module. private def initSymbolModule(): List[SymbolFunction] = { val opNames = ListBuffer.empty[String] @@ -277,7 +218,7 @@ private[mxnet] object SymbolImplMacros { } // scalastyle:on println val argList = argNames zip argTypes map { case (argName, argType) => - val typeAndOption = argumentCleaner(argType) + val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.Symbol") new SymbolArg(argName, typeAndOption._1, typeAndOption._2) } new SymbolFunction(aliasName, argList.toList) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala new file mode 100644 index 000000000000..9d51ddcb674a --- /dev/null +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet.utils + +private[mxnet] object CToScalaUtils { + + + + // Convert C++ Types to Scala Types + def typeConversion(in : String, argType : String = "", returnType : String) : String = { + in match { + case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" + case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType + case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" + => s"Array[$returnType]" + case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" + case "int" | "intorNone" | "int(non-negative)" => "Int" + case "long" | "long(non-negative)" => "Long" + case "double" | "doubleorNone" => "Double" + case "string" => "String" + case "boolean" | "booleanorNone" => "Boolean" + case "tupleof" | "tupleof" | "ptr" | "" => "Any" + case default => throw new IllegalArgumentException( + s"Invalid type for args: $default, $argType") + } + } + + + /** + * By default, the argType come from the C++ API is a description more than a single word + * For Example: + * , , + * The three field shown above do not usually come at the same time + * This function used the above format to determine if the argument is + * optional, what is it Scala type and possibly pass in a default value + * @param argType Raw arguement Type description + * @return (Scala_Type, isOptional) + */ + def argumentCleaner(argType : String, returnType : String) : (String, Boolean) = { + val spaceRemoved = argType.replaceAll("\\s+", "") + var commaRemoved : Array[String] = new Array[String](0) + // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} + if (spaceRemoved.charAt(0)== '{') { + val endIdx = spaceRemoved.indexOf('}') + commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") + commaRemoved(0) = "string" + } else { + commaRemoved = spaceRemoved.split(",") + } + // Optional Field + if (commaRemoved.length >= 3) { + // arg: Type, optional, default = Null + require(commaRemoved(1).equals("optional")) + require(commaRemoved(2).startsWith("default=")) + (typeConversion(commaRemoved(0), argType, returnType), true) + } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { + val tempType = typeConversion(commaRemoved(0), argType, returnType) + val tempOptional = tempType.equals("org.apache.mxnet.Symbol") + (tempType, tempOptional) + } else { + throw new IllegalArgumentException( + s"Unrecognized arg field: $argType, ${commaRemoved.length}") + } + + } +} diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala index bc8be7df5fb1..5883a00c3315 100644 --- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -17,6 +17,7 @@ package org.apache.mxnet +import org.apache.mxnet.utils.CToScalaUtils import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory @@ -42,7 +43,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll { ) for (idx <- input.indices) { - val result = SymbolImplMacros.argumentCleaner(input(idx)) + val result = CToScalaUtils.argumentCleaner(input(idx), "org.apache.mxnet.Symbol") assert(result._1 === output(idx)._1 && result._2 === output(idx)._2) } }