Skip to content

Commit

Permalink
[SPARK-39760][PYTHON] Support Varchar in PySpark
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support Varchar in PySpark

### Why are the changes needed?
function parity

### Does this PR introduce _any_ user-facing change?
yes, new datatype supported

### How was this patch tested?
1, added UT;
2, manually check against the scala side:

```python

In [1]: from pyspark.sql.types import *
   ...: from pyspark.sql.functions import *
   ...:
   ...: df = spark.createDataFrame([(1,), (11,)], ["value"])
   ...: ret = df.select(col("value").cast(VarcharType(10))).collect()
   ...:
22/07/13 17:17:07 WARN CharVarcharUtils: The Spark cast operator does not support char/varchar type and simply treats them as string type. Please use string type directly to avoid confusion. Otherwise, you can set spark.sql.legacy.charVarcharAsString to true, so that Spark treat them as string type as same as Spark 3.0 and earlier

In [2]:

In [2]: schema = StructType([StructField("a", IntegerType(), True), (StructField("v", VarcharType(10), True))])
   ...: description = "this a table created via Catalog.createTable()"
   ...: table = spark.catalog.createTable("tab3_via_catalog", schema=schema, description=description)
   ...: table.schema
   ...:
Out[2]: StructType([StructField('a', IntegerType(), True), StructField('v', StringType(), True)])
```

```scala
scala> import org.apache.spark.sql.types._
import org.apache.spark.sql.types._

scala> import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions._

scala> val df = spark.range(0, 10).selectExpr(" id AS value")
df: org.apache.spark.sql.DataFrame = [value: bigint]

scala> val ret = df.select(col("value").cast(VarcharType(10))).collect()
22/07/13 17:28:56 WARN CharVarcharUtils: The Spark cast operator does not support char/varchar type and simply treats them as string type. Please use string type directly to avoid confusion. Otherwise, you can set spark.sql.legacy.charVarcharAsString to true, so that Spark treat them as string type as same as Spark 3.0 and earlier
ret: Array[org.apache.spark.sql.Row] = Array([0], [1], [2], [3], [4], [5], [6], [7], [8], [9])

scala>

scala> val schema = StructType(StructField("a", IntegerType, true) :: (StructField("v", VarcharType(10), true) :: Nil))
schema: org.apache.spark.sql.types.StructType = StructType(StructField(a,IntegerType,true),StructField(v,VarcharType(10),true))

scala> val description = "this a table created via Catalog.createTable()"
description: String = this a table created via Catalog.createTable()

scala> val table = spark.catalog.createTable("tab3_via_catalog", source="json", schema=schema, description=description, options=Map.empty[String, String])
table: org.apache.spark.sql.DataFrame = [a: int, v: string]

scala> table.schema
res0: org.apache.spark.sql.types.StructType = StructType(StructField(a,IntegerType,true),StructField(v,StringType,true))
```

Closes apache#37173 from zhengruifeng/py_add_varchar.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jul 18, 2022
1 parent e4ca842 commit 08808fb
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql/data_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Data Types
NullType
ShortType
StringType
VarcharType
StructField
StructType
TimestampType
Expand Down
26 changes: 25 additions & 1 deletion python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
DayTimeIntervalType,
MapType,
StringType,
VarcharType,
StructType,
StructField,
ArrayType,
Expand Down Expand Up @@ -739,8 +740,12 @@ def test_parse_datatype_string(self):
from pyspark.sql.types import _all_atomic_types, _parse_datatype_string

for k, t in _all_atomic_types.items():
self.assertEqual(t(), _parse_datatype_string(k))
if k != "varchar":
self.assertEqual(t(), _parse_datatype_string(k))
self.assertEqual(IntegerType(), _parse_datatype_string("int"))
self.assertEqual(VarcharType(1), _parse_datatype_string("varchar(1)"))
self.assertEqual(VarcharType(10), _parse_datatype_string("varchar( 10 )"))
self.assertEqual(VarcharType(11), _parse_datatype_string("varchar( 11)"))
self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)"))
self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
Expand Down Expand Up @@ -1028,6 +1033,7 @@ def test_repr(self):
instances = [
NullType(),
StringType(),
VarcharType(10),
BinaryType(),
BooleanType(),
DateType(),
Expand Down Expand Up @@ -1132,6 +1138,15 @@ def test_decimal_type(self):
t3 = DecimalType(8)
self.assertNotEqual(t2, t3)

def test_varchar_type(self):
v1 = VarcharType(10)
v2 = VarcharType(20)
self.assertTrue(v2 is not v1)
self.assertNotEqual(v1, v2)
v3 = VarcharType(10)
self.assertEqual(v1, v3)
self.assertFalse(v1 is v3)

# regression test for SPARK-10392
def test_datetype_equal_zero(self):
dt = DateType()
Expand Down Expand Up @@ -1211,6 +1226,13 @@ def __init__(self, **kwargs):
(1.0, StringType()),
([], StringType()),
({}, StringType()),
# Varchar
("", VarcharType(10)),
("", VarcharType(10)),
(1, VarcharType(10)),
(1.0, VarcharType(10)),
([], VarcharType(10)),
({}, VarcharType(10)),
# UDT
(ExamplePoint(1.0, 2.0), ExamplePointUDT()),
# Boolean
Expand Down Expand Up @@ -1267,6 +1289,8 @@ def __init__(self, **kwargs):
failure_spec = [
# String (match anything but None)
(None, StringType(), ValueError),
# VarcharType (match anything but None)
(None, VarcharType(10), ValueError),
# UDT
(ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
# Boolean
Expand Down
46 changes: 42 additions & 4 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"DataType",
"NullType",
"StringType",
"VarcharType",
"BinaryType",
"BooleanType",
"DateType",
Expand Down Expand Up @@ -181,6 +182,28 @@ class StringType(AtomicType, metaclass=DataTypeSingleton):
pass


class VarcharType(AtomicType):
"""Varchar data type
Parameters
----------
length : int
the length limitation.
"""

def __init__(self, length: int):
self.length = length

def simpleString(self) -> str:
return "varchar(%d)" % (self.length)

def jsonValue(self) -> str:
return "varchar(%d)" % (self.length)

def __repr__(self) -> str:
return "VarcharType(%d)" % (self.length)


class BinaryType(AtomicType, metaclass=DataTypeSingleton):
"""Binary (byte array) data type."""

Expand Down Expand Up @@ -625,6 +648,10 @@ class StructType(DataType):
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
>>> struct1 = StructType([StructField("f1", VarcharType(10), True)])
>>> struct2 = StructType([StructField("f1", VarcharType(10), True)])
>>> struct1 == struct2
True
>>> struct1 = StructType([StructField("f1", StringType(), True)])
>>> struct2 = StructType([StructField("f1", StringType(), True),
... StructField("f2", IntegerType(), False)])
Expand Down Expand Up @@ -944,6 +971,7 @@ def __eq__(self, other: Any) -> bool:

_atomic_types: List[Type[DataType]] = [
StringType,
VarcharType,
BinaryType,
BooleanType,
DecimalType,
Expand All @@ -965,7 +993,7 @@ def __eq__(self, other: Any) -> bool:
(v.typeName(), v) for v in _complex_types
)


_LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)")
_FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
_INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?")

Expand All @@ -987,6 +1015,8 @@ def _parse_datatype_string(s: str) -> DataType:
StructType([StructField('a', ByteType(), True), StructField('b', DecimalType(16,8), True)])
>>> _parse_datatype_string("a DOUBLE, b STRING")
StructType([StructField('a', DoubleType(), True), StructField('b', StringType(), True)])
>>> _parse_datatype_string("a DOUBLE, b VARCHAR( 50 )")
StructType([StructField('a', DoubleType(), True), StructField('b', VarcharType(50), True)])
>>> _parse_datatype_string("a: array< short>")
StructType([StructField('a', ArrayType(ShortType(), True), True)])
>>> _parse_datatype_string(" map<string , string > ")
Expand Down Expand Up @@ -1055,7 +1085,10 @@ def _parse_datatype_json_string(json_string: str) -> DataType:
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... assert datatype == python_datatype
>>> for cls in _all_atomic_types.values():
... check_datatype(cls())
... if cls is not VarcharType:
... check_datatype(cls())
... else:
... check_datatype(cls(1))
>>> # Simple ArrayType.
>>> simple_arraytype = ArrayType(StringType(), True)
Expand All @@ -1079,6 +1112,7 @@ def _parse_datatype_json_string(json_string: str) -> DataType:
... StructField("simpleMap", simple_maptype, True),
... StructField("simpleStruct", simple_structtype, True),
... StructField("boolean", BooleanType(), False),
... StructField("words", VarcharType(10), False),
... StructField("withMeta", DoubleType(), False, {"name": "age"})])
>>> check_datatype(complex_structtype)
Expand Down Expand Up @@ -1111,6 +1145,9 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType:
if first_field is not None and second_field is None:
return DayTimeIntervalType(first_field)
return DayTimeIntervalType(first_field, second_field)
elif _LENGTH_VARCHAR.match(json_value):
m = _LENGTH_VARCHAR.match(json_value)
return VarcharType(int(m.group(1))) # type: ignore[union-attr]
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
Expand Down Expand Up @@ -1549,6 +1586,7 @@ def convert_struct(obj: Any) -> Optional[Tuple]:
DoubleType: (float,),
DecimalType: (decimal.Decimal,),
StringType: (str,),
VarcharType: (str,),
BinaryType: (bytearray, bytes),
DateType: (datetime.date, datetime.datetime),
TimestampType: (datetime.datetime,),
Expand Down Expand Up @@ -1659,8 +1697,8 @@ def verify_acceptable_types(obj: Any) -> None:
new_msg("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))
)

if isinstance(dataType, StringType):
# StringType can work with any types
if isinstance(dataType, (StringType, VarcharType)):
# StringType and VarcharType can work with any types
def verify_value(obj: Any) -> None:
pass

Expand Down

0 comments on commit 08808fb

Please sign in to comment.