Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(compiler): Always generate last argument of fold [LNG-265] #947

Merged
merged 4 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/air/src/main/scala/aqua/backend/air/Air.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object Air {
iterable: DataView,
label: String,
instruction: Air,
lastNextInstruction: Option[Air]
lastNextInstruction: Air
) extends Air(Keyword.Fold)

case class Match(left: DataView, right: DataView, instruction: Air) extends Air(Keyword.Match)
Expand Down Expand Up @@ -137,7 +137,7 @@ object Air {
case Air.Next(label) ⇒ s" $label"
case Air.New(item, inst) ⇒ s" ${item.show}\n${showNext(inst)}$space"
case Air.Fold(iter, label, inst, lastInst) ⇒
val l = lastInst.map(a => show(depth + 1, a)).getOrElse("")
val l = show(depth + 1, lastInst)
s" ${iter.show} $label\n${showNext(inst)}$l$space"
case Air.Match(left, right, inst) ⇒
s" ${left.show} ${right.show}\n${showNext(inst)}$space"
Expand Down
11 changes: 5 additions & 6 deletions backend/air/src/main/scala/aqua/backend/air/AirGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ object AirGen extends Logging {
)

case FoldRes(item, iterable, mode) =>
val m = mode.map {
case ForModel.Mode.Null => NullGen
case ForModel.Mode.Never => NeverGen
val m = mode match {
case FoldRes.Mode.Null => NullGen
case FoldRes.Mode.Never => NeverGen
}
Eval later ForGen(valueToData(iterable), item, opsToSingle(ops), m)
case RestrictionRes(item, itemType) =>
Expand Down Expand Up @@ -202,9 +202,8 @@ case class MatchMismatchGen(
else Air.Mismatch(left, right, body.generate)
}

case class ForGen(iterable: DataView, item: String, body: AirGen, mode: Option[AirGen])
extends AirGen {
override def generate: Air = Air.Fold(iterable, item, body.generate, mode.map(_.generate))
case class ForGen(iterable: DataView, item: String, body: AirGen, mode: AirGen) extends AirGen {
override def generate: Air = Air.Fold(iterable, item, body.generate, mode.generate)
}

case class NewGen(name: String, body: AirGen) extends AirGen {
Expand Down
54 changes: 28 additions & 26 deletions compiler/src/test/scala/aqua/compiler/AquaCompilerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,34 +169,36 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers with Inside {
RestrictionRes(results.name, resultsType).wrap(
SeqRes.wrap(
ParRes.wrap(
FoldRes(peer.name, peers, ForModel.Mode.Never.some).wrap(
ParRes.wrap(
XorRes.wrap(
// better if first relay will be outside `for`
SeqRes.wrap(
through(ValueModel.fromRaw(relay)),
CallServiceRes(
LiteralModel.fromRaw(LiteralRaw.quote("op")),
"identity",
CallRes(
LiteralModel.fromRaw(LiteralRaw.quote("hahahahah")) :: Nil,
Some(CallModel.Export(retVar.name, retVar.`type`))
),
peer
).leaf,
ApRes(retVar, CallModel.Export(results.name, results.`type`)).leaf,
through(ValueModel.fromRaw(relay)),
through(initPeer)
FoldRes
.lastNever(peer.name, peers)
.wrap(
ParRes.wrap(
XorRes.wrap(
// better if first relay will be outside `for`
SeqRes.wrap(
through(ValueModel.fromRaw(relay)),
CallServiceRes(
LiteralModel.fromRaw(LiteralRaw.quote("op")),
"identity",
CallRes(
LiteralModel.fromRaw(LiteralRaw.quote("hahahahah")) :: Nil,
Some(CallModel.Export(retVar.name, retVar.`type`))
),
peer
).leaf,
ApRes(retVar, CallModel.Export(results.name, results.`type`)).leaf,
through(ValueModel.fromRaw(relay)),
through(initPeer)
),
SeqRes.wrap(
through(ValueModel.fromRaw(relay)),
through(initPeer),
failErrorRes
)
),
SeqRes.wrap(
through(ValueModel.fromRaw(relay)),
through(initPeer),
failErrorRes
)
),
NextRes(peer.name).leaf
NextRes(peer.name).leaf
)
)
)
),
join(results, LiteralModel.number(3)), // Compiler optimized addition
CanonRes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,12 @@ object TagInliner extends Logging {
)
}
_ <- Exports[S].resolved(item, VarModel(n, elementType))
m = mode.map {
case ForTag.Mode.Wait => ForModel.Mode.Never
case ForTag.Mode.Pass => ForModel.Mode.Null
modeModel = mode match {
case ForTag.Mode.Blocking => ForModel.Mode.Never
case ForTag.Mode.NonBlocking => ForModel.Mode.Null
}
} yield TagInlined.Single(
model = ForModel(n, v, m),
model = ForModel(n, v, modeModel),
prefix = p
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ object StreamGateInliner extends Logging {
val resultCanon = VarModel(canonName, CanonStreamType(streamType.element))

RestrictionModel(varSTest.name, streamType).wrap(
ForModel(iter.name, VarModel(streamName, streamType), ForModel.Mode.Never.some).wrap(
ForModel(iter.name, VarModel(streamName, streamType), ForModel.Mode.Never).wrap(
PushToStreamModel(
iter,
CallModel.Export(varSTest.name, varSTest.`type`)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2064,8 +2064,12 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.leaf
)

val foldOp =
ForTag(iVar.name, array, ForTag.Mode.Wait.some).wrap(inFold, NextTag(iVar.name).leaf)
val foldOp = ForTag
.blocking(iVar.name, array)
.wrap(
inFold,
NextTag(iVar.name).leaf
)

val model: OpModel.Tree = ArrowInliner
.callArrow[InliningState](
Expand All @@ -2091,14 +2095,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
._2

model.equalsOrShowDiff(
ForModel(iVar0.name, ValueModel.fromRaw(array), ForModel.Mode.Never.some).wrap(
CallServiceModel(
LiteralModel.fromRaw(serviceId),
fnName,
CallModel(LiteralModel.number(1) :: Nil, Nil)
).leaf,
NextModel(iVar0.name).leaf
)
ForModel
.neverMode(iVar0.name, ValueModel.fromRaw(array))
.wrap(
CallServiceModel(
LiteralModel.fromRaw(serviceId),
fnName,
CallModel(LiteralModel.number(1) :: Nil, Nil)
).leaf,
NextModel(iVar0.name).leaf
)
) should be(true)
}

Expand Down
13 changes: 9 additions & 4 deletions model/raw/src/main/scala/aqua/raw/ops/RawTag.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ case class RestrictionTag(name: String, `type`: DataType) extends SeqGroupTag {
copy(name = map.getOrElse(name, name))
}

case class ForTag(item: String, iterable: ValueRaw, mode: Option[ForTag.Mode] = None)
extends SeqGroupTag {
case class ForTag(item: String, iterable: ValueRaw, mode: ForTag.Mode) extends SeqGroupTag {

override def restrictsVarNames: Set[String] = Set(item)

Expand All @@ -185,9 +184,15 @@ case class ForTag(item: String, iterable: ValueRaw, mode: Option[ForTag.Mode] =
object ForTag {

enum Mode {
case Wait
case Pass
case Blocking
case NonBlocking
}

def blocking(item: String, iterable: ValueRaw): ForTag =
ForTag(item, iterable, Mode.Blocking)

def nonBlocking(item: String, iterable: ValueRaw): ForTag =
ForTag(item, iterable, Mode.NonBlocking)
}

case class CallArrowRawTag(
Expand Down
8 changes: 7 additions & 1 deletion model/res/src/main/scala/aqua/res/MakeRes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ object MakeRes {
case SeqModel | _: OnModel | _: ApplyTopologyModel => SeqRes.leaf
case MatchMismatchModel(a, b, s) =>
MatchMismatchRes(a, b, s).leaf
case ForModel(item, iter, mode) if !isNillLiteral(iter) => FoldRes(item, iter, mode).leaf
case ForModel(item, iter, mode) if !isNillLiteral(iter) =>
val modeRes = mode match {
case ForModel.Mode.Null => FoldRes.Mode.Null
case ForModel.Mode.Never => FoldRes.Mode.Never
}

FoldRes(item, iter, modeRes).leaf
case RestrictionModel(item, itemType) => RestrictionRes(item, itemType).leaf
case DetachModel => ParRes.leaf
case ParModel => ParRes.leaf
Expand Down
18 changes: 14 additions & 4 deletions model/res/src/main/scala/aqua/res/ResolvedOp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,18 @@ case class MatchMismatchRes(left: ValueModel, right: ValueModel, shouldMatch: Bo
override def toString: String = s"(${if (shouldMatch) "match" else "mismatch"} $left $right)"
}

case class FoldRes(item: String, iterable: ValueModel, mode: Option[ForModel.Mode] = None)
extends ResolvedOp {
override def toString: String = s"(fold $iterable $item ${mode.map(_.toString).getOrElse("")}"
case class FoldRes(item: String, iterable: ValueModel, mode: FoldRes.Mode) extends ResolvedOp {
override def toString: String = s"(fold $iterable $item ${mode.toString.toLowerCase()}"
}

object FoldRes {
enum Mode { case Null, Never }

def lastNull(item: String, iterable: ValueModel): FoldRes =
FoldRes(item, iterable, Mode.Null)

def lastNever(item: String, iterable: ValueModel): FoldRes =
FoldRes(item, iterable, Mode.Never)
}

case class RestrictionRes(item: String, `type`: DataType) extends ResolvedOp {
Expand All @@ -50,7 +59,8 @@ case class CallServiceRes(
override def toString: String = s"(call $peerId ($serviceId $funcName) $call)"
}

case class ApStreamMapRes(key: ValueModel, value: ValueModel, exportTo: CallModel.Export) extends ResolvedOp {
case class ApStreamMapRes(key: ValueModel, value: ValueModel, exportTo: CallModel.Export)
extends ResolvedOp {
override def toString: String = s"(ap ($key $value) $exportTo)"
}

Expand Down
2 changes: 1 addition & 1 deletion model/res/src/test/scala/aqua/res/ResBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object ResBuilder {
val arrayRes = VarModel(stream.name + "_gate", ArrayType(ScalarType.string))

RestrictionRes(testVM.name, testStreamType).wrap(
FoldRes(iter.name, stream, ForModel.Mode.Never.some).wrap(
FoldRes(iter.name, stream, FoldRes.Mode.Never).wrap(
ApRes(iter, CallModel.Export(testVM.name, testVM.`type`)).leaf,
CanonRes(testVM, peer, CallModel.Export(canon.name, canon.`type`)).leaf,
XorRes.wrap(
Expand Down
17 changes: 14 additions & 3 deletions model/src/main/scala/aqua/model/OpModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ case class MatchMismatchModel(left: ValueModel, right: ValueModel, shouldMatch:
case class ForModel(
item: String,
iterable: ValueModel,
mode: Option[ForModel.Mode] = Some(ForModel.Mode.Null)
mode: ForModel.Mode = ForModel.Mode.Null
) extends SeqGroupModel {

override def toString: String =
s"for $item <- $iterable${mode.map(m => " " + m.toString).getOrElse("")}"
s"for $item <- $iterable${mode.toString}"

override def restrictsVarNames: Set[String] = Set(item)

Expand All @@ -165,6 +165,12 @@ object ForModel {
case Null
case Never
}

def neverMode(item: String, iterable: ValueModel): ForModel =
ForModel(item, iterable, Mode.Never)

def nullMode(item: String, iterable: ValueModel): ForModel =
ForModel(item, iterable, Mode.Null)
}

// TODO how is it used? remove, if it's not
Expand All @@ -175,7 +181,12 @@ case class DeclareStreamModel(value: ValueModel) extends NoExecModel {
}

// key must be only string or number
case class InsertKeyValueModel(key: ValueModel, value: ValueModel, assignTo: String, assignToType: StreamMapType) extends OpModel {
case class InsertKeyValueModel(
key: ValueModel,
value: ValueModel,
assignTo: String,
assignToType: StreamMapType
) extends OpModel {
override def usesVarNames: Set[String] = value.usesVarNames

override def exportsVarNames: Set[String] = Set(assignTo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ case class ArgsFromService(dataServiceId: ValueRaw) extends ArgsProvider {
Call(Nil, Call.Export(iter, ArrayType(t.element)) :: Nil)
)
.leaf,
ForTag(item, VarRaw(iter, ArrayType(t.element))).wrap(
SeqTag.wrap(
PushToStreamTag(VarRaw(item, t.element), Call.Export(varName, t)).leaf,
NextTag(item).leaf
ForTag
.nonBlocking(item, VarRaw(iter, ArrayType(t.element)))
.wrap(
SeqTag.wrap(
PushToStreamTag(VarRaw(item, t.element), Call.Export(varName, t)).leaf,
NextTag(item).leaf
)
)
)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,11 @@ object Topology extends Logging {
NextRes(itemName).leaf
)

FoldRes(itemName, v).wrap(if (reversed) steps.reverse else steps)
FoldRes
.lastNull(itemName, v)
.wrap(
if (reversed) steps.reverse else steps
)
case _ =>
MakeRes.hop(v)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,16 @@ object ModelBuilder {
failErrorModel
)

def fold(item: String, iter: ValueRaw, mode: Option[ForModel.Mode], body: OpModel.Tree*) = {
def fold(item: String, iter: ValueRaw, mode: ForModel.Mode, body: OpModel.Tree*) = {
val ops = SeqModel.wrap(body: _*)
ForModel(item, ValueModel.fromRaw(iter), mode).wrap(ops, NextModel(item).leaf)
}

def foldPar(item: String, iter: ValueRaw, body: OpModel.Tree*) = {
val ops = SeqModel.wrap(body: _*)
DetachModel.wrap(
ForModel(item, ValueModel.fromRaw(iter), ForModel.Mode.Never.some)
ForModel
.neverMode(item, ValueModel.fromRaw(iter))
.wrap(ParModel.wrap(ops, NextModel(item).leaf))
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package aqua.model.transform.topology

import aqua.model.transform.ModelBuilder
import aqua.model.{CallModel, OnModel, SeqModel}
import aqua.model.{CallModel, ForModel, OnModel, SeqModel}
import aqua.model.transform.cursor.ChainZipper
import aqua.raw.value.{LiteralRaw, ValueRaw, VarRaw}
import aqua.raw.ops.{Call, FuncOp, OnTag}
Expand Down Expand Up @@ -137,7 +137,7 @@ class OpModelTreeCursorSpec extends AnyFlatSpec with Matchers {
fold(
"item",
VarRaw("iterable", ArrayType(ScalarType.string)),
None,
ForModel.Mode.Null,
OnModel(
VarRaw("-in-fold-", ScalarType.string),
Chain.one(VarRaw("-fold-relay-", ScalarType.string))
Expand Down
Loading
Loading