From c1f35aa593278bdd3349af732593f1721b5e06b8 Mon Sep 17 00:00:00 2001 From: odersky Date: Sun, 10 Jul 2022 14:28:50 +0200 Subject: [PATCH 1/5] Instantiate more type variables to hard unions Fixes #14770 --- .../tools/dotc/core/ConstraintHandling.scala | 39 ++++++++--- .../dotty/tools/dotc/core/TypeComparer.scala | 67 +++++++++++++------ .../src/dotty/tools/dotc/core/TypeOps.scala | 4 +- .../dotty/tools/dotc/core/TyperState.scala | 19 +++++- .../src/dotty/tools/dotc/core/Types.scala | 6 +- .../src/dotty/tools/dotc/typer/Namer.scala | 2 +- .../dotty/tools/dotc/typer/Synthesizer.scala | 2 +- .../src/dotty/tools/dotc/typer/Typer.scala | 2 +- tests/pos/i14770.scala | 25 +++++++ 9 files changed, 128 insertions(+), 38 deletions(-) create mode 100644 tests/pos/i14770.scala diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index 747500465c0a..05333bfb778c 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -13,6 +13,7 @@ import typer.ProtoTypes.{newTypeVar, representedParamRef} import UnificationDirection.* import NameKinds.AvoidNameKind import util.SimpleIdentitySet +import NullOpsDecorator.stripNull /** Methods for adding constraints and solving them. * @@ -627,8 +628,11 @@ trait ConstraintHandling { * 1. If `inst` is a singleton type, or a union containing some singleton types, * widen (all) the singleton type(s), provided the result is a subtype of `bound`. * (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint) - * 2. If `inst` is a union type, approximate the union type from above by an intersection - * of all common base types, provided the result is a subtype of `bound`. + * 2a. If `inst` is a union type and `widenUnions` is true, approximate the union type + * from above by an intersection of all common base types, provided the result + * is a subtype of `bound`. + * 2b. If `inst` is a union type and `widenUnions` is false, turn it into a hard + * union type (except for unions | Null, which are kept in the state they were). * 3. Widen some irreducible applications of higher-kinded types to wildcard arguments * (see @widenIrreducible). * 4. Drop transparent traits from intersections (see @dropTransparentTraits). @@ -641,10 +645,12 @@ trait ConstraintHandling { * At this point we also drop the @Repeated annotation to avoid inferring type arguments with it, * as those could leak the annotation to users (see run/inferred-repeated-result). */ - def widenInferred(inst: Type, bound: Type)(using Context): Type = + def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type = def widenOr(tp: Type) = - val tpw = tp.widenUnion - if (tpw ne tp) && (tpw <:< bound) then tpw else tp + if widenUnions then + val tpw = tp.widenUnion + if (tpw ne tp) && (tpw <:< bound) then tpw else tp + else tp.hardenUnions def widenSingle(tp: Type) = val tpw = tp.widenSingletons @@ -664,6 +670,23 @@ trait ConstraintHandling { wideInst.dropRepeatedAnnot end widenInferred + /** Convert all toplevel union types in `tp` to hard unions */ + extension (tp: Type) private def hardenUnions(using Context): Type = tp.widen match + case tp: AndType => + tp.derivedAndType(tp.tp1.hardenUnions, tp.tp2.hardenUnions) + case tp: RefinedType => + tp.derivedRefinedType(tp.parent.hardenUnions, tp.refinedName, tp.refinedInfo) + case tp: RecType => + tp.rebind(tp.parent.hardenUnions) + case tp: HKTypeLambda => + tp.derivedLambdaType(resType = tp.resType.hardenUnions) + case tp: OrType => + val tp1 = tp.stripNull + if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType) + else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false) + case _ => + tp + /** The instance type of `param` in the current constraint (which contains `param`). * If `fromBelow` is true, the instance type is the lub of the parameter's * lower bounds; otherwise it is the glb of its upper bounds. However, @@ -672,10 +695,10 @@ trait ConstraintHandling { * The instance type is not allowed to contain references to types nested deeper * than `maxLevel`. */ - def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type = { + def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int)(using Context): Type = { val approx = approximation(param, fromBelow, maxLevel).simplified if fromBelow then - val widened = widenInferred(approx, param) + val widened = widenInferred(approx, param, widenUnions) // Widening can add extra constraints, in particular the widened type might // be a type variable which is now instantiated to `param`, and therefore // cannot be used as an instantiation of `param` without creating a loop. @@ -683,7 +706,7 @@ trait ConstraintHandling { // (we do not check for non-toplevel occurences: those should never occur // since `addOneBound` disallows recursive lower bounds). if constraint.occursAtToplevel(param, widened) then - instanceType(param, fromBelow, maxLevel) + instanceType(param, fromBelow, widenUnions, maxLevel) else widened else diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index b5fe41095516..5c8c7633945b 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -487,7 +487,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // If LHS is a hard union, constrain any type variables of the RHS with it as lower bound // before splitting the LHS into its constituents. That way, the RHS variables are - // constraint by the hard union and can be instantiated to it. If we just split and add + // constrained by the hard union and can be instantiated to it. If we just split and add // the two parts of the LHS separately to the constraint, the lower bound would become // a soft union. def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match @@ -495,23 +495,46 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22) case _ => true - widenOK - || joinOK - || (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2) - || containsAnd(tp1) - && !joined - && { - joined = true - try inFrozenGadt(recur(tp1.join, tp2)) - finally joined = false - } - // An & on the left side loses information. We compensate by also trying the join. - // This is less ad-hoc than it looks since we produce joins in type inference, - // and then need to check that they are indeed supertypes of the original types - // under -Ycheck. Test case is i7965.scala. - // On the other hand, we could get a combinatorial explosion by applying such joins - // recursively, so we do it only once. See i14870.scala as a test case, which would - // loop for a very long time without the recursion brake. + /** Mark toplevel type vars in `tp2` as hard in the current typerState */ + def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match + case tvar: TypeVar if constraint.contains(tvar.origin) => + state.hardVars += tvar + case tp2: TypeParamRef if constraint.contains(tp2) => + hardenTypeVars(constraint.typeVarOfParam(tp2)) + case tp2: AndOrType => + hardenTypeVars(tp2.tp1) + hardenTypeVars(tp2.tp2) + case _ => + + val res = widenOK + || joinOK + || (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2) + || containsAnd(tp1) + && !joined + && { + joined = true + try inFrozenGadt(recur(tp1.join, tp2)) + finally joined = false + } + // An & on the left side loses information. We compensate by also trying the join. + // This is less ad-hoc than it looks since we produce joins in type inference, + // and then need to check that they are indeed supertypes of the original types + // under -Ycheck. Test case is i7965.scala. + // On the other hand, we could get a combinatorial explosion by applying such joins + // recursively, so we do it only once. See i14870.scala as a test case, which would + // loop for a very long time without the recursion brake. + + if res && !tp1.isSoft then + // We use a heuristic here where every toplevel type variable on the right hand side + // is marked so that it converts all soft unions in its lower bound to hard unions + // before it is instantiated. The reason is that the union might have come from + // (decomposed and reconstituted) `tp1`. But of course there might be false positives + // where we also treat unions that come from elsewhere as hard unions. Or the constraint + // that created the union is ultimately thrown away, but the type variable will + // stay marked. So it is a coarse measure to take. But it works in the obvious cases. + hardenTypeVars(tp2) + + res case CapturingType(parent1, refs1) => if subCaptures(refs1, tp2.captureSet, frozenConstraint).isOK && sameBoxed(tp1, tp2, refs1) @@ -2960,8 +2983,8 @@ object TypeComparer { def subtypeCheckInProgress(using Context): Boolean = comparing(_.subtypeCheckInProgress) - def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type = - comparing(_.instanceType(param, fromBelow, maxLevel)) + def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type = + comparing(_.instanceType(param, fromBelow, widenUnions, maxLevel)) def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type = comparing(_.approximation(param, fromBelow, maxLevel)) @@ -2981,8 +3004,8 @@ object TypeComparer { def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean = comparing(_.addToConstraint(tl, tvars)) - def widenInferred(inst: Type, bound: Type)(using Context): Type = - comparing(_.widenInferred(inst, bound)) + def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type = + comparing(_.widenInferred(inst, bound, widenUnions)) def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type = comparing(_.dropTransparentTraits(tp, bound)) diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index c087aac83cb8..35afab770259 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -537,7 +537,9 @@ object TypeOps: override def apply(tp: Type): Type = tp match case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) => val lo = TypeComparer.instanceType( - tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx) + tp.origin, + fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound, + widenUnions = tp.widenUnions)(using mapCtx) val lo1 = apply(lo) if (lo1 ne lo) lo1 else tp case _ => diff --git a/compiler/src/dotty/tools/dotc/core/TyperState.scala b/compiler/src/dotty/tools/dotc/core/TyperState.scala index 81b60c608e28..94f85d5c1dd7 100644 --- a/compiler/src/dotty/tools/dotc/core/TyperState.scala +++ b/compiler/src/dotty/tools/dotc/core/TyperState.scala @@ -25,19 +25,20 @@ object TyperState { type LevelMap = SimpleIdentityMap[TypeVar, Integer] - opaque type Snapshot = (Constraint, TypeVars, LevelMap) + opaque type Snapshot = (Constraint, TypeVars, TypeVars, LevelMap) extension (ts: TyperState) def snapshot()(using Context): Snapshot = - (ts.constraint, ts.ownedVars, ts.upLevels) + (ts.constraint, ts.ownedVars, ts.hardVars, ts.upLevels) def resetTo(state: Snapshot)(using Context): Unit = - val (constraint, ownedVars, upLevels) = state + val (constraint, ownedVars, hardVars, upLevels) = state for tv <- ownedVars do if !ts.ownedVars.contains(tv) then // tv has been instantiated tv.resetInst(ts) ts.constraint = constraint ts.ownedVars = ownedVars + ts.hardVars = hardVars ts.upLevels = upLevels } @@ -91,6 +92,14 @@ class TyperState() { def ownedVars: TypeVars = myOwnedVars def ownedVars_=(vs: TypeVars): Unit = myOwnedVars = vs + /** The set of type variables `tv` such that, if `tv` is instantiated to + * its lower bound, top-level soft unions in the instance type are converted + * to hard unions instead of being widened in `widenOr`. + */ + private var myHardVars: TypeVars = _ + def hardVars: TypeVars = myHardVars + def hardVars_=(tvs: TypeVars): Unit = myHardVars = tvs + private var upLevels: LevelMap = _ /** Initializes all fields except reporter, isCommittable, which need to be @@ -103,6 +112,7 @@ class TyperState() { this.myConstraint = constraint this.previousConstraint = constraint this.myOwnedVars = SimpleIdentitySet.empty + this.myHardVars = SimpleIdentitySet.empty this.upLevels = SimpleIdentityMap.empty this.isCommitted = false this @@ -114,6 +124,7 @@ class TyperState() { val ts = TyperState().init(this, this.constraint) .setReporter(reporter) .setCommittable(committable) + ts.hardVars = this.hardVars ts.upLevels = upLevels ts @@ -180,6 +191,7 @@ class TyperState() { constr.println(i"committing $this to $targetState, fromConstr = $constraint, toConstr = ${targetState.constraint}") if targetState.constraint eq previousConstraint then targetState.constraint = constraint + targetState.hardVars = hardVars if !ownedVars.isEmpty then ownedVars.foreach(targetState.includeVar) else targetState.mergeConstraintWith(this) @@ -238,6 +250,7 @@ class TyperState() { val otherLos = other.lower(p) val otherHis = other.upper(p) val otherEntry = other.entry(p) + if that.hardVars.contains(tv) then this.myHardVars += tv ( (otherLos eq constraint.lower(p)) || otherLos.forall(_ <:< p)) && ( (otherHis eq constraint.upper(p)) || otherHis.forall(p <:< _)) && ((otherEntry eq constraint.entry(p)) || otherEntry.match diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 935d22d14de2..b2a9e876b8a3 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4714,12 +4714,16 @@ object Types { * is also a singleton type. */ def instantiate(fromBelow: Boolean)(using Context): Type = - val tp = TypeComparer.instanceType(origin, fromBelow, nestingLevel) + val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel) if myInst.exists then // The line above might have triggered instantiation of the current type variable myInst else instantiateWith(tp) + /** Widen unions when instantiating this variable in the current context? */ + def widenUnions(using Context): Boolean = + !ctx.typerState.hardVars.contains(this) + /** For uninstantiated type variables: the entry in the constraint (either bounds or * provisional instance value) */ diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index e6426cc54cd5..ad8d0e50d348 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1888,7 +1888,7 @@ class Namer { typer: Typer => TypeOps.simplify(tp.widenTermRefExpr, if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match case ctp: ConstantType if sym.isInlineVal => ctp - case tp => TypeComparer.widenInferred(tp, pt) + case tp => TypeComparer.widenInferred(tp, pt, widenUnions = true) // Replace aliases to Unit by Unit itself. If we leave the alias in // it would be erased to BoxedUnit. diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 2e49e0c8bb58..e3f5382ecad7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -514,7 +514,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): val tparams = poly.paramRefs val variances = childClass.typeParams.map(_.paramVarianceSign) val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) => - TypeComparer.instanceType(tparam, fromBelow = variance < 0) + TypeComparer.instanceType(tparam, fromBelow = variance < 0, widenUnions = true) ) val instanceType = resType.substParams(poly, instanceTypes) // this is broken in tests/run/i13332intersection.scala, diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 36e600faefc6..d802ef0df973 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2847,7 +2847,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if (ctx.mode.is(Mode.Pattern)) app1 else { val elemTpes = elems.lazyZip(pts).map((elem, pt) => - TypeComparer.widenInferred(elem.tpe, pt)) + TypeComparer.widenInferred(elem.tpe, pt, widenUnions = true)) val resTpe = TypeOps.nestedPairs(elemTpes) app1.cast(resTpe) } diff --git a/tests/pos/i14770.scala b/tests/pos/i14770.scala new file mode 100644 index 000000000000..182ccba21fdf --- /dev/null +++ b/tests/pos/i14770.scala @@ -0,0 +1,25 @@ +type UndefOr[A] = A | Unit + +extension [A](maybe: UndefOr[A]) + def foreach(f: A => Unit): Unit = + maybe match + case () => () + case a: A => f(a) + +trait Foo +trait Bar + +object Baz: + var booBap: Foo | Bar = _ + +def z: UndefOr[Foo | Bar] = ??? + +@main +def main = + z.foreach(x => Baz.booBap = x) + +def test[A](v: A | Unit): A | Unit = v +val x1 = test(5: Int | Unit) +val x2 = test(5: String | Int | Unit) +val _: Int | Unit = x1 +val _: String | Int | Unit = x2 From f1772668d0e3cbf4b028193687e6491a973a2012 Mon Sep 17 00:00:00 2001 From: odersky Date: Sun, 21 Aug 2022 11:14:33 +0200 Subject: [PATCH 2/5] Refactor hardenedVars handling --- .../src/dotty/tools/dotc/core/TypeComparer.scala | 2 +- .../src/dotty/tools/dotc/core/TyperState.scala | 14 +++++++++----- compiler/src/dotty/tools/dotc/core/Types.scala | 3 +-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 5c8c7633945b..3756ff2fba44 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -498,7 +498,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling /** Mark toplevel type vars in `tp2` as hard in the current typerState */ def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match case tvar: TypeVar if constraint.contains(tvar.origin) => - state.hardVars += tvar + state.hardenTypeVar(tvar) case tp2: TypeParamRef if constraint.contains(tp2) => hardenTypeVars(constraint.typeVarOfParam(tp2)) case tp2: AndOrType => diff --git a/compiler/src/dotty/tools/dotc/core/TyperState.scala b/compiler/src/dotty/tools/dotc/core/TyperState.scala index 94f85d5c1dd7..2145919f124b 100644 --- a/compiler/src/dotty/tools/dotc/core/TyperState.scala +++ b/compiler/src/dotty/tools/dotc/core/TyperState.scala @@ -96,9 +96,7 @@ class TyperState() { * its lower bound, top-level soft unions in the instance type are converted * to hard unions instead of being widened in `widenOr`. */ - private var myHardVars: TypeVars = _ - def hardVars: TypeVars = myHardVars - def hardVars_=(tvs: TypeVars): Unit = myHardVars = tvs + private var hardVars: TypeVars = _ private var upLevels: LevelMap = _ @@ -112,7 +110,7 @@ class TyperState() { this.myConstraint = constraint this.previousConstraint = constraint this.myOwnedVars = SimpleIdentitySet.empty - this.myHardVars = SimpleIdentitySet.empty + this.hardVars = SimpleIdentitySet.empty this.upLevels = SimpleIdentityMap.empty this.isCommitted = false this @@ -131,6 +129,12 @@ class TyperState() { /** The uninstantiated variables */ def uninstVars: collection.Seq[TypeVar] = constraint.uninstVars + /** Register type variable `tv` as hard. */ + def hardenTypeVar(tv: TypeVar): Unit = hardVars += tv + + /** Is type variable `tv` registered as hard? */ + def isHard(tv: TypeVar): Boolean = hardVars.contains(tv) + /** The nestingLevel of `tv` in this typer state */ def nestingLevel(tv: TypeVar): Int = val own = upLevels(tv) @@ -195,6 +199,7 @@ class TyperState() { if !ownedVars.isEmpty then ownedVars.foreach(targetState.includeVar) else targetState.mergeConstraintWith(this) + for tv <- hardVars do targetState.hardVars += tv upLevels.foreachBinding { (tv, level) => if level < targetState.nestingLevel(tv) then @@ -250,7 +255,6 @@ class TyperState() { val otherLos = other.lower(p) val otherHis = other.upper(p) val otherEntry = other.entry(p) - if that.hardVars.contains(tv) then this.myHardVars += tv ( (otherLos eq constraint.lower(p)) || otherLos.forall(_ <:< p)) && ( (otherHis eq constraint.upper(p)) || otherHis.forall(p <:< _)) && ((otherEntry eq constraint.entry(p)) || otherEntry.match diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index b2a9e876b8a3..037b1cc2556e 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4721,8 +4721,7 @@ object Types { instantiateWith(tp) /** Widen unions when instantiating this variable in the current context? */ - def widenUnions(using Context): Boolean = - !ctx.typerState.hardVars.contains(this) + def widenUnions(using Context): Boolean = !ctx.typerState.isHard(this) /** For uninstantiated type variables: the entry in the constraint (either bounds or * provisional instance value) From d15558deb47fef821d63178a25e7d0e003d33db2 Mon Sep 17 00:00:00 2001 From: odersky Date: Tue, 30 Aug 2022 13:28:08 +0200 Subject: [PATCH 3/5] Try alternative to track hard typevars in constraint Try alternative to track hard typevars in constraint instead of typer state. This avoids the problem that a failing subtype comparison can also mark type variables as hard. --- .../dotty/tools/dotc/core/Constraint.scala | 6 ++++ .../tools/dotc/core/GadtConstraint.scala | 4 +-- .../tools/dotc/core/OrderingConstraint.scala | 35 ++++++++++++------- .../dotty/tools/dotc/core/TypeComparer.scala | 12 +++---- .../dotty/tools/dotc/core/TyperState.scala | 3 +- .../src/dotty/tools/dotc/core/Types.scala | 4 ++- 6 files changed, 41 insertions(+), 23 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Constraint.scala b/compiler/src/dotty/tools/dotc/core/Constraint.scala index c35c93886cd8..07b6e71cdcc9 100644 --- a/compiler/src/dotty/tools/dotc/core/Constraint.scala +++ b/compiler/src/dotty/tools/dotc/core/Constraint.scala @@ -126,6 +126,12 @@ abstract class Constraint extends Showable { */ def subst(from: TypeLambda, to: TypeLambda)(using Context): This + /** Is `tv` marked as hard in the constraint? */ + def isHard(tv: TypeVar): Boolean + + /** The same as this constraint, but with `tv` marked as hard. */ + def withHard(tv: TypeVar)(using Context): This + /** Gives for each instantiated type var that does not yet have its `inst` field * set, the instance value stored in the constraint. Storing instances in constraints * is done only in a temporary way for contexts that may be retracted diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index a8b5eee4902d..d8e1c5276ab6 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -6,7 +6,7 @@ import Decorators._ import Contexts._ import Types._ import Symbols._ -import util.SimpleIdentityMap +import util.{SimpleIdentitySet, SimpleIdentityMap} import collection.mutable import printing._ @@ -68,7 +68,7 @@ final class ProperGadtConstraint private( import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} def this() = this( - myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty), + myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty), mapping = SimpleIdentityMap.empty, reverseMapping = SimpleIdentityMap.empty, wasConstrained = false diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 5f267a1c242a..1341fac7d735 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -3,7 +3,7 @@ package dotc package core import Types._, Contexts._, Symbols._, Decorators._, TypeApplications._ -import util.SimpleIdentityMap +import util.{SimpleIdentitySet, SimpleIdentityMap} import collection.mutable import printing.Printer import printing.Texts._ @@ -24,12 +24,14 @@ object OrderingConstraint { /** The type of `OrderingConstraint#lowerMap`, `OrderingConstraint#upperMap` */ type ParamOrdering = ArrayValuedMap[List[TypeParamRef]] - /** A new constraint with given maps */ - private def newConstraint(boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering)(using Context) : OrderingConstraint = + /** A new constraint with given maps and given set of hard typevars */ + private def newConstraint( + boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering, + hardVars: TypeVars)(using Context) : OrderingConstraint = if boundsMap.isEmpty && lowerMap.isEmpty && upperMap.isEmpty then empty else - val result = new OrderingConstraint(boundsMap, lowerMap, upperMap) + val result = new OrderingConstraint(boundsMap, lowerMap, upperMap, hardVars) if ctx.run != null then ctx.run.nn.recordConstraintSize(result, result.boundsMap.size) result @@ -91,7 +93,7 @@ object OrderingConstraint { def entries(c: OrderingConstraint, poly: TypeLambda): Array[Type] | Null = c.boundsMap(poly) def updateEntries(c: OrderingConstraint, poly: TypeLambda, entries: Array[Type])(using Context): OrderingConstraint = - newConstraint(c.boundsMap.updated(poly, entries), c.lowerMap, c.upperMap) + newConstraint(c.boundsMap.updated(poly, entries), c.lowerMap, c.upperMap, c.hardVars) def initial = NoType } @@ -99,7 +101,7 @@ object OrderingConstraint { def entries(c: OrderingConstraint, poly: TypeLambda): Array[List[TypeParamRef]] | Null = c.lowerMap(poly) def updateEntries(c: OrderingConstraint, poly: TypeLambda, entries: Array[List[TypeParamRef]])(using Context): OrderingConstraint = - newConstraint(c.boundsMap, c.lowerMap.updated(poly, entries), c.upperMap) + newConstraint(c.boundsMap, c.lowerMap.updated(poly, entries), c.upperMap, c.hardVars) def initial = Nil } @@ -107,12 +109,12 @@ object OrderingConstraint { def entries(c: OrderingConstraint, poly: TypeLambda): Array[List[TypeParamRef]] | Null = c.upperMap(poly) def updateEntries(c: OrderingConstraint, poly: TypeLambda, entries: Array[List[TypeParamRef]])(using Context): OrderingConstraint = - newConstraint(c.boundsMap, c.lowerMap, c.upperMap.updated(poly, entries)) + newConstraint(c.boundsMap, c.lowerMap, c.upperMap.updated(poly, entries), c.hardVars) def initial = Nil } @sharable - val empty = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty) + val empty = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty) } import OrderingConstraint._ @@ -134,10 +136,13 @@ import OrderingConstraint._ * @param upperMap a map from TypeLambdas to arrays. Each array entry corresponds * to a parameter P of the type lambda; it contains all constrained parameters * Q that are known to be greater than P, i.e. P <: Q. + * @param hardVars a set of type variables that are marked as hard and therefore will not + * undergo a `widenUnion` when instantiated to their lower bound. */ class OrderingConstraint(private val boundsMap: ParamBounds, private val lowerMap : ParamOrdering, - private val upperMap : ParamOrdering) extends Constraint { + private val upperMap : ParamOrdering, + private val hardVars : TypeVars) extends Constraint { import UnificationDirection.* @@ -277,7 +282,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds, val entries1 = new Array[Type](nparams * 2) poly.paramInfos.copyToArray(entries1, 0) tvars.copyToArray(entries1, nparams) - newConstraint(boundsMap.updated(poly, entries1), lowerMap, upperMap).init(poly) + newConstraint(boundsMap.updated(poly, entries1), lowerMap, upperMap, hardVars).init(poly) } /** Split dependent parameters off the bounds for parameters in `poly`. @@ -478,7 +483,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds, } po.remove(pt).mapValuesNow(removeFromBoundss) } - newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap)) + val hardVars1 = pt.paramRefs.foldLeft(hardVars)((hvs, param) => hvs - typeVarOfParam(param)) + newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap), hardVars1) .checkNonCyclic() } @@ -505,7 +511,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds, def swapKey[T](m: ArrayValuedMap[T]) = val info = m(from) if info == null then m else m.remove(from).updated(to, info) - var current = newConstraint(swapKey(boundsMap), swapKey(lowerMap), swapKey(upperMap)) + var current = newConstraint(swapKey(boundsMap), swapKey(lowerMap), swapKey(upperMap), hardVars) def subst[T <: Type](x: T): T = x.subst(from, to).asInstanceOf[T] current.foreachParam {(p, i) => current = boundsLens.map(this, current, p, i, subst) @@ -515,6 +521,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds, constr.println(i"renamed $this to $current") current.checkNonCyclic() + def isHard(tv: TypeVar) = hardVars.contains(tv) + + def withHard(tv: TypeVar)(using Context) = + newConstraint(boundsMap, lowerMap, upperMap, hardVars + tv) + def instType(tvar: TypeVar): Type = entry(tvar.origin) match case _: TypeBounds => NoType case tp: TypeParamRef => typeVarOfParam(tp).orElse(tp) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 3756ff2fba44..a146b2dbd677 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -495,10 +495,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22) case _ => true - /** Mark toplevel type vars in `tp2` as hard in the current typerState */ + /** Mark toplevel type vars in `tp2` as hard in the current constraint */ def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match case tvar: TypeVar if constraint.contains(tvar.origin) => state.hardenTypeVar(tvar) + constraint = constraint.withHard(tvar) case tp2: TypeParamRef if constraint.contains(tp2) => hardenTypeVars(constraint.typeVarOfParam(tp2)) case tp2: AndOrType => @@ -524,14 +525,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // recursively, so we do it only once. See i14870.scala as a test case, which would // loop for a very long time without the recursion brake. - if res && !tp1.isSoft then + if res && !tp1.isSoft && state.isCommittable then // We use a heuristic here where every toplevel type variable on the right hand side // is marked so that it converts all soft unions in its lower bound to hard unions - // before it is instantiated. The reason is that the union might have come from - // (decomposed and reconstituted) `tp1`. But of course there might be false positives - // where we also treat unions that come from elsewhere as hard unions. Or the constraint - // that created the union is ultimately thrown away, but the type variable will - // stay marked. So it is a coarse measure to take. But it works in the obvious cases. + // before it is instantiated. The reason is that the variable's instance type will + // be a supertype of (decomposed and reconstituted) `tp1`. hardenTypeVars(tp2) res diff --git a/compiler/src/dotty/tools/dotc/core/TyperState.scala b/compiler/src/dotty/tools/dotc/core/TyperState.scala index 2145919f124b..9d7021f5c644 100644 --- a/compiler/src/dotty/tools/dotc/core/TyperState.scala +++ b/compiler/src/dotty/tools/dotc/core/TyperState.scala @@ -246,7 +246,8 @@ class TyperState() { constraint.contains(tl) || other.isRemovable(tl) || { val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv } if this.isCommittable then - tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar)) + tvars.foreach(tvar => + if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar)) typeComparer.addToConstraint(tl, tvars) }) && // Integrate the additional constraints on type variables from `other` diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 037b1cc2556e..c7c8b594b165 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4721,7 +4721,9 @@ object Types { instantiateWith(tp) /** Widen unions when instantiating this variable in the current context? */ - def widenUnions(using Context): Boolean = !ctx.typerState.isHard(this) + def widenUnions(using Context): Boolean = + if true then !ctx.typerState.constraint.isHard(this) + else !ctx.typerState.isHard(this) /** For uninstantiated type variables: the entry in the constraint (either bounds or * provisional instance value) From d7521bf5e5ede8495b82d636dd24904f77fd38f2 Mon Sep 17 00:00:00 2001 From: odersky Date: Tue, 30 Aug 2022 15:41:53 +0200 Subject: [PATCH 4/5] Drop old TyperState based scheme and drop constrainRHSVars It seems constrainRHSVars is no longer needed with the new way to keep track of hard type variables. --- .../dotty/tools/dotc/core/TypeComparer.scala | 18 +++---------- .../dotty/tools/dotc/core/TyperState.scala | 27 +++++-------------- .../src/dotty/tools/dotc/core/Types.scala | 4 +-- 3 files changed, 10 insertions(+), 39 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index a146b2dbd677..78fbea352bf3 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -485,20 +485,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling false } - // If LHS is a hard union, constrain any type variables of the RHS with it as lower bound - // before splitting the LHS into its constituents. That way, the RHS variables are - // constrained by the hard union and can be instantiated to it. If we just split and add - // the two parts of the LHS separately to the constraint, the lower bound would become - // a soft union. - def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match - case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2) - case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22) - case _ => true - /** Mark toplevel type vars in `tp2` as hard in the current constraint */ def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match case tvar: TypeVar if constraint.contains(tvar.origin) => - state.hardenTypeVar(tvar) constraint = constraint.withHard(tvar) case tp2: TypeParamRef if constraint.contains(tp2) => hardenTypeVars(constraint.typeVarOfParam(tp2)) @@ -507,9 +496,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling hardenTypeVars(tp2.tp2) case _ => - val res = widenOK - || joinOK - || (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2) + val res = widenOK || joinOK + || recur(tp11, tp2) && recur(tp12, tp2) || containsAnd(tp1) && !joined && { @@ -525,7 +513,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // recursively, so we do it only once. See i14870.scala as a test case, which would // loop for a very long time without the recursion brake. - if res && !tp1.isSoft && state.isCommittable then + if res && !tp1.isSoft && state.isCommittable then // We use a heuristic here where every toplevel type variable on the right hand side // is marked so that it converts all soft unions in its lower bound to hard unions // before it is instantiated. The reason is that the variable's instance type will diff --git a/compiler/src/dotty/tools/dotc/core/TyperState.scala b/compiler/src/dotty/tools/dotc/core/TyperState.scala index 9d7021f5c644..8a07981f1be7 100644 --- a/compiler/src/dotty/tools/dotc/core/TyperState.scala +++ b/compiler/src/dotty/tools/dotc/core/TyperState.scala @@ -25,20 +25,19 @@ object TyperState { type LevelMap = SimpleIdentityMap[TypeVar, Integer] - opaque type Snapshot = (Constraint, TypeVars, TypeVars, LevelMap) + opaque type Snapshot = (Constraint, TypeVars, LevelMap) extension (ts: TyperState) def snapshot()(using Context): Snapshot = - (ts.constraint, ts.ownedVars, ts.hardVars, ts.upLevels) + (ts.constraint, ts.ownedVars, ts.upLevels) def resetTo(state: Snapshot)(using Context): Unit = - val (constraint, ownedVars, hardVars, upLevels) = state + val (constraint, ownedVars, upLevels) = state for tv <- ownedVars do if !ts.ownedVars.contains(tv) then // tv has been instantiated tv.resetInst(ts) ts.constraint = constraint ts.ownedVars = ownedVars - ts.hardVars = hardVars ts.upLevels = upLevels } @@ -92,12 +91,6 @@ class TyperState() { def ownedVars: TypeVars = myOwnedVars def ownedVars_=(vs: TypeVars): Unit = myOwnedVars = vs - /** The set of type variables `tv` such that, if `tv` is instantiated to - * its lower bound, top-level soft unions in the instance type are converted - * to hard unions instead of being widened in `widenOr`. - */ - private var hardVars: TypeVars = _ - private var upLevels: LevelMap = _ /** Initializes all fields except reporter, isCommittable, which need to be @@ -110,7 +103,6 @@ class TyperState() { this.myConstraint = constraint this.previousConstraint = constraint this.myOwnedVars = SimpleIdentitySet.empty - this.hardVars = SimpleIdentitySet.empty this.upLevels = SimpleIdentityMap.empty this.isCommitted = false this @@ -122,19 +114,12 @@ class TyperState() { val ts = TyperState().init(this, this.constraint) .setReporter(reporter) .setCommittable(committable) - ts.hardVars = this.hardVars ts.upLevels = upLevels ts /** The uninstantiated variables */ def uninstVars: collection.Seq[TypeVar] = constraint.uninstVars - /** Register type variable `tv` as hard. */ - def hardenTypeVar(tv: TypeVar): Unit = hardVars += tv - - /** Is type variable `tv` registered as hard? */ - def isHard(tv: TypeVar): Boolean = hardVars.contains(tv) - /** The nestingLevel of `tv` in this typer state */ def nestingLevel(tv: TypeVar): Int = val own = upLevels(tv) @@ -195,11 +180,9 @@ class TyperState() { constr.println(i"committing $this to $targetState, fromConstr = $constraint, toConstr = ${targetState.constraint}") if targetState.constraint eq previousConstraint then targetState.constraint = constraint - targetState.hardVars = hardVars if !ownedVars.isEmpty then ownedVars.foreach(targetState.includeVar) else targetState.mergeConstraintWith(this) - for tv <- hardVars do targetState.hardVars += tv upLevels.foreachBinding { (tv, level) => if level < targetState.nestingLevel(tv) then @@ -247,7 +230,9 @@ class TyperState() { val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv } if this.isCommittable then tvars.foreach(tvar => - if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar)) + if !tvar.inst.exists then + if !isOwnedAnywhere(this, tvar) then includeVar(tvar) + if constraint.isHard(tvar) then constraint = constraint.withHard(tvar)) typeComparer.addToConstraint(tl, tvars) }) && // Integrate the additional constraints on type variables from `other` diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index c7c8b594b165..0df3fa368d5a 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4721,9 +4721,7 @@ object Types { instantiateWith(tp) /** Widen unions when instantiating this variable in the current context? */ - def widenUnions(using Context): Boolean = - if true then !ctx.typerState.constraint.isHard(this) - else !ctx.typerState.isHard(this) + def widenUnions(using Context): Boolean = !ctx.typerState.constraint.isHard(this) /** For uninstantiated type variables: the entry in the constraint (either bounds or * provisional instance value) From dfcfb6b9ae9f334d88a9fb1b346860b1104799d0 Mon Sep 17 00:00:00 2001 From: odersky Date: Wed, 31 Aug 2022 18:17:03 +0200 Subject: [PATCH 5/5] Address review comment --- compiler/src/dotty/tools/dotc/core/TyperState.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TyperState.scala b/compiler/src/dotty/tools/dotc/core/TyperState.scala index 8a07981f1be7..d2df2a2aebef 100644 --- a/compiler/src/dotty/tools/dotc/core/TyperState.scala +++ b/compiler/src/dotty/tools/dotc/core/TyperState.scala @@ -230,13 +230,13 @@ class TyperState() { val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv } if this.isCommittable then tvars.foreach(tvar => - if !tvar.inst.exists then - if !isOwnedAnywhere(this, tvar) then includeVar(tvar) - if constraint.isHard(tvar) then constraint = constraint.withHard(tvar)) + if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar)) typeComparer.addToConstraint(tl, tvars) }) && // Integrate the additional constraints on type variables from `other` + // and merge hardness markers constraint.uninstVars.forall(tv => + if other.isHard(tv) then constraint = constraint.withHard(tv) val p = tv.origin val otherLos = other.lower(p) val otherHis = other.upper(p)