From 08808fb507947b51ea7656496612a81e11fe66bd Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 18 Jul 2022 15:55:55 +0800 Subject: [PATCH] [SPARK-39760][PYTHON] Support Varchar in PySpark ### What changes were proposed in this pull request? Support Varchar in PySpark ### Why are the changes needed? function parity ### Does this PR introduce _any_ user-facing change? yes, new datatype supported ### How was this patch tested? 1, added UT; 2, manually check against the scala side: ```python In [1]: from pyspark.sql.types import * ...: from pyspark.sql.functions import * ...: ...: df = spark.createDataFrame([(1,), (11,)], ["value"]) ...: ret = df.select(col("value").cast(VarcharType(10))).collect() ...: 22/07/13 17:17:07 WARN CharVarcharUtils: The Spark cast operator does not support char/varchar type and simply treats them as string type. Please use string type directly to avoid confusion. Otherwise, you can set spark.sql.legacy.charVarcharAsString to true, so that Spark treat them as string type as same as Spark 3.0 and earlier In [2]: In [2]: schema = StructType([StructField("a", IntegerType(), True), (StructField("v", VarcharType(10), True))]) ...: description = "this a table created via Catalog.createTable()" ...: table = spark.catalog.createTable("tab3_via_catalog", schema=schema, description=description) ...: table.schema ...: Out[2]: StructType([StructField('a', IntegerType(), True), StructField('v', StringType(), True)]) ``` ```scala scala> import org.apache.spark.sql.types._ import org.apache.spark.sql.types._ scala> import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._ scala> val df = spark.range(0, 10).selectExpr(" id AS value") df: org.apache.spark.sql.DataFrame = [value: bigint] scala> val ret = df.select(col("value").cast(VarcharType(10))).collect() 22/07/13 17:28:56 WARN CharVarcharUtils: The Spark cast operator does not support char/varchar type and simply treats them as string type. Please use string type directly to avoid confusion. Otherwise, you can set spark.sql.legacy.charVarcharAsString to true, so that Spark treat them as string type as same as Spark 3.0 and earlier ret: Array[org.apache.spark.sql.Row] = Array([0], [1], [2], [3], [4], [5], [6], [7], [8], [9]) scala> scala> val schema = StructType(StructField("a", IntegerType, true) :: (StructField("v", VarcharType(10), true) :: Nil)) schema: org.apache.spark.sql.types.StructType = StructType(StructField(a,IntegerType,true),StructField(v,VarcharType(10),true)) scala> val description = "this a table created via Catalog.createTable()" description: String = this a table created via Catalog.createTable() scala> val table = spark.catalog.createTable("tab3_via_catalog", source="json", schema=schema, description=description, options=Map.empty[String, String]) table: org.apache.spark.sql.DataFrame = [a: int, v: string] scala> table.schema res0: org.apache.spark.sql.types.StructType = StructType(StructField(a,IntegerType,true),StructField(v,StringType,true)) ``` Closes #37173 from zhengruifeng/py_add_varchar. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../reference/pyspark.sql/data_types.rst | 1 + python/pyspark/sql/tests/test_types.py | 26 ++++++++++- python/pyspark/sql/types.py | 46 +++++++++++++++++-- 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/data_types.rst b/python/docs/source/reference/pyspark.sql/data_types.rst index d146c640477d6..775f0bf394a49 100644 --- a/python/docs/source/reference/pyspark.sql/data_types.rst +++ b/python/docs/source/reference/pyspark.sql/data_types.rst @@ -40,6 +40,7 @@ Data Types NullType ShortType StringType + VarcharType StructField StructType TimestampType diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index ef0ad82dbb97a..218cfc413db5f 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -38,6 +38,7 @@ DayTimeIntervalType, MapType, StringType, + VarcharType, StructType, StructField, ArrayType, @@ -739,8 +740,12 @@ def test_parse_datatype_string(self): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string for k, t in _all_atomic_types.items(): - self.assertEqual(t(), _parse_datatype_string(k)) + if k != "varchar": + self.assertEqual(t(), _parse_datatype_string(k)) self.assertEqual(IntegerType(), _parse_datatype_string("int")) + self.assertEqual(VarcharType(1), _parse_datatype_string("varchar(1)")) + self.assertEqual(VarcharType(10), _parse_datatype_string("varchar( 10 )")) + self.assertEqual(VarcharType(11), _parse_datatype_string("varchar( 11)")) self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)")) self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )")) self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)")) @@ -1028,6 +1033,7 @@ def test_repr(self): instances = [ NullType(), StringType(), + VarcharType(10), BinaryType(), BooleanType(), DateType(), @@ -1132,6 +1138,15 @@ def test_decimal_type(self): t3 = DecimalType(8) self.assertNotEqual(t2, t3) + def test_varchar_type(self): + v1 = VarcharType(10) + v2 = VarcharType(20) + self.assertTrue(v2 is not v1) + self.assertNotEqual(v1, v2) + v3 = VarcharType(10) + self.assertEqual(v1, v3) + self.assertFalse(v1 is v3) + # regression test for SPARK-10392 def test_datetype_equal_zero(self): dt = DateType() @@ -1211,6 +1226,13 @@ def __init__(self, **kwargs): (1.0, StringType()), ([], StringType()), ({}, StringType()), + # Varchar + ("", VarcharType(10)), + ("", VarcharType(10)), + (1, VarcharType(10)), + (1.0, VarcharType(10)), + ([], VarcharType(10)), + ({}, VarcharType(10)), # UDT (ExamplePoint(1.0, 2.0), ExamplePointUDT()), # Boolean @@ -1267,6 +1289,8 @@ def __init__(self, **kwargs): failure_spec = [ # String (match anything but None) (None, StringType(), ValueError), + # VarcharType (match anything but None) + (None, VarcharType(10), ValueError), # UDT (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), # Boolean diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fa3f3dd7d881e..7ab8f7c9c2d00 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -57,6 +57,7 @@ "DataType", "NullType", "StringType", + "VarcharType", "BinaryType", "BooleanType", "DateType", @@ -181,6 +182,28 @@ class StringType(AtomicType, metaclass=DataTypeSingleton): pass +class VarcharType(AtomicType): + """Varchar data type + + Parameters + ---------- + length : int + the length limitation. + """ + + def __init__(self, length: int): + self.length = length + + def simpleString(self) -> str: + return "varchar(%d)" % (self.length) + + def jsonValue(self) -> str: + return "varchar(%d)" % (self.length) + + def __repr__(self) -> str: + return "VarcharType(%d)" % (self.length) + + class BinaryType(AtomicType, metaclass=DataTypeSingleton): """Binary (byte array) data type.""" @@ -625,6 +648,10 @@ class StructType(DataType): >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True + >>> struct1 = StructType([StructField("f1", VarcharType(10), True)]) + >>> struct2 = StructType([StructField("f1", VarcharType(10), True)]) + >>> struct1 == struct2 + True >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True), ... StructField("f2", IntegerType(), False)]) @@ -944,6 +971,7 @@ def __eq__(self, other: Any) -> bool: _atomic_types: List[Type[DataType]] = [ StringType, + VarcharType, BinaryType, BooleanType, DecimalType, @@ -965,7 +993,7 @@ def __eq__(self, other: Any) -> bool: (v.typeName(), v) for v in _complex_types ) - +_LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)") _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") @@ -987,6 +1015,8 @@ def _parse_datatype_string(s: str) -> DataType: StructType([StructField('a', ByteType(), True), StructField('b', DecimalType(16,8), True)]) >>> _parse_datatype_string("a DOUBLE, b STRING") StructType([StructField('a', DoubleType(), True), StructField('b', StringType(), True)]) + >>> _parse_datatype_string("a DOUBLE, b VARCHAR( 50 )") + StructType([StructField('a', DoubleType(), True), StructField('b', VarcharType(50), True)]) >>> _parse_datatype_string("a: array< short>") StructType([StructField('a', ArrayType(ShortType(), True), True)]) >>> _parse_datatype_string(" map ") @@ -1055,7 +1085,10 @@ def _parse_datatype_json_string(json_string: str) -> DataType: ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype >>> for cls in _all_atomic_types.values(): - ... check_datatype(cls()) + ... if cls is not VarcharType: + ... check_datatype(cls()) + ... else: + ... check_datatype(cls(1)) >>> # Simple ArrayType. >>> simple_arraytype = ArrayType(StringType(), True) @@ -1079,6 +1112,7 @@ def _parse_datatype_json_string(json_string: str) -> DataType: ... StructField("simpleMap", simple_maptype, True), ... StructField("simpleStruct", simple_structtype, True), ... StructField("boolean", BooleanType(), False), + ... StructField("words", VarcharType(10), False), ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) @@ -1111,6 +1145,9 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: if first_field is not None and second_field is None: return DayTimeIntervalType(first_field) return DayTimeIntervalType(first_field, second_field) + elif _LENGTH_VARCHAR.match(json_value): + m = _LENGTH_VARCHAR.match(json_value) + return VarcharType(int(m.group(1))) # type: ignore[union-attr] else: raise ValueError("Could not parse datatype: %s" % json_value) else: @@ -1549,6 +1586,7 @@ def convert_struct(obj: Any) -> Optional[Tuple]: DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str,), + VarcharType: (str,), BinaryType: (bytearray, bytes), DateType: (datetime.date, datetime.datetime), TimestampType: (datetime.datetime,), @@ -1659,8 +1697,8 @@ def verify_acceptable_types(obj: Any) -> None: new_msg("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) ) - if isinstance(dataType, StringType): - # StringType can work with any types + if isinstance(dataType, (StringType, VarcharType)): + # StringType and VarcharType can work with any types def verify_value(obj: Any) -> None: pass