Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(compiler): Recursively find abilities [LNG-338] #1086

Merged
merged 3 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions aqua-src/antithesis.aqua
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
aqua A

import "aqua-src/gen/OneMore.aqua"
export haveFun

export main
ability Compute:
job() -> string

alias SomeAlias: string
func lift() -> Compute:
job = () -> string:
<- "job done"
<- Compute(job)

data NestedStruct:
a: SomeAlias
ability Function:
run() -> string

data SomeStruct:
al: SomeAlias
nested: NestedStruct
func roundtrip{Function}() -> string:
res <- Function.run()
<- res

ability SomeAbility:
someStr: SomeStruct
nested: NestedStruct
al: SomeAlias
someFunc(ss: SomeStruct, nest: NestedStruct, al: SomeAlias) -> NestedStruct, SomeStruct, SomeAlias
func disjoint_run{Compute}() -> Function:
run = func () -> string:
<- Compute.job()
<- Function(run = run)

service Srv("a"):
check(ss: SomeStruct, nest: NestedStruct, al: SomeAlias) -> NestedStruct
check2() -> SomeStruct
check3() -> SomeAlias

func withAb{SomeAbility}() -> SomeStruct:
<- SomeAbility.someStr

func main(ss: SomeStruct, nest: NestedStruct, al: SomeAlias) -> string:
<- ""
func haveFun() -> string:
comp = lift()
fn = disjoint_run{comp}()
res <- roundtrip{fn}()
<- res
30 changes: 28 additions & 2 deletions integration-tests/aqua/examples/abilitiesClosure.aqua
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
aqua M

export bugLNG314
export bugLNG314, bugLNG338

ability WorkerJob:
runOnSingleWorker(w: string) -> string
Expand All @@ -20,4 +20,30 @@ func bugLNG314() -> string:
worker_job = WorkerJob(runOnSingleWorker = job2)
subnet_job <- disjoint_run{worker_job}()
res <- runJob(subnet_job)
<- res
<- res

ability Compute:
job() -> string

func lift() -> Compute:
job = () -> string:
<- "job done"
<- Compute(job)

ability Function:
run() -> string

func roundtrip{Function}() -> string:
res <- Function.run()
<- res

func disjoint_run{Compute}() -> Function:
run = func () -> string:
<- Compute.job()
<- Function(run = run)

func bugLNG338() -> string:
comp = lift()
fn = disjoint_run{comp}()
res <- roundtrip{fn}()
<- res
7 changes: 6 additions & 1 deletion integration-tests/src/__test__/examples.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import {
multipleAbilityWithClosureCall,
returnSrvAsAbilityCall,
} from "../examples/abilityCall.js";
import { bugLNG314Call } from "../examples/abilityClosureCall.js";
import { bugLNG314Call, bugLNG338Call } from "../examples/abilityClosureCall.js";
import {
nilLengthCall,
nilLiteralCall,
Expand Down Expand Up @@ -665,6 +665,11 @@ describe("Testing examples", () => {
expect(result).toEqual("strstrstr");
});

it("abilitiesClosure.aqua bug LNG-338", async () => {
let result = await bugLNG338Call();
expect(result).toEqual("job done");
});

it("functors.aqua LNG-119 bug", async () => {
let result = await bugLng119Call();
expect(result).toEqual([1]);
Expand Down
6 changes: 5 additions & 1 deletion integration-tests/src/examples/abilityClosureCall.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import {
bugLNG314
bugLNG314, bugLNG338
} from "../compiled/examples/abilitiesClosure.js";

export async function bugLNG314Call(): Promise<string> {
return await bugLNG314();
}

export async function bugLNG338Call(): Promise<string> {
return await bugLNG338();
}
79 changes: 69 additions & 10 deletions model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import cats.syntax.option.*
import cats.syntax.semigroup.*
import cats.syntax.traverse.*
import cats.{Eval, Monoid}
import scala.annotation.tailrec
import scribe.Logging

/**
Expand Down Expand Up @@ -82,6 +83,62 @@ object ArrowInliner extends Logging {
arrowsToSave: Map[String, FuncArrow]
)

/**
* Find abilities recursively, because ability can hold arrow with another ability in it.
* @param abilitiesToGather gather all fields for these abilities
* @param varsFromAbs already gathered variables
* @param arrowsFromAbs already gathered arrows
* @param processedAbs already processed abilities
* @return all needed variables and arrows
*/
@tailrec
private def arrowsAndVarsFromAbilities(
abilitiesToGather: Map[String, GeneralAbilityType],
exports: Map[String, ValueModel],
arrows: Map[String, FuncArrow],
varsFromAbs: Map[String, ValueModel] = Map.empty,
arrowsFromAbs: Map[String, FuncArrow] = Map.empty,
processedAbs: Set[String] = Set.empty
): (Map[String, ValueModel], Map[String, FuncArrow]) = {
val varsFromAbilities = abilitiesToGather.flatMap { case (name, at) =>
getAbilityVars(name, None, at, exports)
}
val arrowsFromAbilities = abilitiesToGather.flatMap { case (name, at) =>
getAbilityArrows(name, None, at, exports, arrows)
}

val allProcessed = abilitiesToGather.keySet ++ processedAbs

// find all names that is used in arrows
val namesUsage = arrowsFromAbilities.values.flatMap(_.body.usesVarNames.value).toSet

// check if there is abilities that we didn't gather
val abilitiesUsage = namesUsage.toList
.flatMap(exports.get)
.collect {
case ValueModel.Ability(vm, at) if !allProcessed.contains(vm.name) =>
vm.name -> at
}
.toMap

val allVars = varsFromAbilities ++ varsFromAbs
val allArrows = arrowsFromAbilities ++ arrowsFromAbs

if (abilitiesUsage.isEmpty) {
(allVars, allArrows)
} else {
arrowsAndVarsFromAbilities(
abilitiesUsage,
exports,
arrows,
allVars,
allArrows,
allProcessed
)
}

}

// Apply a callable function, get its fully resolved body & optional value, if any
private def inline[S: Mangler: Arrows: Exports](
fn: FuncArrow,
Expand All @@ -104,15 +161,15 @@ object ArrowInliner extends Logging {
exports <- Exports[S].exports
arrows <- Arrows[S].arrows
// gather all arrows and variables from abilities
returnedAbilities = rets.collect { case ValueModel.Ability(vm, at) =>
abilitiesToGather = rets.collect { case ValueModel.Ability(vm, at) =>
vm.name -> at
}
varsFromAbilities = returnedAbilities.flatMap { case (name, at) =>
getAbilityVars(name, None, at, exports)
}.toMap
arrowsFromAbilities = returnedAbilities.flatMap { case (name, at) =>
getAbilityArrows(name, None, at, exports, arrows)
}.toMap
arrsVars = arrowsAndVarsFromAbilities(
abilitiesToGather.toMap,
exports,
arrows
)
(varsFromAbilities, arrowsFromAbilities) = arrsVars

// find and get resolved arrows if we return them from the function
returnedArrows = rets.collect { case VarModel(name, _: ArrowType, _) => name }.toSet
Expand Down Expand Up @@ -172,9 +229,11 @@ object ArrowInliner extends Logging {
abilityType,
exports
)
val abilityExport =
exports.get(abilityName).map(vm => abilityNewName.getOrElse(abilityName) -> vm).toMap

get(_.variables) ++ get(_.arrows).flatMap {
case arrow @ (_, vm @ ValueModel.Arrow(_, _)) =>
abilityExport ++ get(_.variables) ++ get(_.arrows).flatMap {
case arrow @ (_, ValueModel.Arrow(_, _)) =>
arrow.some
case (_, m) =>
internalError(s"($m) cannot be an arrow")
Expand Down Expand Up @@ -497,7 +556,7 @@ object ArrowInliner extends Logging {
exports <- Exports[S].exports
streams <- getOutsideStreamNames
arrows = passArrows ++ arrowsFromAbilities

inlineResult <- Exports[S].scope(
Arrows[S].scope(
for {
Expand Down
Loading