Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(compiler): Optimize math in compile time [LNG-245] #922

Merged
merged 23 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions model/inline/src/main/scala/aqua/model/inline/RawValueInliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,12 @@ package aqua.model.inline
import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler}
import aqua.model.inline.Inline.MergeMode.*
import aqua.model.*
import aqua.model.inline.raw.{
ApplyBinaryOpRawInliner,
ApplyFunctorRawInliner,
ApplyGateRawInliner,
ApplyPropertiesRawInliner,
ApplyUnaryOpRawInliner,
CallArrowRawInliner,
CollectionRawInliner,
MakeAbilityRawInliner
}
import aqua.model.inline.raw.*
import aqua.raw.ops.*
import aqua.raw.value.*
import aqua.types.{ArrayType, LiteralType, OptionType, StreamType}

import cats.Eval
import cats.syntax.traverse.*
import cats.syntax.monoid.*
import cats.syntax.functor.*
Expand Down Expand Up @@ -68,6 +60,10 @@ object RawValueInliner extends Logging {

case cr: CallArrowRaw =>
CallArrowRawInliner(cr, propertiesAllowed)

case cs: CallServiceRaw =>
CallServiceRawInliner(cs, propertiesAllowed)

}

private[inline] def inlineToTree[S: Mangler: Exports: Arrows](
Expand Down Expand Up @@ -104,10 +100,12 @@ object RawValueInliner extends Logging {
def valueToModel[S: Mangler: Exports: Arrows](
value: ValueRaw,
propertiesAllowed: Boolean = true
): State[S, (ValueModel, Option[OpModel.Tree])] = {
logger.trace("RAW " + value)
toModel(unfold(value, propertiesAllowed))
}
): State[S, (ValueModel, Option[OpModel.Tree])] = for {
_ <- StateT.liftF(Eval.later(logger.trace("RAW " + value)))
optimized <- StateT.liftF(Optimization.optimize(value))
_ <- StateT.liftF(Eval.later(logger.trace("OPTIMIZED " + optimized)))
model <- toModel(unfold(optimized, propertiesAllowed))
} yield model

def valueListToModel[S: Mangler: Exports: Arrows](
values: List[ValueRaw]
Expand Down
11 changes: 8 additions & 3 deletions model/inline/src/main/scala/aqua/model/inline/TagInliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import aqua.errors.Errors.internalError
import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.model.*
import aqua.model.inline.RawValueInliner.collectionToModel
import aqua.model.inline.raw.CallArrowRawInliner
import aqua.model.inline.raw.{CallArrowRawInliner, CallServiceRawInliner}
import aqua.raw.value.ApplyBinaryOpRaw.Op as BinOp
import aqua.raw.ops.*
import aqua.raw.value.*
Expand Down Expand Up @@ -308,8 +308,13 @@ object TagInliner extends Logging {
TagInlined.Empty(prefix = parDesugarPrefix(nel.toList.flatMap(_._2)))
})

case CallArrowRawTag(exportTo, value: CallArrowRaw) =>
CallArrowRawInliner.unfoldArrow(value, exportTo).flatMap { case (_, inline) =>
case CallArrowRawTag(exportTo, value: (CallArrowRaw | CallServiceRaw)) =>
(value match {
case ca: CallArrowRaw =>
CallArrowRawInliner.unfold(ca, exportTo)
case cs: CallServiceRaw =>
CallServiceRawInliner.unfold(cs, exportTo)
}).flatMap { case (_, inline) =>
RawValueInliner
.inlineToTree(inline)
.map(tree =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package aqua.model.inline.raw

import aqua.errors.Errors.internalError
import aqua.model.*
import aqua.model.inline.raw.RawInliner
import aqua.model.inline.TagInliner
Expand All @@ -8,8 +9,9 @@ import aqua.raw.value.{AbilityRaw, LiteralRaw, MakeStructRaw}
import cats.data.{NonEmptyList, NonEmptyMap, State}
import aqua.model.inline.Inline
import aqua.model.inline.RawValueInliner.{unfold, valueToModel}
import aqua.types.{ArrowType, ScalarType}
import aqua.types.{ArrowType, ScalarType, Type}
import aqua.raw.value.ApplyBinaryOpRaw
import aqua.raw.value.ApplyBinaryOpRaw.Op
import aqua.raw.value.ApplyBinaryOpRaw.Op.*
import aqua.model.inline.Inline.MergeMode

Expand All @@ -24,9 +26,6 @@ import cats.syntax.applicative.*

object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {

private type BoolOp = And.type | Or.type
private type EqOp = Eq.type | Neq.type

override def apply[S: Mangler: Exports: Arrows](
raw: ApplyBinaryOpRaw,
propertiesAllowed: Boolean
Expand All @@ -37,16 +36,49 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
(rmodel, rinline) = right

result <- raw.op match {
case op @ (And | Or) => inlineBoolOp(lmodel, rmodel, linline, rinline, op)
case op @ (Eq | Neq) =>
case op: Op.Bool =>
inlineBoolOp(
lmodel,
rmodel,
linline,
rinline,
op,
raw.baseType
)
case op: Op.Eq =>
for {
// Canonicalize stream operands before comparison
leftStream <- TagInliner.canonicalizeIfStream(lmodel)
(lmodelStream, linlineStream) = leftStream.map(linline.append)
rightStream <- TagInliner.canonicalizeIfStream(rmodel)
(rmodelStream, rinlineStream) = rightStream.map(rinline.append)
result <- inlineEqOp(lmodelStream, rmodelStream, linlineStream, rinlineStream, op)
result <- inlineEqOp(
lmodelStream,
rmodelStream,
linlineStream,
rinlineStream,
op,
raw.baseType
)
} yield result
case op: Op.Cmp =>
inlineCmpOp(
lmodel,
rmodel,
linline,
rinline,
op,
raw.baseType
)
case op: Op.Math =>
inlineMathOp(
lmodel,
rmodel,
linline,
rinline,
op,
raw.baseType
)
}
} yield result

Expand All @@ -55,7 +87,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: EqOp
op: Op.Eq,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
// Optimize in case compared values are literals
// Semantics should check that types are comparable
Expand All @@ -69,15 +102,16 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
},
linline.mergeWith(rinline, MergeMode.ParMode)
).pure[State[S, *]]
case _ => fullInlineEqOp(lmodel, rmodel, linline, rinline, op)
case _ => fullInlineEqOp(lmodel, rmodel, linline, rinline, op, resType)
}

private def fullInlineEqOp[S: Mangler: Exports: Arrows](
lmodel: ValueModel,
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: EqOp
op: Op.Eq,
resType: Type
): State[S, (ValueModel, Inline)] = {
val (name, shouldMatch) = op match {
case Eq => ("eq", true)
Expand Down Expand Up @@ -114,15 +148,16 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
)
)

result(name, predo)
result(name, resType, predo)
}

private def inlineBoolOp[S: Mangler: Exports: Arrows](
lmodel: ValueModel,
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: BoolOp
op: Op.Bool,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
// Optimize in case of left value is known at compile time
case (LiteralModel.Bool(lvalue), _) =>
Expand All @@ -139,15 +174,16 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
case _ => (lmodel, linline)
}).pure[State[S, *]]
// Produce unoptimized inline
case _ => fullInlineBoolOp(lmodel, rmodel, linline, rinline, op)
case _ => fullInlineBoolOp(lmodel, rmodel, linline, rinline, op, resType)
}

private def fullInlineBoolOp[S: Mangler: Exports: Arrows](
lmodel: ValueModel,
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: BoolOp
op: Op.Bool,
resType: Type
): State[S, (ValueModel, Inline)] = {
val (name, compareWith) = op match {
case And => ("and", false)
Expand Down Expand Up @@ -190,19 +226,140 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
)
)

result(name, predo)
result(name, resType, predo)
}

private def inlineCmpOp[S: Mangler: Exports: Arrows](
lmodel: ValueModel,
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: Op.Cmp,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
case (LiteralModel.Integer(lv), LiteralModel.Integer(rv)) =>
val res = op match {
case Lt => lv < rv
case Lte => lv <= rv
case Gt => lv > rv
case Gte => lv >= rv
}

(
LiteralModel.bool(res),
Inline(linline.predo ++ rinline.predo)
).pure
case _ =>
val fn = op match {
case Lt => "lt"
case Lte => "lte"
case Gt => "gt"
case Gte => "gte"
}

val predo = (resName: String) =>
SeqModel.wrap(
linline.predo ++ rinline.predo :+ CallServiceModel(
serviceId = LiteralModel.quote("cmp"),
funcName = fn,
call = CallModel(
args = lmodel :: rmodel :: Nil,
exportTo = CallModel.Export(resName, resType) :: Nil
)
).leaf
)

result(fn, resType, predo)
}

private def inlineMathOp[S: Mangler: Exports: Arrows](
lmodel: ValueModel,
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: Op.Math,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
case (
LiteralModel.Integer(lv),
LiteralModel.Integer(rv)
) if !mathExceptionalCase(lv, rv, op) =>
val res = op match {
case Add => lv + rv
case Sub => lv - rv
case Mul => lv * rv
case Div => lv / rv
case Rem => lv % rv
case Pow => intPow(lv, rv)
case _ => internalError(s"Unsupported operation $op for $lv and $rv")
}

(
LiteralModel.number(res),
Inline(linline.predo ++ rinline.predo)
).pure
case _ =>
val fn = op match {
case Add => "add"
case Sub => "sub"
case Mul => "mul"
case FMul => "fmul"
case Div => "div"
case Rem => "rem"
case Pow => "pow"
}

val predo = (resName: String) =>
SeqModel.wrap(
linline.predo ++ rinline.predo :+ CallServiceModel(
serviceId = LiteralModel.quote("math"),
funcName = fn,
call = CallModel(
args = lmodel :: rmodel :: Nil,
exportTo = CallModel.Export(resName, resType) :: Nil
)
).leaf
)

result(fn, resType, predo)
}

private def result[S: Mangler](
name: String,
resType: Type,
predo: String => OpModel.Tree
): State[S, (ValueModel, Inline)] =
Mangler[S]
.findAndForbidName(name)
.map(resName =>
(
VarModel(resName, ScalarType.bool),
VarModel(resName, resType),
Inline(Chain.one(predo(resName)))
)
)

private def mathExceptionalCase(
InversionSpaces marked this conversation as resolved.
Show resolved Hide resolved
InversionSpaces marked this conversation as resolved.
Show resolved Hide resolved
left: Long,
right: Long,
op: Op.Math
): Boolean = op match {
case Op.Div | Op.Rem => right == 0
case Op.Pow => right < 0
case _ => false
}

/**
* Integer power
InversionSpaces marked this conversation as resolved.
Show resolved Hide resolved
*
* @param base
* @param exp >= 0
* @return base ** exp
*/
private def intPow(base: Long, exp: Long): Long = {
def intPowTailRec(base: Long, exp: Long, acc: Long): Long =
if (exp <= 0) acc
else intPowTailRec(base * base, exp / 2, if (exp % 2 == 0) acc else acc * base)

intPowTailRec(base, exp, 1)
}
}
Loading