Skip to content

Commit

Permalink
Ground work for making mleap scala 2.13 compatible
Browse files Browse the repository at this point in the history
1. Using sbt 1.9
2. Using most recent versions of the scala libraries where possible
3. Replaced scala-arm with more idiomatic scala 2.13+ way (scala-arm was
   not updated for a long time)
4. Updated scalatest.
5. Updated ClassloaderUtils to use java 9 compatible code.
  • Loading branch information
dotbg committed Aug 18, 2023
1 parent d1d4afd commit c79afde
Show file tree
Hide file tree
Showing 226 changed files with 1,083 additions and 1,002 deletions.
16 changes: 8 additions & 8 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Use container-based infrastructure
os: linux
dist: focal
dist: jammy

# Set default python env
# because the xgboost-spark library when running training code, it will
Expand All @@ -22,7 +22,7 @@ services:

language: scala
scala:
- 2.12.13
- 2.12.18
jdk:
- openjdk11

Expand All @@ -43,27 +43,27 @@ jobs:

- name: "Python 3.7 tests"
language: python
python: 3.7.9
python: 3.7.15
install:
- pip install tox
before_script:
- >
curl
--create-dirs -L -o /home/travis/.sbt/launchers/1.4.9/sbt-launch.jar
https://repo1.maven.org/maven2/org/scala-sbt/sbt-launch/1.4.9/sbt-launch-1.4.9.jar
--create-dirs -L -o /home/travis/.sbt/launchers/1.9.3/sbt-launch.jar
https://repo1.maven.org/maven2/org/scala-sbt/sbt-launch/1.9.3/sbt-launch-1.9.3.jar
script:
- make test_python37

- name: "Python 3.8 tests"
language: python
python: 3.8.15
python: 3.8.16
install:
- pip install tox
before_script:
- >
curl
--create-dirs -L -o /home/travis/.sbt/launchers/1.4.9/sbt-launch.jar
https://repo1.maven.org/maven2/org/scala-sbt/sbt-launch/1.4.9/sbt-launch-1.4.9.jar
--create-dirs -L -o /home/travis/.sbt/launchers/1.9.3/sbt-launch.jar
https://repo1.maven.org/maven2/org/scala-sbt/sbt-launch/1.9.3/sbt-launch-1.9.3.jar
script:
- make test_python38

Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ import org.apache.spark.ml.bundle.SparkBundleContext
import org.apache.spark.ml.feature.{Binarizer, StringIndexer}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import resource._
import scala.util.Using

val datasetName = "./examples/spark-demo.csv"

Expand All @@ -143,7 +143,7 @@ import resource._

// then serialize pipeline
val sbc = SparkBundleContext().withDataset(pipeline.transform(dataframe))
for(bf <- managed(BundleFile("jar:file:/tmp/simple-spark-pipeline.zip"))) {
Using(BundleFile("jar:file:/tmp/simple-spark-pipeline.zip")) { bf =>
pipeline.writeBundle.save(bf)(sbc).get
}
```
Expand Down Expand Up @@ -215,9 +215,9 @@ Because we export Spark and Scikit-learn pipelines to a standard format, we can
```scala
import ml.combust.bundle.BundleFile
import ml.combust.mleap.runtime.MleapSupport._
import resource._
import scala.util.Using
// load the Spark pipeline we saved in the previous section
val bundle = (for(bundleFile <- managed(BundleFile("jar:file:/tmp/simple-spark-pipeline.zip"))) yield {
val bundle = Using(BundleFile("jar:file:/tmp/simple-spark-pipeline.zip"))) { bundleFile =>
bundleFile.loadMleapBundle().get
}).opt.get

Expand Down Expand Up @@ -271,7 +271,7 @@ For more documentation, please see our [documentation](https://combust.github.io

## Building

Please ensure you have sbt 1.4.9, java 11, scala 2.12.13
Please ensure you have sbt 1.9.3, java 11, scala 2.12.13

1. Initialize the git submodules `git submodule update --init --recursive`
2. Run `sbt compile`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import java.nio.file.{Files, Paths}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.scalatest.FunSpec
import org.scalatest.funspec.AnyFunSpec

class HadoopBundleFileSystemSpec extends FunSpec {
class HadoopBundleFileSystemSpec extends org.scalatest.funspec.AnyFunSpec {
private val fs = FileSystem.get(new Configuration())
private val bundleFs = new HadoopBundleFileSystem(fs)

Expand Down
26 changes: 17 additions & 9 deletions bundle-ml/src/main/scala/ml/combust/bundle/BundleFile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@ import java.io.{Closeable, File}
import java.net.URI
import java.nio.file.{FileSystem, FileSystems, Files, Path}
import java.util.stream.Collectors

import ml.combust.bundle.dsl.{Bundle, BundleInfo}
import ml.combust.bundle.fs.BundleFileSystem
import ml.combust.bundle.json.JsonSupport.bundleBundleInfoFormat
import ml.combust.bundle.serializer.BundleSerializer
import ml.combust.bundle.json.JsonSupport._
import spray.json._
import resource._

import scala.collection.JavaConverters._
import scala.jdk.CollectionConverters._
import scala.language.implicitConversions
import scala.util.Try
import scala.util.{Try, Using}

/**
* Created by hollinwilkins on 12/24/16.
Expand All @@ -40,6 +38,14 @@ object BundleFile {
apply(new URI(unbackslash(uri)))
}

implicit def apply(path: Path): BundleFile = {
if(path.getFileName.toString.endsWith("zip")) {
apply(s"jar:${path.toUri}")
} else {
apply(path.toUri)
}
}

implicit def apply(file: File): BundleFile = {
val uri: String = if (file.getPath.endsWith(".zip")) {
s"jar:${file.toURI.toString}"
Expand Down Expand Up @@ -101,10 +107,12 @@ case class BundleFile(fs: FileSystem,

def writeNote(name: String, note: String): Try[String] = {
Files.createDirectories(fs.getPath(path.toString, "notes"))
(for(out <- managed(Files.newOutputStream(fs.getPath(path.toString, "notes", name)))) yield {
out.write(note.getBytes)
note
}).tried
Using(Files.newOutputStream(fs.getPath(path.toString, "notes", name))) {
out => {
out.write(note.getBytes)
note
}
}
}

def readNote(name: String): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ml.combust.bundle.fs.BundleFileSystem
import ml.combust.bundle.op.{OpModel, OpNode}
import ml.combust.mleap.ClassLoaderUtil

import scala.collection.JavaConverters._
import scala.jdk.CollectionConverters._

/** Trait for classes that contain a bundle registry.
*
Expand Down Expand Up @@ -39,7 +39,7 @@ object BundleRegistry {

val br = ops.foldLeft(Map[String, OpNode[_, _, _]]()) {
(m, opClass) =>
val opNode = cl.loadClass(opClass).newInstance().asInstanceOf[OpNode[_, _, _]]
val opNode = cl.loadClass(opClass).getDeclaredConstructor().newInstance().asInstanceOf[OpNode[_, _, _]]
m + (opNode.Model.opName -> opNode)
}.values.foldLeft(BundleRegistry(cl)) {
(br, opNode) => br.register(opNode)
Expand Down
15 changes: 6 additions & 9 deletions bundle-ml/src/main/scala/ml/combust/bundle/BundleWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package ml.combust.bundle

import java.net.URI
import java.nio.file.{Files, Paths}

import ml.combust.bundle.dsl.Bundle
import ml.combust.bundle.fs.BundleFileSystem
import ml.combust.bundle.serializer.{BundleSerializer, SerializationFormat}
import ml.combust.bundle.serializer.{BundleSerializer, SerializationFormat}

import scala.util.Try
import resource._
import scala.util.{Try, Using}

/**
* Created by hollinwilkins on 12/24/16.
Expand Down Expand Up @@ -37,16 +34,16 @@ Transformer <: AnyRef](root: Transformer,
def save(uri: URI)
(implicit context: Context): Try[Bundle[Transformer]] = uri.getScheme match {
case "jar" | "file" =>
(for (bf <- managed(BundleFile(uri))) yield {
Using(BundleFile(uri)) { bf =>
save(bf).get
}).tried
}
case _ =>
val tmpDir = Files.createTempDirectory("bundle")
val tmp = Paths.get(tmpDir.toString, "tmp.zip")

(for (bf <- managed(BundleFile(tmp.toFile))) yield {
Using(BundleFile(tmp.toFile)) { bf =>
save(bf).get
}).tried.map {
}.map {
r =>
context.bundleRegistry.fileSystemForUri(uri).save(uri, tmp.toFile)
r
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ package ml.combust.bundle.serializer

import java.io.Closeable
import java.nio.file.Files

import ml.combust.bundle.{BundleContext, BundleFile, HasBundleRegistry}
import ml.combust.bundle.dsl.Bundle
import ml.combust.bundle.json.JsonSupport._
import spray.json._
import resource._

import scala.util.Try
import scala.util.{Try, Using}

/** Class for serializing/deserializing Bundle.ML [[ml.combust.bundle.dsl.Bundle]] objects.
*
Expand All @@ -28,18 +26,17 @@ case class BundleSerializer[Context](context: Context,
*/
def write[Transformer <: AnyRef](bundle: Bundle[Transformer]): Try[Bundle[Transformer]] = Try {
val bundleContext = bundle.bundleContext(context, hr.bundleRegistry, file.fs, file.path)
implicit val format = bundleContext.format

Files.createDirectories(file.path)
NodeSerializer(bundleContext.bundleContext("root")).write(bundle.root).flatMap {
_ =>
(for (out <- managed(Files.newOutputStream(bundleContext.file(Bundle.bundleJson)))) yield {
val json = bundle.info.asBundle.toJson.prettyPrint.getBytes("UTF-8")
out.write(json)
bundle
}).tried
Using(Files.newOutputStream(bundleContext.file(Bundle.bundleJson))) {
out =>
val json = bundle.info.asBundle.toJson.prettyPrint.getBytes("UTF-8")
out.write(json)
bundle
}
}
}.flatMap(identity)
}.flatten

/** Read a bundle from the path.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,76 +1,40 @@
package ml.combust.bundle.serializer

import java.io.{File, FileInputStream, FileOutputStream}
import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream}

import resource._
import java.io.File
import java.util.zip.{ZipInputStream, ZipOutputStream}

/**
* Created by hollinwilkins on 9/11/16.
*/
@deprecated("Prefer ml.combust.bundle.util.FileUtil object.")
case class FileUtil() {
import ml.combust.bundle.util.{FileUtil => FileUtils}
@deprecated("use FileUtil.rmRF(Path).")
def rmRF(path: File): Array[(String, Boolean)] = {
Option(path.listFiles).map(_.flatMap(f => rmRF(f))).getOrElse(Array()) :+ (path.getPath -> path.delete)
FileUtils.rmRF(path.toPath)
}

@deprecated("use extract(Path, Path).")
def extract(source: File, dest: File): Unit = {
dest.mkdirs()
for(in <- managed(new ZipInputStream(new FileInputStream(source)))) {
extract(in, dest)
}
FileUtils.extract(source.toPath, dest.toPath)
}

@deprecated("use extract(ZipInputStream, Path).")
def extract(in: ZipInputStream, dest: File): Unit = {
dest.mkdirs()
val buffer = new Array[Byte](1024 * 1024)

var entry = in.getNextEntry
while(entry != null) {
if(entry.isDirectory) {
new File(dest, entry.getName).mkdirs()
} else {
val filePath = new File(dest, entry.getName)
for(out <- managed(new FileOutputStream(filePath))) {
var len = in.read(buffer)
while(len > 0) {
out.write(buffer, 0, len)
len = in.read(buffer)
}
}
}
entry = in.getNextEntry
}
FileUtils.extract(in, dest.toPath)
}

@deprecated("use FileUtil.extract(Path, Path).")
def zip(source: File, dest: File): Unit = {
for(out <- managed(new ZipOutputStream(new FileOutputStream(dest)))) {
zip(source, out)
}
FileUtils.zip(source.toPath, dest.toPath)
}

def zip(source: File, dest: ZipOutputStream): Unit = zip(source, source, dest)
@deprecated("use FileUtil.extract(Path, ZipOutputStream).")
def zip(source: File, dest: ZipOutputStream): Unit = FileUtils.zip(source.toPath, source.toPath, dest)

@deprecated("use FileUtil.extract(Path, Path, ZipOutputStream).")
def zip(base: File, source: File, dest: ZipOutputStream): Unit = {
val buffer = new Array[Byte](1024 * 1024)

for(files <- Option(source.listFiles);
file <- files) {
val name = file.toString.substring(base.toString.length + 1)

if(file.isDirectory) {
dest.putNextEntry(new ZipEntry(s"$name/"))
zip(base, file, dest)
} else {
dest.putNextEntry(new ZipEntry(name))

for (in <- managed(new FileInputStream(file))) {
var read = in.read(buffer)
while (read > 0) {
dest.write(buffer, 0, read)
read = in.read(buffer)
}
}
}
}
FileUtils.zip(base.toPath, source.toPath, dest)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ case class NodeSerializer[Context](bundleContext: BundleContext[Context]) {
val shape = op.shape(obj)(bundleContext)
Node(name = name, shape = shape)
}
}.flatMap(identity).flatMap {
}.flatten.flatMap {
node => Try(FormatNodeSerializer.serializer.write(bundleContext.file(Bundle.nodeFile), node))
}

Expand Down
Loading

0 comments on commit c79afde

Please sign in to comment.