Skip to content

Commit

Permalink
Fix #110661: stop accidentally supported auto-tupling for case classes
Browse files Browse the repository at this point in the history
This commit refactor the PR #5259 so that it will not impact pattern
matching code.

#5259

- tests/pos/automatic-tupling-of-function-parameters.scala

  This test requires dealiasing.

- tests/run/function-arity.scala

  This test requires widen ExprType.
  • Loading branch information
liufengyun committed Jan 15, 2021
1 parent f5e84f4 commit 6a47ca3
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 18 deletions.
28 changes: 15 additions & 13 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,23 @@ object Applications {
}

def productSelectorTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = {
def tupleSelectors(n: Int, tp: Type): List[Type] = {
val sel = extractorMemberType(tp, nme.selectorName(n), errorPos)
// extractorMemberType will return NoType if this is the tail of tuple with an unknown tail
// such as `Int *: T` where `T <: Tuple`.
if (sel.exists) sel :: tupleSelectors(n + 1, tp) else Nil
}
def genTupleSelectors(n: Int, tp: Type): List[Type] = tp match {
case tp: AppliedType if !defn.isTupleClass(tp.tycon.typeSymbol) && tp.derivesFrom(defn.PairClass) =>
val List(head, tail) = tp.args
head :: genTupleSelectors(n, tail)
case _ => tupleSelectors(n, tp)
}
genTupleSelectors(0, tp)
val sels = for (n <- Iterator.from(0)) yield extractorMemberType(tp, nme.selectorName(n), errorPos)
sels.takeWhile(_.exists).toList
}

def tupleComponentTypes(tp: Type)(using Context): List[Type] =
tp.widenExpr.dealias match
case tp: AppliedType =>
if defn.isTupleClass(tp.tycon.typeSymbol) then
tp.args
else if tp.tycon.derivesFrom(defn.PairClass) then
val List(head, tail) = tp.args
head :: tupleComponentTypes(tail)
else
Nil
case _ =>
Nil

def productArity(tp: Type, errorPos: SrcPos = NoSourcePosition)(using Context): Int =
if (defn.isProductSubType(tp)) productSelectorTypes(tp, errorPos).size else -1

Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import TypeComparer.CompareResult
import util.Spans._
import util.common._
import util.{Property, SimpleIdentityMap, SrcPos}
import Applications.{ExtMethodApply, productSelectorTypes, wrapDefs, defaultArgument}
import Applications.{ExtMethodApply, tupleComponentTypes, wrapDefs, defaultArgument}

import collection.mutable
import annotation.tailrec
Expand Down Expand Up @@ -1254,8 +1254,8 @@ class Typer extends Namer
/** Is `formal` a product type which is elementwise compatible with `params`? */
def ptIsCorrectProduct(formal: Type) =
isFullyDefined(formal, ForceDegree.flipBottom) &&
(defn.isProductSubType(formal) || formal.derivesFrom(defn.PairClass)) &&
productSelectorTypes(formal, tree.srcPos).corresponds(params) {
defn.isProductSubType(formal) &&
tupleComponentTypes(formal).corresponds(params) {
(argType, param) =>
param.tpt.isEmpty || argType.widenExpr <:< typedAheadType(param.tpt).tpe
}
Expand Down
2 changes: 1 addition & 1 deletion scala3doc/src/dotty/renderers/MemberRenderer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MemberRenderer(signatureRenderer: SignatureRenderer, buildNode: ContentNod
case Origin.Overrides(defs) =>
def renderDef(d: Overriden): Seq[TagArg] =
Seq(" -> ", signatureRenderer.renderLink(d.name, d.dri))
val headNode = m.inheritedFrom.map(signatureRenderer.renderLink(_, _))
val headNode = m.inheritedFrom.map(form => signatureRenderer.renderLink(form.name, form.dri))
val tailNodes = defs.flatMap(renderDef)
val nodes = headNode.fold(tailNodes.drop(1))(_ +: tailNodes)
tableRow("Definition Classes", div(nodes:_*))
Expand Down
7 changes: 7 additions & 0 deletions tests/neg/i11061.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
case class Foo(a: Int, b: Int)

object Test {
def foo(x: Foo) = List(x).map(_ + _) // error

def main(args: Array[String]): Unit = println(foo(Foo(3, 4)))
}
2 changes: 1 addition & 1 deletion tests/pos/i6199a.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ case class ValueWithEncoder[T](value: T, encoder: Encoder[T])

object Test {
val a: Seq[ValueWithEncoder[_]] = Seq.empty
val b = a.map((value, encoder) => encoder.encode(value))
val b = a.map(ve => ve.encoder.encode(ve.value))
val c: Seq[String] = b
}

0 comments on commit 6a47ca3

Please sign in to comment.