Skip to content
Akshay Utkarsh Sharma edited this page Aug 27, 2016 · 16 revisions

Spark-Transformers

Library for exporting apache spark models in Java ecosystem.

Goal of this library is to :

  • Provide a way to export apache spark models/transformations into a custom format which can be imported back into a java object.
  • Provide a way to do model predictions in java ecosystem same as it is in apache spark.

#Usage

Add jar to classpath

http://spark.apache.org/docs/latest/programming-guide.html#using-the-shell

./bin/spark-shell --master local --packages "com.flipkart.fdp.ml:adapters-1.6_2.11:0.2.2"

Train, export in spark.

Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 1.6.1
      /_/

Using Scala version 2.10.5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_74)
Type in expressions to have them evaluated.
Type :help for more information.
Spark context available as sc.
SQL context available as sqlContext.

scala> import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.classification.LogisticRegression

scala> // Load training data
scala> val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
training: org.apache.spark.sql.DataFrame = [label: double, features: vector]

scala> val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
lr: org.apache.spark.ml.classification.LogisticRegression = logreg_997fb6ae63ea

scala> // Fit the model
scala> val lrModel = lr.fit(training)
lrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_997fb6ae63ea

scala> // Print the coefficients and intercept for logistic regression
scala> println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
Coefficients: (692,[244,263,272,300,301,328,350,351,378,379,405,406,407,428,433,434,455,456,461,462,483,484,489,490,496,511,512,517,539,540,568],[-7.35398352418818E-5,-9.102738505589477E-5,-1.9467430546904317E-4,-2.0300642473486697E-4,-3.147618331486218E-5,-6.842977602660754E-5,1.588362689823927E-5,1.4023497091371932E-5,3.543204752496866E-4,1.1443272898171088E-4,1.001671238366663E-4,6.014109303795483E-4,2.840248179122761E-4,-1.1541084736508822E-4,3.8599688631290636E-4,6.350195574241066E-4,-1.1506412384575666E-4,-1.5271865864986795E-4,2.8049338089942136E-4,6.070117471191632E-4,-2.0084596632474375E-4,-1.421075579290124E-4,2.7390103411608827E-4,2.773045624496812E-4,-9.838027027269317E-5,-3.808522443517698E-4,-2.5315198008554995E-4,2.7747714770754296E-4,-2.4436197639191937E-4,-0.0015394744687597776,-2.3073328411331252E-4]) Intercept: 0.22456315961250317


scala> import com.flipkart.fdp.ml.export.ModelExporter
import com.flipkart.fdp.ml.export.ModelExporter

scala> //Export the trained model
scala> val exportedModel = ModelExporter.export(lrModel, training)
exportedModel: Array[Byte] = Array(123, 34, 95, 109, 111, 100, 101, 108, 95, 105, 110, 102, 111, 34, 58, 34, 123, 92, 34, 119, 101, 105, 103, 104, 116, 115, 92, 34, 58, 91, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, 48, 46, 48, 44, ...

scala> //save exportedModel somewhere
scala> import java.io._
import java.io._

scala> val bos = new BufferedOutputStream(new FileOutputStream("/tmp/ExportedLRModel"))
bos: java.io.BufferedOutputStream = java.io.BufferedOutputStream@1cbb6037

scala> Stream.continually(bos.write(exportedModel))
res1: scala.collection.immutable.Stream[Unit] = Stream((), ?)

scala> bos.close()

Import and predict in java.

Add maven dependency

<dependency>
  <groupId>com.flipkart.fdp.ml</groupId>
  <artifactId>models-info</artifactId>
  <version>0.2.2</version>
  <type>pom</type>
</dependency>
public class ImportTest {

    @Test
    public void testImport() throws IOException {
        //read back saved model
        Path path = Paths.get("/tmp/ExportedLRModel");
        byte[] exportedModel = Files.readAllBytes(path);

        //Import and get Transformer
        Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

        //predict
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 124.0, 253.0, 255.0, 63.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 96.0,
                244.0, 251.0, 253.0, 62.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 127.0, 251.0, 251.0,
                253.0, 62.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 68.0, 236.0, 251.0, 211.0, 31.0, 8.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 60.0, 228.0, 251.0, 251.0, 94.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 155.0, 253.0, 253.0, 189.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 20.0, 253.0, 251.0, 235.0, 66.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.0, 205.0, 253.0, 251.0, 126.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 104.0, 251.0, 253.0, 184.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 80.0, 240.0,
                251.0, 193.0, 23.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.0, 253.0, 253.0, 253.0, 159.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 151.0, 251.0, 251.0, 251.0, 39.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 48.0, 221.0, 251.0, 251.0, 172.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 234.0, 251.0, 251.0, 196.0, 12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 253.0, 251.0, 251.0, 89.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                159.0, 255.0, 253.0, 253.0, 31.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 48.0, 228.0, 253.0, 247.0,
                140.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 64.0, 251.0, 253.0, 220.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 64.0, 251.0, 253.0, 220.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 24.0, 193.0, 253.0, 220.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0});
        transformer.transform(data);
        double predicted = (double) data.get("prediction");
        System.out.println(predicted);

    }
}

Getting help

For help regarding usage, drop an email to [email protected]