Skip to content

Commit

Permalink
fix(compiler): Fix closure passing [fixes LNG-92] (#747)
Browse files Browse the repository at this point in the history
  • Loading branch information
InversionSpaces authored Jun 14, 2023
1 parent 739854a commit f1abd58
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 33 deletions.
204 changes: 172 additions & 32 deletions model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -581,33 +581,20 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
}

/**
* func innerName(arg: u16) -> u16 -> u16:
* closureName = (x: u16) -> u16:
* retval = x + arg
* <- retval
* <- closureName
* closureName = (x: u16) -> u16:
* retval = x + add
* <- retval
*
* func outer() -> u16:
* outterClosureName <- inner(42)
* <body(outterClosureName.type)>
* <- outterResultName
* @return (closure func, closure type, closure type labelled)
*/
def closureReturnModel(
innerName: String,
def addClosure(
closureName: String,
outterClosureName: String,
outterResultName: String,
body: (ArrowType) => List[RawTag.Tree]
) = {
add: ValueRaw
): (FuncRaw, ArrowType, ArrowType) = {
val closureArg = VarRaw(
"x",
ScalarType.u16
)
val innerArg = VarRaw(
"arg",
ScalarType.u16
)

val closureRes = VarRaw(
"retval",
ScalarType.u16
Expand All @@ -616,24 +603,15 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
domain = ProductType(List(closureArg.`type`)),
codomain = ProductType(List(ScalarType.u16))
)
val closureTypeLablled = closureType.copy(
val closureTypeLabelled = closureType.copy(
domain = ProductType.labelled(List(closureArg.name -> closureArg.`type`))
)

val innerRes = VarRaw(
closureName,
closureTypeLablled
)
val innerType = ArrowType(
domain = ProductType.labelled(List(innerArg.name -> innerArg.`type`)),
codomain = ProductType(List(closureType))
)

val closureBody = SeqTag.wrap(
AssignmentTag(
RawBuilder.add(
closureArg,
innerArg
add
),
closureRes.name
).leaf,
Expand All @@ -645,12 +623,51 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
val closureFunc = FuncRaw(
name = closureName,
arrow = ArrowRaw(
`type` = closureTypeLablled,
`type` = closureTypeLabelled,
ret = List(closureRes),
body = closureBody
)
)

(closureFunc, closureType, closureTypeLabelled)
}

/**
* func innerName(arg: u16) -> u16 -> u16:
* closureName = (x: u16) -> u16:
* retval = x + arg
* <- retval
* <- closureName
*
* func outer() -> u16:
* outterClosureName <- inner(42)
* <body(outterClosureName.type)>
* <- outterResultName
*/
def closureReturnModel(
innerName: String,
closureName: String,
outterClosureName: String,
outterResultName: String,
body: (ArrowType) => List[RawTag.Tree]
) = {
val innerArg = VarRaw(
"arg",
ScalarType.u16
)

val (closureFunc, closureType, closureTypeLabelled) =
addClosure(closureName, innerArg)

val innerRes = VarRaw(
closureName,
closureTypeLabelled
)
val innerType = ArrowType(
domain = ProductType.labelled(List(innerArg.name -> innerArg.`type`)),
codomain = ProductType(List(closureType))
)

val innerBody = SeqTag.wrap(
ClosureTag(
func = closureFunc,
Expand Down Expand Up @@ -1121,6 +1138,129 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
model.equalsOrShowDiff(expected) shouldEqual true
}

/**
* func accept_closure(closure: u16 -> u16) -> u16:
* resA <- closure(42)
* <- resA
*
* func test() -> u16:
* closure = (x: u16) -> u16:
* resC = x + 37
* <- resC
* resT <- accept_closure(closure)
* <- resT
*/
"arrow inliner" should "correctly handle closure as argument [bug LNG-92]" in {
val acceptName = "accept_closure"
val closureName = "closure"
val testName = "test"
val acceptRes = VarRaw("resA", ScalarType.u16)
val testRes = VarRaw("resT", ScalarType.u16)

val (closureFunc, closureType, closureTypeLabelled) =
addClosure(closureName, LiteralRaw("37", LiteralType.number))

val acceptType = ArrowType(
domain = ProductType.labelled(List(closureName -> closureType)),
codomain = ProductType(ScalarType.u16 :: Nil)
)

val acceptBody = SeqTag.wrap(
CallArrowRawTag(
List(Call.Export(acceptRes.name, acceptRes.baseType)),
CallArrowRaw(
ability = None,
name = closureName,
arguments = List(LiteralRaw("42", LiteralType.number)),
baseType = closureType,
serviceId = None
)
).leaf,
ReturnTag(
NonEmptyList.one(acceptRes)
).leaf
)

val acceptFunc = FuncArrow(
funcName = acceptName,
body = acceptBody,
arrowType = ArrowType(
ProductType.labelled(List(closureName -> closureType)),
ProductType(List(ScalarType.u16))
),
ret = List(acceptRes),
capturedArrows = Map.empty,
capturedValues = Map.empty,
capturedTopology = None
)

val testBody = SeqTag.wrap(
ClosureTag(
func = closureFunc,
detach = false
).leaf,
CallArrowRawTag(
List(Call.Export(testRes.name, testRes.baseType)),
CallArrowRaw(
ability = None,
name = acceptName,
arguments = List(VarRaw(closureName, closureTypeLabelled)),
baseType = acceptFunc.arrowType,
serviceId = None
)
).leaf,
ReturnTag(
NonEmptyList.one(testRes)
).leaf
)

val testFunc = FuncArrow(
funcName = testName,
body = testBody,
arrowType = ArrowType(
ProductType(Nil),
ProductType(List(ScalarType.u16))
),
ret = List(testRes),
capturedArrows = Map(acceptName -> acceptFunc),
capturedValues = Map.empty,
capturedTopology = None
)

val model = ArrowInliner
.callArrow[InliningState](
testFunc,
CallModel(Nil, Nil)
)
.runA(InliningState())
.value

/* WARNING: This naming is unstable */
val tempAdd = VarModel("add", ScalarType.u16)

val expected = SeqModel.wrap(
CaptureTopologyModel(closureName).leaf,
MetaModel
.CallArrowModel(acceptName)
.wrap(
MetaModel
.CallArrowModel(closureName)
.wrap(
ApplyTopologyModel(closureName).wrap(
ModelBuilder
.add(
LiteralModel("42", LiteralType.number),
LiteralModel("37", LiteralType.number)
)(tempAdd)
.leaf
)
)
)
)

model.equalsOrShowDiff(expected) shouldEqual true
}

/*
data Prod:
value: string
Expand Down
6 changes: 5 additions & 1 deletion model/raw/src/main/scala/aqua/raw/value/ValueRaw.scala
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ case class CallArrowRaw(
override def varNames: Set[String] = arguments.flatMap(_.varNames).toSet

override def renameVars(map: Map[String, String]): ValueRaw =
copy(arguments = arguments.map(_.renameVars(map)), serviceId = serviceId.map(_.renameVars(map)))
copy(
name = map.getOrElse(name, name),
arguments = arguments.map(_.renameVars(map)),
serviceId = serviceId.map(_.renameVars(map))
)

override def toString: String =
s"(call ${ability.fold("")(a => s"|$a| ")} (${serviceId.fold("")(_.toString + " ")}$name) [${arguments
Expand Down

0 comments on commit f1abd58

Please sign in to comment.