Skip to content

Commit

Permalink
Merge pull request #7 from lamastex/features/scalable-trend-calculus
Browse files Browse the repository at this point in the history
Features/scalable trend calculus
  • Loading branch information
lamastex authored Jul 30, 2020
2 parents 423fcdb + 32df72a commit b0d2360
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 39 deletions.
15 changes: 0 additions & 15 deletions src/main/scala/org/lamastex/spark/trendcalculus/DateUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,4 @@ object DateUtils {
}
}

case class FrequencyMillisecond(
frequency: Frequency.Value,
milliseconds: Long
)

case class MonthYear(
quarter: Int,
half: Int
)

object Frequency extends Enumeration with Serializable {
val UNKWOWN, MILLI_SECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, HALF_YEAR, YEAR = Value
}


}
5 changes: 5 additions & 0 deletions src/main/scala/org/lamastex/spark/trendcalculus/Point.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ case class Point(
x: Long,
y: Double
)

case class TimePoint(
x: java.sql.Timestamp,
y: Double
)
10 changes: 0 additions & 10 deletions src/main/scala/org/lamastex/spark/trendcalculus/SeriesUtils.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.lamastex.spark.trendcalculus

import org.lamastex.spark.trendcalculus.DateUtils.Frequency

object SeriesUtils {

def movingAverage(timeseries: Array[Point], grouping: Frequency.Value): Unit = {
Expand Down Expand Up @@ -127,12 +125,4 @@ object SeriesUtils {
.sorted(Ordering.by((p: Point) => p.x))
}

object FillingStrategy extends Enumeration with Serializable {
val MEAN, LOCF, LINEAR, ZERO = Value
}

object AggregateStrategy extends Enumeration with Serializable {
val MEAN, SUM = Value
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@

package org.lamastex.spark.trendcalculus

import org.lamastex.spark.trendcalculus.DateUtils.Frequency

import scala.util.Try

class TrendCalculus(timeseries: Array[Point], groupingFrequency: Frequency.Value) extends Serializable {

val map: Map[DateUtils.Frequency.Value, Long] = DateUtils.frequencyMillisecond.map(t => (t.frequency, t.milliseconds)).toMap
val samplingFrequency: DateUtils.Frequency.Value = DateUtils.findFrequency(timeseries.map(_.x))
val map: Map[Frequency.Value, Long] = DateUtils.frequencyMillisecond.map(t => (t.frequency, t.milliseconds)).toMap
val samplingFrequency: Frequency.Value = DateUtils.findFrequency(timeseries.map(_.x))
val groupingMilliseconds: Long = map(groupingFrequency)
val samplingMilliseconds: Long = map(samplingFrequency)

Expand Down Expand Up @@ -49,7 +47,7 @@ class TrendCalculus(timeseries: Array[Point], groupingFrequency: Frequency.Value
.sortBy(_._1)

val high = sortedWSeries.last._2.head //earliest high price
val low = sortedWSeries.head._2.last //latest low price
val low = sortedWSeries.head._2.last //latest low price

val List(left, right) = List(high, low).sorted(Ordering.by((p: Point) => p.x))
val leftSeries = wSeries.filter(_.x < left.x)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.lamastex.spark.trendcalculus

import org.apache.spark.sql._
import org.apache.spark.sql.functions._

import org.apache.spark.sql.expressions.{Window, WindowSpec}
import java.sql.Timestamp

class TrendCalculus2(timeseries: Dataset[TimePoint], windowSize: Int, spark: SparkSession, initZero: Boolean = true) {

import spark.implicits._

val windowSpec = Window.rowsBetween(Window.unboundedPreceding, 0)

def getReversals: Dataset[(TimePoint, String)] = {

val tmp = if (initZero) {
timeseries
.map(r => (Point(r.x.getTime(), r.y), "dummy"))
.toDF("point","dummy")
.withColumn("fhls", new TsToTrend(windowSize)($"point").over(windowSpec))
} else {
val initPoint = try {
timeseries.first
} catch {
case e: NoSuchElementException => return spark.emptyDataset[(TimePoint, String)].toDF("reversalPoint", "reversal").as[(TimePoint, String)]
}

val init = Some(Row(Point(initPoint.x.getTime, initPoint.y)))

timeseries
.rdd.mapPartitionsWithIndex{ (id_x, iter) => if (id_x == 0) iter.drop(1) else iter }.toDS
.map(r => (Point(r.x.getTime(), r.y), "dummy"))
.toDF("point","dummy")
.withColumn("fhls", new TsToTrend(windowSize, init)($"point").over(windowSpec))
}

tmp
.select(explode($"fhls") as "tmp")
.select($"tmp.fhls".as("fhls"), $"tmp.trend".as("trend"), $"tmp.lastTrend".as("lastTrend"), $"tmp.lastFhls".as("lastFHLS"), $"tmp.reversal".as("reversal"))
.filter($"reversal" =!= 0)
.select($"lastFHLS", $"reversal" as "reversalInt")
.withColumn("reversalPoint", when($"reversalInt" === -1, $"lastFHLS.high").otherwise($"lastFHLS.low"))
.withColumn("reversal", when($"reversalInt" === -1, lit("Top")).otherwise(lit("Bottom")))
.select($"reversalPoint", $"reversal")
.filter($"reversalPoint.x" =!= 0L)
.map( r => (TimePoint(new Timestamp(r.getStruct(0).getLong(0)), r.getStruct(0).getDouble(1)), r.getString(1)))
.select($"_1" as "reversalPoint", $"_2" as "reversal")
.as[(TimePoint,String)]
}
}
179 changes: 179 additions & 0 deletions src/main/scala/org/lamastex/spark/trendcalculus/TsToTrend.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package org.lamastex.spark.trendcalculus

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class TsToTrend(windowSize: Int, initLine: Option[Row] = None) extends UserDefinedAggregateFunction {

val FIRST = 0
val HIGH = 1
val LOW = 2
val SECOND = 3

val X = 0
val Y = 1

val pointSchema = StructType(
StructField("x", LongType, false) ::
StructField("y", DoubleType, false) ::
Nil
)

val fhlsSchema = StructType(
StructField("first", pointSchema, false) ::
StructField("high", pointSchema, false) ::
StructField("low", pointSchema, false) ::
StructField("second", pointSchema, false) ::
StructField("sign", IntegerType, false) ::
Nil
)

val emptyPoint = Row(0L, 0.0)

val emptyFhls = Row(
emptyPoint,
emptyPoint,
emptyPoint,
emptyPoint,
0
)

// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType = StructType(
StructField("point", pointSchema) ::
Nil
)

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("bufferedPoints", ArrayType(pointSchema)) ::
StructField("counter", IntegerType) ::
StructField("result", dataType) ::
Nil
)

// This is the output type of your aggregatation function.
override def dataType: DataType = ArrayType(
StructType(
StructField("fhls", fhlsSchema) ::
StructField("trend", IntegerType) ::
StructField("lastFhls", fhlsSchema) ::
StructField("lastTrend", IntegerType) ::
StructField("reversal", IntegerType) ::
Nil
)
)

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
val init = initLine.getOrElse(Row(Point(0L, 0.0))).getAs[Point](0)
val initPoint = Row(init.x, init.y)
val lastFhls = Row(initPoint, initPoint, initPoint, initPoint, 0)
buffer(0) = (1 to 2*windowSize).map(_ => Row(0L, 0.0)).toSeq
buffer(1) = 0
buffer(2) = Seq[Row](Row(lastFhls,1,emptyFhls,0,0))
}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val newPoint = input.getStruct(0)
val newSeq = buffer.getSeq(0).+:(newPoint) //Prepend new point to buffer list
val newCounter = (buffer.getInt(1) + 1) % windowSize
buffer(0) = newSeq.take(2*windowSize)
buffer(1) = newCounter
if (newCounter == 0) {
val prevRes: Row = buffer.getSeq(2).last

val currFHLS = makeFHLS(newSeq.take(windowSize))
val prevFHLS = prevRes.getStruct(0)
val lastTrend = prevRes.getInt(1)
buffer(2) = getTrend(prevFHLS,currFHLS,lastTrend)
}
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
//buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
val counter = buffer.getInt(1)

if (counter == 0) {
buffer.getSeq(2)
} else {
Seq.empty[Row]
}

}

private def makeFHLS(points: Seq[Row]): Row = {
val sortedByVal = points.groupBy(_.getDouble(Y)).map {case (v,vSeq) => (v,vSeq.sortBy(_.getLong(X)))}.toSeq.sortBy(_._1)
val high = sortedByVal.last._2.head //Should be last head
val low = sortedByVal.head._2.head //Should be head last
val List(first,second) = if (low.getLong(X) < high.getLong(X)) List(low,high) else List(high,low)
val sign = if (first.getDouble(Y) < second.getDouble(Y)) 1 else -1

Row(first, high, low, second, sign)
}

private def getTrend(prevFHLS: Row, currFHLS: Row, lastTrend: Int): Seq[Row] = {
val prevHigh = prevFHLS.getStruct(HIGH)
val currHigh = currFHLS.getStruct(HIGH)
val prevLow = prevFHLS.getStruct(LOW)
val currLow = currFHLS.getStruct(LOW)

var currTrend = (
(currHigh.getDouble(Y) - prevHigh.getDouble(Y)).signum +
(currLow.getDouble(Y) - prevLow.getDouble(Y)).signum).signum

var currRev = (currTrend - lastTrend).signum

if (currTrend != 0) {
return Seq[Row](Row(currFHLS,currTrend,prevFHLS,lastTrend,currRev))
}

val intrFirst = prevFHLS.getStruct(SECOND)
val intrSecond = currFHLS.getStruct(FIRST)
val List(intrLow,intrHigh) = if (intrFirst.getDouble(Y) < intrSecond.getDouble(Y)) {
List(intrFirst, intrSecond)
} else {
List(intrSecond, intrFirst)
}
val intrSign = if (intrFirst.getDouble(Y) < intrSecond.getDouble(Y)) {
1
} else {
-1
}

var intrTrend = (
(intrHigh.getDouble(Y) - prevHigh.getDouble(Y)).signum +
(intrLow.getDouble(Y) - prevLow.getDouble(Y)).signum).signum

currTrend = (
(currHigh.getDouble(Y) - intrHigh.getDouble(Y)).signum +
(currLow.getDouble(Y) - intrLow.getDouble(Y)).signum).signum

if (intrTrend == 0) intrTrend = lastTrend
if (currTrend == 0) currTrend = intrTrend

val intrFHLS = Row(
intrFirst,
intrHigh,
intrLow,
intrSecond,
intrSign
)

val intrRev = (intrTrend - lastTrend).signum
currRev = (currTrend - intrTrend).signum

Seq[Row](Row(intrFHLS, intrTrend, prevFHLS, lastTrend, intrRev), Row(currFHLS, currTrend, intrFHLS, intrTrend, currRev))
}
}
22 changes: 22 additions & 0 deletions src/main/scala/org/lamastex/spark/trendcalculus/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,26 @@ package object trendcalculus {
yfin(Seq(inputPath): _*)
}
}

case class FrequencyMillisecond(
frequency: Frequency.Value,
milliseconds: Long
)

case class MonthYear(
quarter: Int,
half: Int
)

object FillingStrategy extends Enumeration with Serializable {
val MEAN, LOCF, LINEAR, ZERO = Value
}

object AggregateStrategy extends Enumeration with Serializable {
val MEAN, SUM = Value
}

object Frequency extends Enumeration with Serializable {
val UNKWOWN, MILLI_SECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, HALF_YEAR, YEAR = Value
}
}
8 changes: 1 addition & 7 deletions src/test/scala/org/lamastex/spark/trendcalculus/Brent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,14 @@ class BrentTest extends SparkSpec with Matchers {
import org.apache.spark.sql.functions._
import spark.implicits._

import org.lamastex.spark.trendcalculus.DateUtils.Frequency
import org.lamastex.spark.trendcalculus.SeriesUtils.FillingStrategy
import org.lamastex.spark.trendcalculus._

/* val spark = SparkSession.builder().appName("gdelt-harness").getOrCreate()
val sqlContext = spark.sqlContext */

val dateUDF = udf((s: String) => new java.sql.Date(new java.text.SimpleDateFormat("yyyy-MM-dd").parse(s).getTime))
val valueUDF = udf((s: String) => s.toDouble)

// Only keep brent with data we know could match ours
val filePathRoot: String = "file:///root/GIT/lamastex/spark-trend-calculus/src/test/resources/org/lamastex/spark/trendcalculus/"
val filePathRoot: String = "src/test/resources/org/lamastex/spark/trendcalculus/"
val DF = spark.read.option("header", "true").option("inferSchema", "true").csv(filePathRoot+"brent.csv").filter(year(col("DATE")) >= 2015)
// val DF = Source.fromInputStream(this.getClass.getResourceAsStream("brent.csv"), "UTF-8").filter(year(col("DATE")) >= 2015)
DF.show
DF.createOrReplaceTempView("brent")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class ForeignExchangeTest extends SparkSpec with Matchers {
sparkTest("Foreign Exchange Trend Calculus") { spark =>
import spark.implicits._

import org.lamastex.spark.trendcalculus.DateUtils.Frequency
import org.lamastex.spark.trendcalculus.SeriesUtils.FillingStrategy
import org.lamastex.spark.trendcalculus._

val fxDF = Source
Expand Down
Loading

0 comments on commit b0d2360

Please sign in to comment.