Skip to content

Commit

Permalink
fix scala#14432: check if scala 2 case class is accessible
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed Apr 22, 2022
1 parent b636633 commit 8699a84
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 19 deletions.
52 changes: 33 additions & 19 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,18 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
monoMap(mirroredType.resultType)

private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): Tree =

/** for a case class, if it is Scala2x then
* check if its constructor can be accessed
* from the calling scope.
*/
def canAccessCtor(cls: Symbol): Boolean =
!cls.is(Scala2x) || {
val ctor = cls.primaryConstructor
!ctor.isOneOf(Private | Protected) // we will never generate the mirror inside a Scala 2 class
&& (!ctor.privateWithin.exists || ctx.owner.isContainedIn(ctor.privateWithin)) // check scope is compatible
}

mirroredType match
case AndType(tp1, tp2) =>
productMirror(tp1, formal, span).orElse(productMirror(tp2, formal, span))
Expand All @@ -291,25 +303,27 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
modulePath.cast(mirrorType)
else if mirroredType.classSymbol.isGenericProduct then
val cls = mirroredType.classSymbol
val accessors = cls.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
val nestedPairs = TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
val (monoType, elemsType) = mirroredType match
case mirroredType: HKTypeLambda =>
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
case _ =>
(mirroredType, nestedPairs)
val elemsLabels = TypeOps.nestedPairs(elemLabels)
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
val mirrorType =
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
val mirrorRef =
if (cls.is(Scala2x)) anonymousMirror(monoType, ExtendsProductMirror, span)
else companionPath(mirroredType, span)
mirrorRef.cast(mirrorType)
if !canAccessCtor(cls) then EmptyTree
else
val accessors = cls.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
val nestedPairs = TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
val (monoType, elemsType) = mirroredType match
case mirroredType: HKTypeLambda =>
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
case _ =>
(mirroredType, nestedPairs)
val elemsLabels = TypeOps.nestedPairs(elemLabels)
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
val mirrorType =
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
val mirrorRef =
if (cls.is(Scala2x)) anonymousMirror(monoType, ExtendsProductMirror, span)
else companionPath(mirroredType, span)
mirrorRef.cast(mirrorType)
else EmptyTree
end productMirror

Expand Down
3 changes: 3 additions & 0 deletions sbt-test/scala2-compat/i14432/app1fail/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import deriving.Mirror

val mFoo = summon[Mirror.Of[Foo]] // error: `Foo.<init>(Int)` is not accessible from `<empty>`.
8 changes: 8 additions & 0 deletions sbt-test/scala2-compat/i14432/app1ok/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import deriving.Mirror

package example {
val mFoo = summon[Mirror.Of[Foo]] // ok, we can access Foo's ctor from here.
}

@main def Test: Unit =
assert(example.mFoo.fromProduct(Some(23)) == example.Foo(23))
5 changes: 5 additions & 0 deletions sbt-test/scala2-compat/i14432/app2fail/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package example

import deriving.Mirror

val mFoo = summon[Mirror.Of[Foo]] // error: `Foo.<init>(Int)` is not accessible from any class.
30 changes: 30 additions & 0 deletions sbt-test/scala2-compat/i14432/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
val scala3Version = sys.props("plugin.scalaVersion")
val scala2Version = sys.props("plugin.scala2Version")

lazy val lib1 = project.in(file("lib1"))
.settings(
scalaVersion := scala2Version
)

lazy val lib2 = project.in(file("lib2"))
.settings(
scalaVersion := scala2Version
)

lazy val app1fail = project.in(file("app1fail"))
.dependsOn(lib1)
.settings(
scalaVersion := scala3Version
)

lazy val app1ok = project.in(file("app1ok"))
.dependsOn(lib1)
.settings(
scalaVersion := scala3Version
)

lazy val app2fail = project.in(file("app2fail"))
.dependsOn(lib2)
.settings(
scalaVersion := scala3Version
)
3 changes: 3 additions & 0 deletions sbt-test/scala2-compat/i14432/lib1/Foo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package example

case class Foo private[example] (i: Int)
3 changes: 3 additions & 0 deletions sbt-test/scala2-compat/i14432/lib2/Foo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package example

case class Foo private (i: Int)
5 changes: 5 additions & 0 deletions sbt-test/scala2-compat/i14432/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
> lib1/compile
> lib2/compile
-> app1fail/compile
> app1ok/run
-> app2fail/compile

0 comments on commit 8699a84

Please sign in to comment.