Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/scalable trend calculus #7

Merged
merged 6 commits into from
Jul 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions src/main/scala/org/lamastex/spark/trendcalculus/DateUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,4 @@ object DateUtils {
}
}

case class FrequencyMillisecond(
frequency: Frequency.Value,
milliseconds: Long
)

case class MonthYear(
quarter: Int,
half: Int
)

object Frequency extends Enumeration with Serializable {
val UNKWOWN, MILLI_SECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, HALF_YEAR, YEAR = Value
}


}
5 changes: 5 additions & 0 deletions src/main/scala/org/lamastex/spark/trendcalculus/Point.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ case class Point(
x: Long,
y: Double
)

case class TimePoint(
x: java.sql.Timestamp,
y: Double
)
10 changes: 0 additions & 10 deletions src/main/scala/org/lamastex/spark/trendcalculus/SeriesUtils.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.lamastex.spark.trendcalculus

import org.lamastex.spark.trendcalculus.DateUtils.Frequency

object SeriesUtils {

def movingAverage(timeseries: Array[Point], grouping: Frequency.Value): Unit = {
Expand Down Expand Up @@ -127,12 +125,4 @@ object SeriesUtils {
.sorted(Ordering.by((p: Point) => p.x))
}

object FillingStrategy extends Enumeration with Serializable {
val MEAN, LOCF, LINEAR, ZERO = Value
}

object AggregateStrategy extends Enumeration with Serializable {
val MEAN, SUM = Value
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@

package org.lamastex.spark.trendcalculus

import org.lamastex.spark.trendcalculus.DateUtils.Frequency

import scala.util.Try

class TrendCalculus(timeseries: Array[Point], groupingFrequency: Frequency.Value) extends Serializable {

val map: Map[DateUtils.Frequency.Value, Long] = DateUtils.frequencyMillisecond.map(t => (t.frequency, t.milliseconds)).toMap
val samplingFrequency: DateUtils.Frequency.Value = DateUtils.findFrequency(timeseries.map(_.x))
val map: Map[Frequency.Value, Long] = DateUtils.frequencyMillisecond.map(t => (t.frequency, t.milliseconds)).toMap
val samplingFrequency: Frequency.Value = DateUtils.findFrequency(timeseries.map(_.x))
val groupingMilliseconds: Long = map(groupingFrequency)
val samplingMilliseconds: Long = map(samplingFrequency)

Expand Down Expand Up @@ -49,7 +47,7 @@ class TrendCalculus(timeseries: Array[Point], groupingFrequency: Frequency.Value
.sortBy(_._1)

val high = sortedWSeries.last._2.head //earliest high price
val low = sortedWSeries.head._2.last //latest low price
val low = sortedWSeries.head._2.last //latest low price

val List(left, right) = List(high, low).sorted(Ordering.by((p: Point) => p.x))
val leftSeries = wSeries.filter(_.x < left.x)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.lamastex.spark.trendcalculus

import org.apache.spark.sql._
import org.apache.spark.sql.functions._

import org.apache.spark.sql.expressions.{Window, WindowSpec}
import java.sql.Timestamp

class TrendCalculus2(timeseries: Dataset[TimePoint], windowSize: Int, spark: SparkSession, initZero: Boolean = true) {

import spark.implicits._

val windowSpec = Window.rowsBetween(Window.unboundedPreceding, 0)

def getReversals: Dataset[(TimePoint, String)] = {

val tmp = if (initZero) {
timeseries
.map(r => (Point(r.x.getTime(), r.y), "dummy"))
.toDF("point","dummy")
.withColumn("fhls", new TsToTrend(windowSize)($"point").over(windowSpec))
} else {
val initPoint = try {
timeseries.first
} catch {
case e: NoSuchElementException => return spark.emptyDataset[(TimePoint, String)].toDF("reversalPoint", "reversal").as[(TimePoint, String)]
}

val init = Some(Row(Point(initPoint.x.getTime, initPoint.y)))

timeseries
.rdd.mapPartitionsWithIndex{ (id_x, iter) => if (id_x == 0) iter.drop(1) else iter }.toDS
.map(r => (Point(r.x.getTime(), r.y), "dummy"))
.toDF("point","dummy")
.withColumn("fhls", new TsToTrend(windowSize, init)($"point").over(windowSpec))
}

tmp
.select(explode($"fhls") as "tmp")
.select($"tmp.fhls".as("fhls"), $"tmp.trend".as("trend"), $"tmp.lastTrend".as("lastTrend"), $"tmp.lastFhls".as("lastFHLS"), $"tmp.reversal".as("reversal"))
.filter($"reversal" =!= 0)
.select($"lastFHLS", $"reversal" as "reversalInt")
.withColumn("reversalPoint", when($"reversalInt" === -1, $"lastFHLS.high").otherwise($"lastFHLS.low"))
.withColumn("reversal", when($"reversalInt" === -1, lit("Top")).otherwise(lit("Bottom")))
.select($"reversalPoint", $"reversal")
.filter($"reversalPoint.x" =!= 0L)
.map( r => (TimePoint(new Timestamp(r.getStruct(0).getLong(0)), r.getStruct(0).getDouble(1)), r.getString(1)))
.select($"_1" as "reversalPoint", $"_2" as "reversal")
.as[(TimePoint,String)]
}
}
179 changes: 179 additions & 0 deletions src/main/scala/org/lamastex/spark/trendcalculus/TsToTrend.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package org.lamastex.spark.trendcalculus

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class TsToTrend(windowSize: Int, initLine: Option[Row] = None) extends UserDefinedAggregateFunction {

val FIRST = 0
val HIGH = 1
val LOW = 2
val SECOND = 3

val X = 0
val Y = 1

val pointSchema = StructType(
StructField("x", LongType, false) ::
StructField("y", DoubleType, false) ::
Nil
)

val fhlsSchema = StructType(
StructField("first", pointSchema, false) ::
StructField("high", pointSchema, false) ::
StructField("low", pointSchema, false) ::
StructField("second", pointSchema, false) ::
StructField("sign", IntegerType, false) ::
Nil
)

val emptyPoint = Row(0L, 0.0)

val emptyFhls = Row(
emptyPoint,
emptyPoint,
emptyPoint,
emptyPoint,
0
)

// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType = StructType(
StructField("point", pointSchema) ::
Nil
)

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("bufferedPoints", ArrayType(pointSchema)) ::
StructField("counter", IntegerType) ::
StructField("result", dataType) ::
Nil
)

// This is the output type of your aggregatation function.
override def dataType: DataType = ArrayType(
StructType(
StructField("fhls", fhlsSchema) ::
StructField("trend", IntegerType) ::
StructField("lastFhls", fhlsSchema) ::
StructField("lastTrend", IntegerType) ::
StructField("reversal", IntegerType) ::
Nil
)
)

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
val init = initLine.getOrElse(Row(Point(0L, 0.0))).getAs[Point](0)
val initPoint = Row(init.x, init.y)
val lastFhls = Row(initPoint, initPoint, initPoint, initPoint, 0)
buffer(0) = (1 to 2*windowSize).map(_ => Row(0L, 0.0)).toSeq
buffer(1) = 0
buffer(2) = Seq[Row](Row(lastFhls,1,emptyFhls,0,0))
}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val newPoint = input.getStruct(0)
val newSeq = buffer.getSeq(0).+:(newPoint) //Prepend new point to buffer list
val newCounter = (buffer.getInt(1) + 1) % windowSize
buffer(0) = newSeq.take(2*windowSize)
buffer(1) = newCounter
if (newCounter == 0) {
val prevRes: Row = buffer.getSeq(2).last

val currFHLS = makeFHLS(newSeq.take(windowSize))
val prevFHLS = prevRes.getStruct(0)
val lastTrend = prevRes.getInt(1)
buffer(2) = getTrend(prevFHLS,currFHLS,lastTrend)
}
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
//buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
val counter = buffer.getInt(1)

if (counter == 0) {
buffer.getSeq(2)
} else {
Seq.empty[Row]
}

}

private def makeFHLS(points: Seq[Row]): Row = {
val sortedByVal = points.groupBy(_.getDouble(Y)).map {case (v,vSeq) => (v,vSeq.sortBy(_.getLong(X)))}.toSeq.sortBy(_._1)
val high = sortedByVal.last._2.head //Should be last head
val low = sortedByVal.head._2.head //Should be head last
val List(first,second) = if (low.getLong(X) < high.getLong(X)) List(low,high) else List(high,low)
val sign = if (first.getDouble(Y) < second.getDouble(Y)) 1 else -1

Row(first, high, low, second, sign)
}

private def getTrend(prevFHLS: Row, currFHLS: Row, lastTrend: Int): Seq[Row] = {
val prevHigh = prevFHLS.getStruct(HIGH)
val currHigh = currFHLS.getStruct(HIGH)
val prevLow = prevFHLS.getStruct(LOW)
val currLow = currFHLS.getStruct(LOW)

var currTrend = (
(currHigh.getDouble(Y) - prevHigh.getDouble(Y)).signum +
(currLow.getDouble(Y) - prevLow.getDouble(Y)).signum).signum

var currRev = (currTrend - lastTrend).signum

if (currTrend != 0) {
return Seq[Row](Row(currFHLS,currTrend,prevFHLS,lastTrend,currRev))
}

val intrFirst = prevFHLS.getStruct(SECOND)
val intrSecond = currFHLS.getStruct(FIRST)
val List(intrLow,intrHigh) = if (intrFirst.getDouble(Y) < intrSecond.getDouble(Y)) {
List(intrFirst, intrSecond)
} else {
List(intrSecond, intrFirst)
}
val intrSign = if (intrFirst.getDouble(Y) < intrSecond.getDouble(Y)) {
1
} else {
-1
}

var intrTrend = (
(intrHigh.getDouble(Y) - prevHigh.getDouble(Y)).signum +
(intrLow.getDouble(Y) - prevLow.getDouble(Y)).signum).signum

currTrend = (
(currHigh.getDouble(Y) - intrHigh.getDouble(Y)).signum +
(currLow.getDouble(Y) - intrLow.getDouble(Y)).signum).signum

if (intrTrend == 0) intrTrend = lastTrend
if (currTrend == 0) currTrend = intrTrend

val intrFHLS = Row(
intrFirst,
intrHigh,
intrLow,
intrSecond,
intrSign
)

val intrRev = (intrTrend - lastTrend).signum
currRev = (currTrend - intrTrend).signum

Seq[Row](Row(intrFHLS, intrTrend, prevFHLS, lastTrend, intrRev), Row(currFHLS, currTrend, intrFHLS, intrTrend, currRev))
}
}
22 changes: 22 additions & 0 deletions src/main/scala/org/lamastex/spark/trendcalculus/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,26 @@ package object trendcalculus {
yfin(Seq(inputPath): _*)
}
}

case class FrequencyMillisecond(
frequency: Frequency.Value,
milliseconds: Long
)

case class MonthYear(
quarter: Int,
half: Int
)

object FillingStrategy extends Enumeration with Serializable {
val MEAN, LOCF, LINEAR, ZERO = Value
}

object AggregateStrategy extends Enumeration with Serializable {
val MEAN, SUM = Value
}

object Frequency extends Enumeration with Serializable {
val UNKWOWN, MILLI_SECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, HALF_YEAR, YEAR = Value
}
}
8 changes: 1 addition & 7 deletions src/test/scala/org/lamastex/spark/trendcalculus/Brent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,14 @@ class BrentTest extends SparkSpec with Matchers {
import org.apache.spark.sql.functions._
import spark.implicits._

import org.lamastex.spark.trendcalculus.DateUtils.Frequency
import org.lamastex.spark.trendcalculus.SeriesUtils.FillingStrategy
import org.lamastex.spark.trendcalculus._

/* val spark = SparkSession.builder().appName("gdelt-harness").getOrCreate()
val sqlContext = spark.sqlContext */

val dateUDF = udf((s: String) => new java.sql.Date(new java.text.SimpleDateFormat("yyyy-MM-dd").parse(s).getTime))
val valueUDF = udf((s: String) => s.toDouble)

// Only keep brent with data we know could match ours
val filePathRoot: String = "file:///root/GIT/lamastex/spark-trend-calculus/src/test/resources/org/lamastex/spark/trendcalculus/"
val filePathRoot: String = "src/test/resources/org/lamastex/spark/trendcalculus/"
val DF = spark.read.option("header", "true").option("inferSchema", "true").csv(filePathRoot+"brent.csv").filter(year(col("DATE")) >= 2015)
// val DF = Source.fromInputStream(this.getClass.getResourceAsStream("brent.csv"), "UTF-8").filter(year(col("DATE")) >= 2015)
DF.show
DF.createOrReplaceTempView("brent")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class ForeignExchangeTest extends SparkSpec with Matchers {
sparkTest("Foreign Exchange Trend Calculus") { spark =>
import spark.implicits._

import org.lamastex.spark.trendcalculus.DateUtils.Frequency
import org.lamastex.spark.trendcalculus.SeriesUtils.FillingStrategy
import org.lamastex.spark.trendcalculus._

val fxDF = Source
Expand Down
Loading