Skip to content

Commit

Permalink
Refactor refined function logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasstucki committed Aug 22, 2023
1 parent 6e370a9 commit 18114f0
Show file tree
Hide file tree
Showing 28 changed files with 226 additions and 210 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
def isStructuralTermSelect(tree: Select) =
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
case defn.PolyFunctionOf(_) =>
case defn.FunctionOf(_) =>
false
case RefinedType(parent, rname, rinfo) =>
rname == tree.name || hasRefinement(parent)
Expand Down
9 changes: 3 additions & 6 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1152,13 +1152,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

def etaExpandCFT(using Context): Tree =
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
case defn.ContextFunctionType(argTypes, resType, _) =>
val anonFun = newAnonFun(
ctx.owner,
MethodType.companion(isContextual = true)(argTypes, resType),
coord = ctx.owner.coord)
case defn.FunctionOf(mt: MethodType) if mt.isContextualMethod && !mt.isResultDependent => // TODO handle result-dependent functions?
val anonFun = newAnonFun(ctx.owner, mt, coord = ctx.owner.coord)
def lambdaBody(refss: List[List[Tree]]) =
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
expand(target.select(nme.apply).appliedToArgss(refss), mt.resType)(
using ctx.withOwner(anonFun))
Closure(anonFun, lambdaBody)
case _ =>
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -876,13 +876,13 @@ object CaptureSet:
empty
case CapturingType(parent, refs) =>
recur(parent) ++ refs
case tpd @ RefinedType(parent, _, rinfo: MethodType)
if followResult && defn.isFunctionNType(tpd) =>
ofType(parent, followResult = false) // pick up capture set from parent type
case tpd @ defn.RefinedFunctionOf(rinfo: MethodType)
if followResult =>
ofType(tpd.parent, followResult = false) // pick up capture set from parent type
++ (recur(rinfo.resType) // add capture set of result
-- CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)) // but disregard bound parameters
case tpd @ AppliedType(tycon, args) =>
if followResult && defn.isNonRefinedFunction(tpd) then
if followResult && defn.isFunctionNType(tpd) then
recur(args.last)
// must be (pure) FunctionN type since ImpureFunctions have already
// been eliminated in selector's dealias. Use capture set of result.
Expand Down
58 changes: 28 additions & 30 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class CheckCaptures extends Recheck, SymTransformer:
capt.println(i"solving $t")
refs.solve()
traverse(parent)
case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionType(t) =>
case defn.RefinedFunctionOf(rinfo) =>
traverse(rinfo)
case tp: TypeVar =>
case tp: TypeRef =>
Expand Down Expand Up @@ -302,8 +302,8 @@ class CheckCaptures extends Recheck, SymTransformer:
t
case _ =>
val t1 = t match
case t @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(t) =>
t.derivedRefinedType(parent, rname, this(rinfo))
case t @ defn.RefinedFunctionOf(rinfo: MethodType) =>
t.derivedRefinedType(t.parent, t.refinedName, this(rinfo))
case _ =>
mapOver(t)
if variance > 0 then t1
Expand Down Expand Up @@ -408,10 +408,10 @@ class CheckCaptures extends Recheck, SymTransformer:
else if meth == defn.Caps_unsafeUnbox then
mapArgUsing(_.forceBoxStatus(false))
else if meth == defn.Caps_unsafeBoxFunArg then
mapArgUsing:
case defn.FunctionOf(paramtpe :: Nil, restpe, isContextual) =>
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContextual)

mapArgUsing: tp =>
val defn.FunctionOf(mt: MethodType) = tp.dealias: @unchecked
mt.derivedLambdaType(resType = mt.resType.forceBoxStatus(true))
.toFunctionType()
else
super.recheckApply(tree, pt) match
case appType @ CapturingType(appType1, refs) =>
Expand Down Expand Up @@ -502,8 +502,9 @@ class CheckCaptures extends Recheck, SymTransformer:
block match
case closureDef(mdef) =>
pt.dealias match
case defn.FunctionOf(ptformals, _, _)
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
case defn.FunctionOf(mt0: MethodType)
if mt0.paramInfos.nonEmpty && mt0.paramInfos.forall(_.captureSet.isAlwaysEmpty) =>
val ptformals = mt0.paramInfos
// Redo setup of the anonymous function so that formal parameters don't
// get capture sets. This is important to avoid false widenings to `cap`
// when taking the base type of the actual closures's dependent function
Expand Down Expand Up @@ -696,21 +697,19 @@ class CheckCaptures extends Recheck, SymTransformer:
//println(i"check conforms $actual1 <<< $expected1")
super.checkConformsExpr(actual1, expected1, tree)

private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean)(using Context): Type =
MethodType.companion(isContextual = isContextual)(args, resultType)
.toFunctionType(isJava = false, alwaysDependent = true)

/** Turn `expected` into a dependent function when `actual` is dependent. */
private def alignDependentFunction(expected: Type, actual: Type)(using Context): Type =
def recur(expected: Type): Type = expected.dealias match
case expected0 @ CapturingType(eparent, refs) =>
val eparent1 = recur(eparent)
if eparent1 eq eparent then expected
else CapturingType(eparent1, refs, boxed = expected0.isBoxed)
case expected @ defn.FunctionOf(args, resultType, isContextual)
if defn.isNonRefinedFunction(expected) && defn.isFunctionNType(actual) && !defn.isNonRefinedFunction(actual) =>
val expected1 = toDepFun(args, resultType, isContextual)
expected1
case defn.FunctionOf(mt: MethodType) =>
actual.dealias match
case defn.FunctionOf(mt2: MethodType) if mt2.isResultDependent =>
mt.toFunctionType(alwaysDependent = true)
case _ =>
expected
case _ =>
expected
recur(expected)
Expand Down Expand Up @@ -781,9 +780,8 @@ class CheckCaptures extends Recheck, SymTransformer:

try
val (eargs, eres) = expected.dealias.stripCapturing match
case defn.FunctionOf(eargs, eres, _) => (eargs, eres)
case expected: MethodType => (expected.paramInfos, expected.resType)
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionNType(expected) => (rinfo.paramInfos, rinfo.resType)
case defn.FunctionOf(mt: MethodType) => (mt.paramInfos, mt.resType)
case _ => (aargs.map(_ => WildcardType), WildcardType)
val aargs1 = aargs.zipWithConserve(eargs) { (aarg, earg) => adapt(aarg, earg, !covariant) }
val ares1 = adapt(ares, eres, covariant)
Expand All @@ -808,7 +806,7 @@ class CheckCaptures extends Recheck, SymTransformer:

try
val eres = expected.dealias.stripCapturing match
case RefinedType(_, _, rinfo: PolyType) => rinfo.resType
case defn.PolyFunctionOf(rinfo: PolyType) => rinfo.resType
case expected: PolyType => expected.resType
case _ => WildcardType

Expand Down Expand Up @@ -842,26 +840,26 @@ class CheckCaptures extends Recheck, SymTransformer:

// Adapt the inner shape type: get the adapted shape type, and the capture set leaked during adaptation
val (styp1, leaked) = styp match {
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
case actual @ AppliedType(tycon, args) if defn.isFunctionNType(actual) =>
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
case actual @ defn.RefinedFunctionOf(rinfo: MethodType) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
.toFunctionType(isJava = false, alwaysDependent = true))
case actual: MethodType =>
adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionType(actual) =>
.toFunctionType(alwaysDependent = true))
case actual @ defn.RefinedFunctionOf(rinfo: PolyType) =>
adaptTypeFun(actual, rinfo.resType, expected, covariant, insertBox,
ares1 =>
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
val actual1 = actual.derivedRefinedType(actual.parent, actual.refinedName, rinfo1)
actual1
)
case actual: MethodType =>
adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
case _ =>
(styp, CaptureSet())
}
Expand Down Expand Up @@ -1080,7 +1078,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case CapturingType(parent, refs) =>
healCaptureSet(refs)
traverse(parent)
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
case defn.RefinedFunctionOf(rinfo: MethodType) =>
traverse(rinfo)
case tp: TermLambda =>
val saved = allowed
Expand Down
16 changes: 8 additions & 8 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ extends tpd.TreeTraverser:
MethodType.companion(
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
)(argTypes, resType)
.toFunctionType(isJava = false, alwaysDependent = true)
.toFunctionType(alwaysDependent = true)

/** If `tp` is an unboxed capturing type or a function returning an unboxed capturing type,
* convert it to be boxed.
Expand All @@ -49,15 +49,15 @@ extends tpd.TreeTraverser:
def recur(tp: Type): Type = tp.dealias match
case tp @ CapturingType(parent, refs) if !tp.isBoxed =>
tp.boxed
case tp1 @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp1) =>
case tp1 @ AppliedType(tycon, args) if defn.isFunctionNType(tp1) =>
val res = args.last
val boxedRes = recur(res)
if boxedRes eq res then tp
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(tp1) =>
case defn.RefinedFunctionOf(rinfo: MethodType) =>
val boxedRinfo = recur(rinfo)
if boxedRinfo eq rinfo then tp
else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true)
else boxedRinfo.toFunctionType(alwaysDependent = true)
case tp1: MethodOrPoly =>
val res = tp1.resType
val boxedRes = recur(res)
Expand Down Expand Up @@ -129,7 +129,7 @@ extends tpd.TreeTraverser:
apply(parent)
case tp @ AppliedType(tycon, args) =>
val tycon1 = this(tycon)
if defn.isNonRefinedFunction(tp) then
if defn.isFunctionNType(tp) then
// Convert toplevel generic function types to dependent functions
if !defn.isFunctionSymbol(tp.typeSymbol) && (tp.dealias ne tp) then
// This type is a function after dealiasing, so we dealias and recurse.
Expand All @@ -149,9 +149,9 @@ extends tpd.TreeTraverser:
tp.derivedAppliedType(tycon1, args1 :+ res1)
else
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
case defn.RefinedFunctionOf(rinfo: MethodType) =>
val rinfo1 = apply(rinfo)
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
if rinfo1 ne rinfo then rinfo1.toFunctionType(alwaysDependent = true)
else tp
case tp: MethodType =>
tp.derivedLambdaType(
Expand Down Expand Up @@ -197,7 +197,7 @@ extends tpd.TreeTraverser:
val mt = ContextualMethodType(paramName :: Nil)(
_ => paramType :: Nil,
mt => if isLast then res else expandThrowsAlias(res, mt :: encl))
val fntpe = defn.PolyFunctionOf(mt)
val fntpe = mt.toFunctionType()
if !encl.isEmpty && isLast then
val cs = CaptureSet(encl.map(_.paramRefs.head)*)
CapturingType(fntpe, cs, boxed = false)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/Synthetics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ object Synthetics:
val (et: ExprType) = symd.info: @unchecked
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
def mapFinalResult(tp: Type, f: Type => Type): Type =
val defn.FunctionOf(args, res, isContextual) = tp: @unchecked
val defn.FunctionNOf(args, res, isContextual) = tp: @unchecked
if defn.isFunctionNType(res) then
defn.FunctionOf(args, mapFinalResult(res, f), isContextual)
defn.FunctionNOf(args, mapFinalResult(res, f), isContextual)
else
f(tp)
val resType1 =
Expand Down
Loading

0 comments on commit 18114f0

Please sign in to comment.