Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Third optimization batch — map fusion #95

Merged
merged 7 commits into from
Dec 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2017 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cats.effect.benchmarks

import java.util.concurrent.TimeUnit
import cats.effect.IO
import org.openjdk.jmh.annotations._

/** To do comparative benchmarks between versions:
*
* benchmarks/run-benchmark MapCallsBenchmark
*
* This will generate results in `benchmarks/results`.
*
* Or to run the benchmark from within SBT:
*
* jmh:run -i 10 -wi 10 -f 2 -t 1 cats.effect.benchmarks.MapCallsBenchmark
*
* Which means "10 iterations", "10 warm-up iterations", "2 forks", "1 thread".
* Please note that benchmarks should be usually executed at least in
* 10 iterations (as a rule of thumb), but more is better.
*/
@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
class MapCallsBenchmark {
import MapCallsBenchmark.test

@Benchmark
def one(): Long = test(12000, 1)

@Benchmark
def batch30(): Long = test(12000 / 30, 30)

@Benchmark
def batch120(): Long = test(12000 / 120, 120)
}

object MapCallsBenchmark {

def test(iterations: Int, batch: Int): Long = {
val f = (x: Int) => x + 1
var io = IO(0)

var j = 0
while (j < batch) { io = io.map(f); j += 1 }

var sum = 0L
var i = 0
while (i < iterations) {
sum += io.unsafeRunSync()
i += 1
}
sum
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright 2017 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cats.effect.benchmarks

import java.util.concurrent.TimeUnit
import cats.effect.IO
import org.openjdk.jmh.annotations._

/** To do comparative benchmarks between versions:
*
* benchmarks/run-benchmark MapStreamBenchmark
*
* This will generate results in `benchmarks/results`.
*
* Or to run the benchmark from within SBT:
*
* jmh:run -i 10 -wi 10 -f 2 -t 1 cats.effect.benchmarks.MapStreamBenchmark
*
* Which means "10 iterations", "10 warm-up iterations", "2 forks", "1 thread".
* Please note that benchmarks should be usually executed at least in
* 10 iterations (as a rule of thumb), but more is better.
*/
@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
class MapStreamBenchmark {
import MapStreamBenchmark.streamTest

@Benchmark
def one(): Long = streamTest(12000, 1)

@Benchmark
def batch30(): Long = streamTest(1000, 30)

@Benchmark
def batch120(): Long = streamTest(100, 120)
}

object MapStreamBenchmark {
def streamTest(times: Int, batchSize: Int): Long = {
var stream = range(0, times)
var i = 0
while (i < batchSize) {
stream = mapStream(addOne)(stream)
i += 1
}
sum(0)(stream).unsafeRunSync()
}

final case class Stream(value: Int, next: IO[Option[Stream]])
val addOne = (x: Int) => x + 1

def range(from: Int, until: Int): Option[Stream] =
if (from < until)
Some(Stream(from, IO(range(from + 1, until))))
else
None

def mapStream(f: Int => Int)(box: Option[Stream]): Option[Stream] =
box match {
case Some(Stream(value, next)) =>
Some(Stream(f(value), next.map(mapStream(f))))
case None =>
None
}

def sum(acc: Long)(box: Option[Stream]): IO[Long] =
box match {
case Some(Stream(value, next)) =>
next.flatMap(sum(acc + value))
case None =>
IO.pure(acc)
}
}
9 changes: 9 additions & 0 deletions core/js/src/main/scala/cats/effect/internals/IOPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,13 @@ private[effect] object IOPlatform {
f(a)
}
}

/**
* Establishes the maximum stack depth for `IO#map` operations
* for JavaScript.
*
* The default for JavaScript is 32, from which we substract 1
* as an optimization.
*/
private[effect] final val fusionMaxStackDepth = 31
}
33 changes: 32 additions & 1 deletion core/jvm/src/main/scala/cats/effect/internals/IOPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import cats.effect.IO

import scala.concurrent.blocking
import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.util.Either
import scala.util.{Either, Try}

private[effect] object IOPlatform {
/**
Expand Down Expand Up @@ -84,4 +84,35 @@ private[effect] object IOPlatform {
true
}
}

/**
* Establishes the maximum stack depth for `IO#map` operations.
*
* The default is `128`, from which we substract one as an
* optimization. This default has been reached like this:
*
* - according to official docs, the default stack size on 32-bits
* Windows and Linux was 320 KB, whereas for 64-bits it is 1024 KB
* - according to measurements chaining `Function1` references uses
* approximately 32 bytes of stack space on a 64 bits system;
* this could be lower if "compressed oops" is activated
* - therefore a "map fusion" that goes 128 in stack depth can use
* about 4 KB of stack space
*
* If this parameter becomes a problem, it can be tuned by setting
* the `cats.effect.fusionMaxStackDepth` environment variable when
* executing the Java VM:
*
* <pre>
* java -Dcats.effect.fusionMaxStackDepth=32 \
* ...
* </pre>
*/
private[effect] final val fusionMaxStackDepth =
Option(System.getProperty("cats.effect.fusionMaxStackDepth", ""))
.filter(s => s != null && s.nonEmpty)
.flatMap(s => Try(s.toInt).toOption)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to depend on a logger, but is it worth it to explain on stderr why we choked?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to not do it, since it introduces extra code — but I don't care much and if it's a popular demand, then OK.

What I'm thinking is that people won't modify this parameter unless they are in big trouble and we can have two issues:

  1. given my calculations, the default value seems fine, but we might underestimate stack growth in common usage
  2. we don't control all possible virtual machines, I have no idea for example what's the default stack size on Android or other non-Oracle JVMs

So increasing it won't increase performance unless used for very narrow use-cases and if people hit the stack limit because of this default, then we probably need to lower this limit in the library, with the overriding option being made available only to empower people to fix it without having to wait for another release.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll buy that.

.filter(_ > 0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming what this looks like at 1 which is then reduced to 0 is that every operation is flatMapped rather than Map

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the intention — which made me realize that when that counter gets reset we should use a Map(this, f, 0) instead of a FlatMap(this, f.andThen(pure)).

.map(_ - 1)
.getOrElse(127)
}
28 changes: 23 additions & 5 deletions core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package effect

import cats.effect.internals.IOFrame.ErrorHandler
import cats.effect.internals.{IOFrame, IOPlatform, IORunLoop, NonFatal}

import cats.effect.internals.IOPlatform.fusionMaxStackDepth
import scala.annotation.unchecked.uncheckedVariance
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.concurrent.duration._
Expand Down Expand Up @@ -91,9 +91,14 @@ sealed abstract class IO[+A] {
*/
final def map[B](f: A => B): IO[B] =
this match {
case Pure(a) => try Pure(f(a)) catch { case NonFatal(e) => RaiseError(e) }
case ref @ RaiseError(_) => ref
case _ => flatMap(a => Pure(f(a)))
case Map(source, g, index) =>
// Allowed to do fixed number of map operations fused before
// resetting the counter in order to avoid stack overflows;
// See `IOPlatform` for details on this maximum.
if (index != fusionMaxStackDepth) Map(source, g.andThen(f), index + 1)
else Map(this, f, 0)
case _ =>
Map(this, f, 0)
}

/**
Expand Down Expand Up @@ -261,12 +266,14 @@ sealed abstract class IO[+A] {
} else {
val lh = F.suspend(source.to[F]).asInstanceOf[F[A]]
F.handleErrorWith(lh) { e =>
m.asInstanceOf[ErrorHandler[IO[A]]].recover(e).to[F]
m.asInstanceOf[ErrorHandler[A]].recover(e).to[F]
}
}
case f =>
F.flatMap(F.suspend(source.to[F]))(e => f(e).to[F])
}
case Map(source, f, _) =>
F.map(source.to[F])(f.asInstanceOf[Any => A])
}

override def toString = this match {
Expand Down Expand Up @@ -572,6 +579,17 @@ object IO extends IOInstances {
private[effect] final case class Bind[E, +A](source: IO[E], f: E => IO[A])
extends IO[A]

/** State for representing `map` ops that itself is a function in
* order to avoid extraneous memory allocations when building the
* internal call-stack.
*/
private[effect] final case class Map[E, +A](source: IO[E], f: E => A, index: Int)
extends IO[A] with (E => IO[A]) {

override def apply(value: E): IO[A] =
IO.pure(f(value))
}

/** Internal reference, used as an optimization for [[IO.attempt]]
* in order to avoid extraneous memory allocations.
*/
Expand Down
18 changes: 8 additions & 10 deletions core/shared/src/main/scala/cats/effect/internals/IOFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package cats.effect.internals

import cats.effect.IO

/** A mapping function that is also able to handle errors,
* being the equivalent of:
*
Expand Down Expand Up @@ -43,20 +45,16 @@ private[effect] object IOFrame {
/** Builds a [[IOFrame]] instance that maps errors, but that isn't
* defined for successful values (a partial function)
*/
def errorHandler[R](fe: Throwable => R): IOFrame[Any, R] =
def errorHandler[A](fe: Throwable => IO[A]): IOFrame[A, IO[A]] =
new ErrorHandler(fe)

/** [[IOFrame]] reference that only handles errors, useful for
* quick filtering of `onErrorHandleWith` frames.
*/
final class ErrorHandler[+R](fe: Throwable => R)
extends IOFrame[Any, R] {

def recover(e: Throwable): R = fe(e)
def apply(a: Any): R = {
// $COVERAGE-OFF$
throw new NotImplementedError("IOFrame protocol breach")
// $COVERAGE-ON$
}
final class ErrorHandler[A](fe: Throwable => IO[A])
extends IOFrame[A, IO[A]] {

def recover(e: Throwable): IO[A] = fe(e)
def apply(a: A): IO[A] = IO.pure(a)
}
}
19 changes: 18 additions & 1 deletion core/shared/src/main/scala/cats/effect/internals/IORunLoop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package cats.effect.internals

import cats.effect.IO
import cats.effect.IO.{Async, Bind, Delay, Pure, RaiseError, Suspend}
import cats.effect.IO.{Async, Bind, Delay, Map, Pure, RaiseError, Suspend}

import scala.collection.mutable.ArrayStack

private[effect] object IORunLoop {
Expand Down Expand Up @@ -91,6 +92,14 @@ private[effect] object IORunLoop {
currentIO = fa
}

case bindNext @ Map(fa, _, _) =>
if (bFirst ne null) {
if (bRest eq null) bRest = new ArrayStack()
bRest.push(bFirst)
}
bFirst = bindNext.asInstanceOf[Bind]
currentIO = fa

case Async(register) =>
if (rcb eq null) rcb = RestartCallback(cb.asInstanceOf[Callback])
rcb.prepare(bFirst, bRest)
Expand Down Expand Up @@ -162,6 +171,14 @@ private[effect] object IORunLoop {
currentIO = fa
}

case bindNext @ Map(fa, _, _) =>
if (bFirst ne null) {
if (bRest eq null) bRest = new ArrayStack()
bRest.push(bFirst)
}
bFirst = bindNext.asInstanceOf[Bind]
currentIO = fa

case Async(register) =>
// Cannot inline the code of this method — as it would
// box those vars in scala.runtime.ObjectRef!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ object arbitrary {
1 -> genFail[A],
5 -> genAsync[A],
5 -> genNestedAsync[A],
5 -> getMapOne[A],
5 -> getMapTwo[A],
10 -> genFlatMap[A])
}

Expand Down Expand Up @@ -72,6 +74,19 @@ object arbitrary {
f <- getArbitrary[A => IO[A]]
} yield ioa.flatMap(f)

def getMapOne[A: Arbitrary: Cogen]: Gen[IO[A]] =
for {
ioa <- getArbitrary[IO[A]]
f <- getArbitrary[A => A]
} yield ioa.map(f)

def getMapTwo[A: Arbitrary: Cogen]: Gen[IO[A]] =
for {
ioa <- getArbitrary[IO[A]]
f1 <- getArbitrary[A => A]
f2 <- getArbitrary[A => A]
} yield ioa.map(f1).map(f2)

implicit def catsEffectLawsCogenForIO[A](implicit cgfa: Cogen[Future[A]]): Cogen[IO[A]] =
cgfa.contramap((ioa: IO[A]) => ioa.unsafeToFuture)
}
Loading