Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Return arrows from functions #693

Merged
merged 9 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions aqua-src/antithesis.aqua
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
service Console("run-console"):
print(any: ⊤)
get() -> string
zzz() -> string

data Azazaz:
s: string

func exec(peers: []string) -> []string:
on "":
closure = (s: Azazaz) -> Azazaz:
Console.get()
func returnCall() -> string -> string:
closure = (s: string) -> string:
<- s
Console.zzz()
<- peers
closure("123asdf")
<- closure

func test() -> string:
a = returnCall()
b = a("arg")
<- b
32 changes: 23 additions & 9 deletions model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aqua.model.inline
import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler}
import aqua.model.*
import aqua.raw.ops.RawTag
import aqua.types.ArrowType
import aqua.raw.value.{ValueRaw, VarRaw}
import aqua.types.{BoxType, StreamType}
import cats.data.{Chain, State, StateT}
Expand Down Expand Up @@ -131,14 +132,24 @@ object ArrowInliner extends Logging {

argsToArrows = argsToArrowsRaw.map { case (k, v) => argsShouldRename.getOrElse(k, k) -> v }

returnedArrows = fn.ret.collect { case VarRaw(name, ArrowType(_, _)) =>
name
}.toSet

returnedArrowsShouldRename <- Mangler[S].findNewNames(returnedArrows)
renamedCapturedArrows = fn.capturedArrows.map { case (k, v) =>
returnedArrowsShouldRename.getOrElse(k, k) -> v
}

// Going to resolve arrows: collect them all. Names should never collide: it's semantically checked
_ <- Arrows[S].purge
_ <- Arrows[S].resolved(fn.capturedArrows ++ argsToArrows)
_ <- Arrows[S].resolved(renamedCapturedArrows ++ argsToArrows)

// Rename all renamed arguments in the body
treeRenamed =
fn.body
.rename(argsShouldRename)
.rename(returnedArrowsShouldRename)
.map(_.mapValues(_.map {
// if an argument is a BoxType (Array or Option), but we pass a stream,
// change a type as stream to not miss `$` sign in air
Expand Down Expand Up @@ -172,7 +183,7 @@ object ArrowInliner extends Logging {
tree = treeRenamed.rename(shouldRename)

// Result could be renamed; take care about that
} yield (tree, fn.ret.map(_.renameVars(shouldRename)))
} yield (tree, fn.ret.map(_.renameVars(shouldRename ++ returnedArrowsShouldRename)))

private[inline] def callArrowRet[S: Exports: Arrows: Mangler](
arrow: FuncArrow,
Expand All @@ -185,12 +196,15 @@ object ArrowInliner extends Logging {
for {
_ <- Arrows[S].resolved(passArrows)
av <- ArrowInliner.inline(arrow, call)
} yield av
// find and get resolved arrows if we return them from the function
returnedArrows = av._2.collect { case VarModel(name, ArrowType(_, _), _) =>
name
}
arrowsToSave <- Arrows[S].pickArrows(returnedArrows.toSet)
} yield av -> arrowsToSave
)
(appliedOp, value) = av

_ <- Exports[S].resolved(call.exportTo.map(_.name).zip(value).toMap)

} yield appliedOp -> value

((appliedOp, values), arrowsToSave) = av
_ <- Arrows[S].resolved(arrowsToSave)
_ <- Exports[S].resolved(call.exportTo.map(_.name).zip(values).toMap)
} yield appliedOp -> values
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler}
import aqua.model.*
import aqua.model.inline.RawValueInliner.collectionToModel
import aqua.model.inline.raw.{CallArrowRawInliner, CollectionRawInliner}
import aqua.raw.arrow.FuncRaw
import aqua.raw.ops.*
import aqua.raw.value.*
import aqua.types.{ArrayType, BoxType, CanonStreamType, StreamType}
import aqua.types.{ArrayType, ArrowType, BoxType, CanonStreamType, StreamType}
import cats.syntax.traverse.*
import cats.syntax.applicative.*
import cats.instances.list.*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package aqua.model.inline.raw

import aqua.model.inline.Inline.parDesugarPrefixOpt
import aqua.model.{CallServiceModel, SeqModel, ValueModel, VarModel}
import aqua.model.{CallServiceModel, FuncArrow, SeqModel, ValueModel, VarModel}
import aqua.model.inline.{ArrowInliner, Inline, TagInliner}
import aqua.model.inline.RawValueInliner.{callToModel, valueToModel}
import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.raw.ops.Call
import aqua.types.ArrowType
import aqua.raw.value.CallArrowRaw
import cats.data.{Chain, State}
import scribe.Logging
Expand Down Expand Up @@ -42,31 +43,58 @@ object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging {
*/
val funcName = value.ability.fold(value.name)(_ + "." + value.name)
logger.trace(s" $funcName")
Arrows[S].arrows.flatMap(arrows =>
arrows.get(funcName) match {
case Some(fn) =>
logger.trace(Console.YELLOW + s"Call arrow $funcName" + Console.RESET)
callToModel(call, false).flatMap { case (cm, p) =>
ArrowInliner
.callArrowRet(fn, cm)
.map { case (body, vars) =>
vars -> Inline(
ListMap.empty,
Chain.one(SeqModel.wrap(p.toList :+ body: _*))
)
}
}
case None =>
logger.error(
s"Inlining, cannot find arrow ${funcName}, available: ${arrows.keys
.mkString(", ")}"
)
State.pure(Nil -> Inline.empty)
}
)

resolveArrow(funcName, call)
}
}

private def resolveFuncArrow[S: Mangler: Exports: Arrows](fn: FuncArrow, call: Call) = {
logger.trace(Console.YELLOW + s"Call arrow ${fn.funcName}" + Console.RESET)
callToModel(call, false).flatMap { case (cm, p) =>
ArrowInliner
.callArrowRet(fn, cm)
.map { case (body, vars) =>
vars -> Inline(
ListMap.empty,
Chain.one(SeqModel.wrap(p.toList :+ body: _*))
)
}
}
}

private def resolveArrow[S: Mangler: Exports: Arrows](funcName: String, call: Call) =
Arrows[S].arrows.flatMap(arrows =>
arrows.get(funcName) match {
case Some(fn) =>
resolveFuncArrow(fn, call)
case None =>
Exports[S].exports.flatMap { exps =>
// if there is no arrow, check if it is stored in Exports as variable and try to resolve it
exps.get(funcName) match {
case Some(VarModel(name, ArrowType(_, _), _)) =>
Arrows[S].arrows.flatMap(arrows =>
arrows.get(name) match {
case Some(fn) =>
resolveFuncArrow(fn, call)
case _ =>
logger.error(
s"Inlining, cannot find arrow $funcName, available: ${arrows.keys
.mkString(", ")}"
)
State.pure(Nil -> Inline.empty)
}
)
case _ =>
logger.error(
s"Inlining, cannot find arrow $funcName, available: ${arrows.keys
.mkString(", ")}"
)
State.pure(Nil -> Inline.empty)
}
}
}
)

override def apply[S: Mangler: Exports: Arrows](
raw: CallArrowRaw,
propertiesAllowed: Boolean
Expand Down
8 changes: 6 additions & 2 deletions model/raw/src/main/scala/aqua/raw/ops/RawTag.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package aqua.raw.ops
import aqua.raw.Raw
import aqua.raw.arrow.FuncRaw
import aqua.raw.ops.RawTag.Tree
import aqua.raw.value.{CallArrowRaw, ValueRaw}
import aqua.raw.value.{CallArrowRaw, ValueRaw, VarRaw}
import aqua.tree.{TreeNode, TreeNodeCompanion}
import aqua.types.{ArrowType, ProductType}
import cats.{Eval, Show}
Expand Down Expand Up @@ -104,7 +104,8 @@ case class MatchMismatchTag(left: ValueRaw, right: ValueRaw, shouldMatch: Boolea
MatchMismatchTag(left.map(f), right.map(f), shouldMatch)
}

case class ForTag(item: String, iterable: ValueRaw, mode: Option[ForTag.Mode] = None) extends SeqGroupTag {
case class ForTag(item: String, iterable: ValueRaw, mode: Option[ForTag.Mode] = None)
extends SeqGroupTag {

override def restrictsVarNames: Set[String] = Set(item)

Expand Down Expand Up @@ -195,6 +196,9 @@ case class ClosureTag(
detach: Boolean
) extends NoExecTag {

override def renameExports(map: Map[String, String]): RawTag =
copy(func = func.copy(name = map.getOrElse(func.name, func.name)))

override def mapValues(f: ValueRaw => ValueRaw): RawTag =
copy(
func.copy(arrow =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ package aqua.parser.expr.func
import aqua.parser.Expr
import aqua.parser.expr.func.DeclareStreamExpr
import aqua.parser.lexer.Token.*
import aqua.parser.lexer.{Name, Token, TypeToken}
import aqua.parser.lexer.{DataTypeToken, Name, Token, TypeToken}
import aqua.parser.lift.LiftParser
import cats.parse.Parser
import cats.{~>, Comonad}
import cats.parse.Parser as P
import cats.{Comonad, ~>}
import aqua.parser.lift.Span
import aqua.parser.lift.Span.{P0ToSpan, PToSpan}

case class DeclareStreamExpr[F[_]](name: Name[F], `type`: TypeToken[F])
case class DeclareStreamExpr[F[_]](name: Name[F], `type`: DataTypeToken[F])
extends Expr[F](DeclareStreamExpr, name) {

override def mapK[K[_]: Comonad](fk: F ~> K): DeclareStreamExpr[K] =
Expand All @@ -19,8 +19,8 @@ case class DeclareStreamExpr[F[_]](name: Name[F], `type`: TypeToken[F])

object DeclareStreamExpr extends Expr.Leaf {

override val p: Parser[DeclareStreamExpr[Span.S]] =
((Name.p <* ` : `) ~ TypeToken.`typedef`).map { case (name, t) =>
override val p: P[DeclareStreamExpr[Span.S]] =
((Name.p <* ` : `) ~ DataTypeToken.`datatypedef`).map { case (name, t) =>
DeclareStreamExpr(name, t)
}

Expand Down
1 change: 1 addition & 0 deletions parser/src/main/scala/aqua/parser/expr/func/IfExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ case class IfExpr[F[_]](left: ValueToken[F], eqOp: EqOp[F], right: ValueToken[F]

object IfExpr extends Expr.AndIndented {

// list of expressions that can be used inside this block
override def validChildren: List[Expr.Lexem] = ForExpr.validChildren

override val p: P[IfExpr[Span.S]] =
Expand Down
14 changes: 10 additions & 4 deletions parser/src/main/scala/aqua/parser/lexer/TypeToken.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import cats.syntax.comonad.*
import cats.syntax.functor.*
import cats.~>
import aqua.parser.lift.Span
import aqua.parser.lift.Span.{P0ToSpan, PToSpan}
import aqua.parser.lift.Span.{P0ToSpan, PToSpan, S}

sealed trait TypeToken[S[_]] extends Token[S] {
def mapK[K[_]: Comonad](fk: S ~> K): TypeToken[K]
Expand Down Expand Up @@ -102,7 +102,7 @@ object BasicTypeToken {
case class ArrowTypeToken[S[_]: Comonad](
override val unit: S[Unit],
args: List[(Option[Name[S]], TypeToken[S])],
res: List[DataTypeToken[S]]
res: List[TypeToken[S]]
) extends TypeToken[S] {
override def as[T](v: T): S[T] = unit.as(v)

Expand All @@ -117,9 +117,15 @@ case class ArrowTypeToken[S[_]: Comonad](

object ArrowTypeToken {

def typeDef(): P[TypeToken[S]] = P.defer(TypeToken.`typedef`.between(`(`, `)`).backtrack | TypeToken.`typedef`)

def returnDef(): P[List[TypeToken[S]]] = comma(
typeDef().backtrack
).map(_.toList)

def `arrowdef`(argTypeP: P[TypeToken[Span.S]]): P[ArrowTypeToken[Span.S]] =
(comma0(argTypeP).with1 ~ ` -> `.lift ~
(comma(DataTypeToken.`datatypedef`).map(_.toList)
(returnDef().backtrack
| `()`.as(Nil))).map { case ((args, point), res) ⇒
ArrowTypeToken(point, args.map(Option.empty[Name[Span.S]] -> _), res)
}
Expand All @@ -129,7 +135,7 @@ object ArrowTypeToken {
(Name.p.map(Option(_)) ~ (` : ` *> (argTypeP | argTypeP.between(`(`, `)`))))
.surroundedBy(`/s*`)
) <* (`/s*` *> `)` <* ` `.?)) ~
(` -> ` *> comma(DataTypeToken.`datatypedef`)).?).map { case ((point, args), res) =>
(` -> ` *> returnDef()).?).map { case ((point, args), res) =>
ArrowTypeToken(point, args, res.toList.flatMap(_.toList))
}
}
Expand Down
Loading