diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala new file mode 100644 index 0000000000000..724ce9af49f77 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import scala.tools.nsc.interpreter.{ExprTyper, IR} + +trait SparkExprTyper extends ExprTyper { + + import repl._ + import global.{reporter => _, Import => _, _} + import naming.freshInternalVarName + + def doInterpret(code: String): IR.Result = { + // interpret/interpretSynthetic may change the phase, + // which would have unintended effects on types. + val savedPhase = phase + try interpretSynthetic(code) finally phase = savedPhase + } + + override def symbolOfLine(code: String): Symbol = { + def asExpr(): Symbol = { + val name = freshInternalVarName() + // Typing it with a lazy val would give us the right type, but runs + // into compiler bugs with things like existentials, so we compile it + // behind a def and strip the NullaryMethodType which wraps the expr. + val line = "def " + name + " = " + code + + doInterpret(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + // drop NullaryMethodType + sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) + case _ => NoSymbol + } + } + + def asDefn(): Symbol = { + val old = repl.definedSymbolList.toSet + + doInterpret(code) match { + case IR.Success => + repl.definedSymbolList filterNot old match { + case Nil => NoSymbol + case sym :: Nil => sym + case syms => NoSymbol.newOverloaded(NoPrefix, syms) + } + case _ => NoSymbol + } + } + + def asError(): Symbol = { + doInterpret(code) + NoSymbol + } + + beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() + } + +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 3ce7cc7c85f74..e69441a475e9a 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -35,6 +35,10 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) def this() = this(None, new JPrintWriter(Console.out, true)) + override def createInterpreter(): Unit = { + intp = new SparkILoopInterpreter(settings, out) + } + val initializationCommands: Seq[String] = Seq( """ @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala new file mode 100644 index 0000000000000..0803426403af5 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter._ + +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { + self => + + override lazy val memberHandlers = new { + val intp: self.type = self + } with MemberHandlers { + import intp.global._ + + override def chooseHandler(member: intp.global.Tree): MemberHandler = member match { + case member: Import => new SparkImportHandler(member) + case _ => super.chooseHandler (member) + } + + class SparkImportHandler(imp: Import) extends ImportHandler(imp: Import) { + + override def targetType: Type = intp.global.rootMirror.getModuleIfDefined("" + expr) match { + case NoSymbol => intp.typeOfExpression("" + expr) + case sym => sym.tpe + } + + private def safeIndexOf(name: Name, s: String): Int = fixIndexOf(name, pos(name, s)) + private def fixIndexOf(name: Name, idx: Int): Int = if (idx == name.length) -1 else idx + private def pos(name: Name, s: String): Int = { + var i = name.pos(s.charAt(0), 0) + val sLen = s.length() + if (sLen == 1) return i + while (i + sLen <= name.length) { + var j = 1 + while (s.charAt(j) == name.charAt(i + j)) { + j += 1 + if (j == sLen) return i + } + i = name.pos(s.charAt(0), i + 1) + } + name.length + } + + private def isFlattenedSymbol(sym: Symbol): Boolean = + sym.owner.isPackageClass && + sym.name.containsName(nme.NAME_JOIN_STRING) && + sym.owner.info.member(sym.name.take( + safeIndexOf(sym.name, nme.NAME_JOIN_STRING))) != NoSymbol + + private def importableTargetMembers = + importableMembers(exitingTyper(targetType)).filterNot(isFlattenedSymbol).toList + + def isIndividualImport(s: ImportSelector): Boolean = + s.name != nme.WILDCARD && s.rename != nme.WILDCARD + def isWildcardImport(s: ImportSelector): Boolean = + s.name == nme.WILDCARD + + // non-wildcard imports + private def individualSelectors = selectors filter isIndividualImport + + override val importsWildcard: Boolean = selectors exists isWildcardImport + + lazy val importableSymbolsWithRenames: List[(Symbol, Name)] = { + val selectorRenameMap = + individualSelectors.flatMap(x => x.name.bothNames zip x.rename.bothNames).toMap + importableTargetMembers flatMap (m => selectorRenameMap.get(m.name) map (m -> _)) + } + + override lazy val individualSymbols: List[Symbol] = importableSymbolsWithRenames map (_._1) + override lazy val wildcardSymbols: List[Symbol] = + if (importsWildcard) importableTargetMembers else Nil + + } + + } + + object expressionTyper extends { + val repl: SparkILoopInterpreter.this.type = self + } with SparkExprTyper { } + + override def symbolOfLine(code: String): global.Symbol = + expressionTyper.symbolOfLine(code) + + override def typeOfExpression(expr: String, silent: Boolean): global.Type = + expressionTyper.typeOfExpression(expr, silent) + +} diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 905b41cdc1594..a5053521f8e31 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -227,4 +227,14 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("error: not found: value sc", output) } + test("spark-shell should find imported types in class constructors and extends clause") { + val output = runInterpreter("local", + """ + |import org.apache.spark.Partition + |class P(p: Partition) + |class P(val index: Int) extends Partition + """.stripMargin) + assertDoesNotContain("error: not found: type Partition", output) + } + }