Skip to content

Commit

Permalink
Merge pull request #10 from HyukjinKwon/address-from_csv
Browse files Browse the repository at this point in the history
Address from csv
  • Loading branch information
HyukjinKwon authored Oct 15, 2018
2 parents 88e3b10 + a32bbcb commit b26e49e
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 25 deletions.
11 changes: 9 additions & 2 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -2223,12 +2223,19 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType")
#' schema <- "city STRING, year INT"
#' head(select(df, from_csv(df$csv, schema)))}
#' @note from_csv since 3.0.0
setMethod("from_csv", signature(x = "Column", schema = "character"),
setMethod("from_csv", signature(x = "Column", schema = "characterOrColumn"),
function(x, schema, ...) {
if (class(schema) == "Column") {
jschema <- schema@jc
} else if (is.character(schema)) {
jschema <- callJStatic("org.apache.spark.sql.functions", "lit", schema)
} else {
stop("schema argument should be a column or character")
}
options <- varargsToStrEnv(...)
jc <- callJStatic("org.apache.spark.sql.functions",
"from_csv",
x@jc, schema, options)
x@jc, jschema, options)
column(jc)
})

Expand Down
2 changes: 2 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,8 @@ test_that("column functions", {
df <- as.DataFrame(list(list("col" = "1")))
c <- collect(select(df, alias(from_csv(df$col, "a INT"), "csv")))
expect_equal(c[[1]][[1]]$a, 1)
c <- collect(select(df, alias(from_csv(df$col, lit("a INT")), "csv")))
expect_equal(c[[1]][[1]]$a, 1)

# Test to_json(), from_json()
df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people")
Expand Down
15 changes: 14 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
if sys.version < "3":
from itertools import imap as map

if sys.version >= '3':
basestring = str

from pyspark import since, SparkContext
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StringType, DataType
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
Expand Down Expand Up @@ -2693,9 +2696,19 @@ def from_csv(col, schema, options={}):
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_csv(df.value, "a INT").alias("csv")).collect()
[Row(csv=Row(a=1))]
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect()
[Row(csv=Row(a=1))]
"""

sc = SparkContext._active_spark_context
if isinstance(schema, basestring):
schema = _create_column_from_literal(schema)
elif isinstance(schema, Column):
schema = _to_java_column(schema)
else:
raise TypeError("schema argument should be a column or string")

jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options)
return Column(jc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.csv

object CSVUtils {
object CSVExpressionUtils {
/**
* Filter ignorable rows for CSV iterator (lines empty and starting with `comment`).
* This is currently being used in CSV reading path and CSV schema inference.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class CSVHeaderChecker(
// Note: if there are only comments in the first block, the header would probably
// be not extracted.
if (options.headerFlag && isStartOfFile) {
CSVUtils.extractHeader(lines, options).foreach { header =>
CSVExpressionUtils.extractHeader(lines, options).foreach { header =>
checkHeaderColumnNames(tokenizer.parseLine(header))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class CSVOptions(
}
}

val delimiter = CSVUtils.toChar(
val delimiter = CSVExpressionUtils.toChar(
parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
val parseMode: ParseMode =
parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ private[sql] object UnivocityParser {

val options = parser.options

val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options)
val filteredLines: Iterator[String] = CSVExpressionUtils.filterCommentAndEmpty(lines, options)

val safeParser = new FailureSafeParser[String](
input => Seq(parser.parse(input)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,42 @@ package org.apache.spark.sql.catalyst.csv

import org.apache.spark.SparkFunSuite

class CSVUtilsSuite extends SparkFunSuite {
class CSVExpressionUtilsSuite extends SparkFunSuite {
test("Can parse escaped characters") {
assert(CSVUtils.toChar("""\t""") === '\t')
assert(CSVUtils.toChar("""\r""") === '\r')
assert(CSVUtils.toChar("""\b""") === '\b')
assert(CSVUtils.toChar("""\f""") === '\f')
assert(CSVUtils.toChar("""\"""") === '\"')
assert(CSVUtils.toChar("""\'""") === '\'')
assert(CSVUtils.toChar("""\u0000""") === '\u0000')
assert(CSVUtils.toChar("""\\""") === '\\')
assert(CSVExpressionUtils.toChar("""\t""") === '\t')
assert(CSVExpressionUtils.toChar("""\r""") === '\r')
assert(CSVExpressionUtils.toChar("""\b""") === '\b')
assert(CSVExpressionUtils.toChar("""\f""") === '\f')
assert(CSVExpressionUtils.toChar("""\"""") === '\"')
assert(CSVExpressionUtils.toChar("""\'""") === '\'')
assert(CSVExpressionUtils.toChar("""\u0000""") === '\u0000')
assert(CSVExpressionUtils.toChar("""\\""") === '\\')
}

test("Does not accept delimiter larger than one character") {
val exception = intercept[IllegalArgumentException]{
CSVUtils.toChar("ab")
CSVExpressionUtils.toChar("ab")
}
assert(exception.getMessage.contains("cannot be more than one character"))
}

test("Throws exception for unsupported escaped characters") {
val exception = intercept[IllegalArgumentException]{
CSVUtils.toChar("""\1""")
CSVExpressionUtils.toChar("""\1""")
}
assert(exception.getMessage.contains("Unsupported special character for delimiter"))
}

test("string with one backward slash is prohibited") {
val exception = intercept[IllegalArgumentException]{
CSVUtils.toChar("""\""")
CSVExpressionUtils.toChar("""\""")
}
assert(exception.getMessage.contains("Single backslash is prohibited"))
}

test("output proper error message for empty string") {
val exception = intercept[IllegalArgumentException]{
CSVUtils.toChar("")
CSVExpressionUtils.toChar("")
}
assert(exception.getMessage.contains("Delimiter cannot be empty string"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser}
import org.apache.spark.sql.catalyst.csv.CSVUtils.filterCommentAndEmpty
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -130,7 +129,7 @@ object TextInputCSVDataSource extends CSVDataSource {
val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, parsedOptions)
val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions)
val tokenRDD = sampled.rdd.mapPartitions { iter =>
val filteredLines = filterCommentAndEmpty(iter, parsedOptions)
val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
val linesWithoutHeader =
CSVUtils.filterHeaderLine(filteredLines, maybeFirstLine.get, parsedOptions)
val parser = new CsvParser(parsedOptions.asParserSettings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.csv

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.csv.CSVExpressionUtils
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.functions._

Expand Down Expand Up @@ -125,4 +126,7 @@ object CSVUtils {
csv.sample(withReplacement = false, options.samplingRatio, 1)
}
}

def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] =
CSVExpressionUtils.filterCommentAndEmpty(iter, options)
}
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3882,8 +3882,8 @@ object functions {
* @group collection_funcs
* @since 3.0.0
*/
def from_csv(e: Column, schema: String, options: java.util.Map[String, String]): Column = {
withExpr(new CsvToStructs(e.expr, lit(schema).expr, options.asScala.toMap))
def from_csv(e: Column, schema: Column, options: java.util.Map[String, String]): Column = {
withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap))
}

// scalastyle:off line.size.limit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext {
val schema = "a int"

checkAnswer(
df.select(from_csv($"value", schema, Map[String, String]().asJava)),
df.select(from_csv($"value", lit(schema), Map[String, String]().asJava)),
Row(Row(1)) :: Nil)
}

Expand Down

0 comments on commit b26e49e

Please sign in to comment.