From 48c79644ebc7363d631263d50183151653e3268f Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Wed, 12 Oct 2022 22:02:22 +0200 Subject: [PATCH] Fix higher-order unification incorrectly substituting tparams When creating a fresh type lambda for the purpose of higher-order type inference, we incorrectly substituted references to type parameters before this commit. We want to construct: bodyArgs := otherArgs.take(d), T_0, ..., T_k-1 [T_0, ..., T_k-1] =>> otherTycon[bodyArgs] For this type to be valid, we need the bounds of `T_i` to be the bounds of the (d+i) type parameter of `otherTycon` after substituting references to each type parameter of `otherTycon` by the corresponding argument in `bodyArgs`. The previous implementation incorrectly substituted only the last `k` type parameters, this was not enough for correctness. It could also lead to a crash because it called `integrate` which implicitly assumes it is passed a full list of type parameters (this is now documented). Fixes #15983. --- .../tools/dotc/core/TypeApplications.scala | 6 +++++ .../dotty/tools/dotc/core/TypeComparer.scala | 25 +++++++++++++------ .../src/dotty/tools/dotc/core/Types.scala | 6 ++++- tests/pos/i15983a.scala | 11 ++++++++ tests/pos/i15983b.scala | 11 ++++++++ 5 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 tests/pos/i15983a.scala create mode 100644 tests/pos/i15983b.scala diff --git a/compiler/src/dotty/tools/dotc/core/TypeApplications.scala b/compiler/src/dotty/tools/dotc/core/TypeApplications.scala index 58f9732edf1f..81f822811456 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeApplications.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeApplications.scala @@ -204,6 +204,12 @@ class TypeApplications(val self: Type) extends AnyVal { } } + /** Substitute in `self` the type parameters of `tycon` by some other types. */ + final def substTypeParams(tycon: Type, to: List[Type])(using Context): Type = + (tycon.typeParams: @unchecked) match + case LambdaParam(lam, _) :: _ => self.substParams(lam, to) + case params: List[Symbol @unchecked] => self.subst(params, to) + /** If `self` is a higher-kinded type, its type parameters, otherwise Nil */ final def hkTypeParams(using Context): List[TypeParamInfo] = if (isLambdaSub) typeParams else Nil diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 9365bafd8282..73ecb620cfea 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1066,12 +1066,16 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling * * - k := args.length * - d := otherArgs.length - k + * - T_0, ..., T_k-1 fresh type parameters + * - bodyArgs := otherArgs.take(d), T_0, ..., T_k-1 * - * `adaptedTycon` will be: + * Then, * - * [T_0, ..., T_k-1] =>> otherTycon[otherArgs(0), ..., otherArgs(d-1), T_0, ..., T_k-1] + * adaptedTycon := [T_0, ..., T_k-1] =>> otherTycon[bodyArgs] * - * where `T_n` has the same bounds as `otherTycon.typeParams(d+n)` + * where the bounds of `T_i` are set based on the bounds of `otherTycon.typeParams(d+i)` + * after substituting type parameter references by the corresponding argument + * in `bodyArgs` (see `adaptedBounds` in the implementation). * * Historical note: this strategy is known in Scala as "partial unification" * (even though the type constructor variable isn't actually unified but only @@ -1096,11 +1100,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling variancesConform(remainingTparams, tparams) && { val adaptedTycon = if d > 0 then + val initialArgs = otherArgs.take(d) + /** The arguments passed to `otherTycon` in the body of `tl` */ + def bodyArgs(tl: HKTypeLambda) = initialArgs ++ tl.paramRefs + /** The bounds of the type parameters of `tl` */ + def adaptedBounds(tl: HKTypeLambda) = + val bodyArgsComputed = bodyArgs(tl) + remainingTparams.map(_.paramInfo) + .mapconserve(_.substTypeParams(otherTycon, bodyArgsComputed).bounds) + HKTypeLambda(remainingTparams.map(_.paramName))( - tl => remainingTparams.map(remainingTparam => - tl.integrate(remainingTparams, remainingTparam.paramInfo).bounds), - tl => otherTycon.appliedTo( - otherArgs.take(d) ++ tl.paramRefs)) + adaptedBounds, + tl => otherTycon.appliedTo(bodyArgs(tl))) else otherTycon (assumedTrue(tycon) || directionalIsSubType(tycon, adaptedTycon)) && diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index f60bba88dde2..605310989384 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -3600,10 +3600,14 @@ object Types { /** The type `[tparams := paramRefs] tp`, where `tparams` can be * either a list of type parameter symbols or a list of lambda parameters + * + * @pre If `tparams` is a list of lambda parameters, then it must be the + * full, in-order list of type parameters of some type constructor, as + * can be obtained using `TypeApplications#typeParams`. */ def integrate(tparams: List[ParamInfo], tp: Type)(using Context): Type = (tparams: @unchecked) match { - case LambdaParam(lam, _) :: _ => tp.subst(lam, this) + case LambdaParam(lam, _) :: _ => tp.subst(lam, this) // This is where the precondition is necessary. case params: List[Symbol @unchecked] => tp.subst(params, paramRefs) } diff --git a/tests/pos/i15983a.scala b/tests/pos/i15983a.scala new file mode 100644 index 000000000000..f49c15158798 --- /dev/null +++ b/tests/pos/i15983a.scala @@ -0,0 +1,11 @@ +class OtherC[A, B, C <: B] + +trait crash { + type OtherT[A, B, C <: B] + + def indexK[F[_]]: F[Any] = ??? + + def res: OtherT[Any, Any, Any] = indexK + + def res2: OtherC[Any, Any, Any] = indexK +} diff --git a/tests/pos/i15983b.scala b/tests/pos/i15983b.scala new file mode 100644 index 000000000000..dd21bc425578 --- /dev/null +++ b/tests/pos/i15983b.scala @@ -0,0 +1,11 @@ +class OtherC[A, B, C <: B, D <: C] + +trait crash { + type OtherT[A, B, C <: B, D <: C] + + def indexK[F[X, Y <: X]]: F[Any, Any] = ??? + + def res: OtherT[Any, Any, Any, Any] = indexK + + def res2: OtherC[Any, Any, Any, Any] = indexK +}