Skip to content

Commit

Permalink
Fix higher-order unification incorrectly substituting tparams
Browse files Browse the repository at this point in the history
When creating a fresh type lambda for the purpose of higher-order type
inference, we incorrectly substituted references to type parameters
before this commit. We want to construct:

    bodyArgs := otherArgs.take(d), T_0, ..., T_k-1
    [T_0, ..., T_k-1] =>> otherTycon[bodyArgs]

For this type to be valid, we need the bounds of `T_i` to be the bounds of
the (d+i) type parameter of `otherTycon` after substituting references to
each type parameter of `otherTycon` by the corresponding argument in `bodyArgs`.

The previous implementation incorrectly substituted only the last `k` type
parameters, this was not enough for correctness. It could also lead to a crash
because it called `integrate` which implicitly assumes it is passed a full list
of type parameters (this is now documented).

Fixes #15983.
  • Loading branch information
smarter committed Oct 13, 2022
1 parent 62684d0 commit 48c7964
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 8 deletions.
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ class TypeApplications(val self: Type) extends AnyVal {
}
}

/** Substitute in `self` the type parameters of `tycon` by some other types. */
final def substTypeParams(tycon: Type, to: List[Type])(using Context): Type =
(tycon.typeParams: @unchecked) match
case LambdaParam(lam, _) :: _ => self.substParams(lam, to)
case params: List[Symbol @unchecked] => self.subst(params, to)

/** If `self` is a higher-kinded type, its type parameters, otherwise Nil */
final def hkTypeParams(using Context): List[TypeParamInfo] =
if (isLambdaSub) typeParams else Nil
Expand Down
25 changes: 18 additions & 7 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1066,12 +1066,16 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
*
* - k := args.length
* - d := otherArgs.length - k
* - T_0, ..., T_k-1 fresh type parameters
* - bodyArgs := otherArgs.take(d), T_0, ..., T_k-1
*
* `adaptedTycon` will be:
* Then,
*
* [T_0, ..., T_k-1] =>> otherTycon[otherArgs(0), ..., otherArgs(d-1), T_0, ..., T_k-1]
* adaptedTycon := [T_0, ..., T_k-1] =>> otherTycon[bodyArgs]
*
* where `T_n` has the same bounds as `otherTycon.typeParams(d+n)`
* where the bounds of `T_i` are set based on the bounds of `otherTycon.typeParams(d+i)`
* after substituting type parameter references by the corresponding argument
* in `bodyArgs` (see `adaptedBounds` in the implementation).
*
* Historical note: this strategy is known in Scala as "partial unification"
* (even though the type constructor variable isn't actually unified but only
Expand All @@ -1096,11 +1100,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
variancesConform(remainingTparams, tparams) && {
val adaptedTycon =
if d > 0 then
val initialArgs = otherArgs.take(d)
/** The arguments passed to `otherTycon` in the body of `tl` */
def bodyArgs(tl: HKTypeLambda) = initialArgs ++ tl.paramRefs
/** The bounds of the type parameters of `tl` */
def adaptedBounds(tl: HKTypeLambda) =
val bodyArgsComputed = bodyArgs(tl)
remainingTparams.map(_.paramInfo)
.mapconserve(_.substTypeParams(otherTycon, bodyArgsComputed).bounds)

HKTypeLambda(remainingTparams.map(_.paramName))(
tl => remainingTparams.map(remainingTparam =>
tl.integrate(remainingTparams, remainingTparam.paramInfo).bounds),
tl => otherTycon.appliedTo(
otherArgs.take(d) ++ tl.paramRefs))
adaptedBounds,
tl => otherTycon.appliedTo(bodyArgs(tl)))
else
otherTycon
(assumedTrue(tycon) || directionalIsSubType(tycon, adaptedTycon)) &&
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3600,10 +3600,14 @@ object Types {

/** The type `[tparams := paramRefs] tp`, where `tparams` can be
* either a list of type parameter symbols or a list of lambda parameters
*
* @pre If `tparams` is a list of lambda parameters, then it must be the
* full, in-order list of type parameters of some type constructor, as
* can be obtained using `TypeApplications#typeParams`.
*/
def integrate(tparams: List[ParamInfo], tp: Type)(using Context): Type =
(tparams: @unchecked) match {
case LambdaParam(lam, _) :: _ => tp.subst(lam, this)
case LambdaParam(lam, _) :: _ => tp.subst(lam, this) // This is where the precondition is necessary.
case params: List[Symbol @unchecked] => tp.subst(params, paramRefs)
}

Expand Down
11 changes: 11 additions & 0 deletions tests/pos/i15983a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class OtherC[A, B, C <: B]

trait crash {
type OtherT[A, B, C <: B]

def indexK[F[_]]: F[Any] = ???

def res: OtherT[Any, Any, Any] = indexK

def res2: OtherC[Any, Any, Any] = indexK
}
11 changes: 11 additions & 0 deletions tests/pos/i15983b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class OtherC[A, B, C <: B, D <: C]

trait crash {
type OtherT[A, B, C <: B, D <: C]

def indexK[F[X, Y <: X]]: F[Any, Any] = ???

def res: OtherT[Any, Any, Any, Any] = indexK

def res2: OtherC[Any, Any, Any, Any] = indexK
}

0 comments on commit 48c7964

Please sign in to comment.