Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle null values in inputs to UDFs #10

Merged
merged 1 commit into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/h3_pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pyspark.sql import functions as F, types as T
import json
from inspect import getmembers, isfunction
from .utils import sanitize_types
from .utils import sanitize_types, handle_nulls
import sys
from shapely import geometry

Expand All @@ -13,16 +13,19 @@


@F.udf(returnType=T.StringType())
@handle_nulls
def geo_to_h3(lat, lng, resolution):
return sanitize_types(h3.geo_to_h3(lat, lng, resolution))


@F.udf(returnType=T.ArrayType(T.DoubleType()))
@handle_nulls
def h3_to_geo(h):
return sanitize_types(h3.h3_to_geo(h))


@F.udf(returnType=T.StringType())
@handle_nulls
def h3_to_geo_boundary(h, geo_json):
# NOTE: this behavior differs from default
# h3-pyspark return type will be a valid GeoJSON string if geo_json is set to True
Expand All @@ -38,41 +41,49 @@ def h3_to_geo_boundary(h, geo_json):


@F.udf(returnType=T.IntegerType())
@handle_nulls
def h3_get_resolution(h):
return sanitize_types(h3.h3_get_resolution(h))


@F.udf(returnType=T.IntegerType())
@handle_nulls
def h3_get_base_cell(h):
return sanitize_types(h3.h3_get_base_cell(h))


@F.udf(returnType=T.LongType())
@handle_nulls
def string_to_h3(h):
return sanitize_types(h3.string_to_h3(h))


@F.udf(returnType=T.StringType())
@handle_nulls
def h3_to_string(h):
return sanitize_types(h3.h3_to_string(h))


@F.udf(returnType=T.BooleanType())
@handle_nulls
def h3_is_valid(h):
return sanitize_types(h3.h3_is_valid(h))


@F.udf(returnType=T.BooleanType())
@handle_nulls
def h3_is_res_class_III(h):
return sanitize_types(h3.h3_is_res_class_III(h))


@F.udf(returnType=T.BooleanType())
@handle_nulls
def h3_is_pentagon(h):
return sanitize_types(h3.h3_is_pentagon(h))


@F.udf(returnType=T.ArrayType(T.IntegerType()))
@handle_nulls
def h3_get_faces(h):
return sanitize_types(h3.h3_get_faces(h))

Expand All @@ -83,51 +94,61 @@ def h3_get_faces(h):


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def k_ring(origin, k):
return sanitize_types(h3.k_ring(origin, k))


@F.udf(returnType=T.ArrayType(T.ArrayType(T.StringType())))
@handle_nulls
def k_ring_distances(origin, k):
return sanitize_types(h3.k_ring_distances(origin, k))


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def hex_range(h, k):
return sanitize_types(h3.hex_range(h, k))


@F.udf(returnType=T.ArrayType(T.ArrayType(T.StringType())))
@handle_nulls
def hex_range_distances(h, k):
return sanitize_types(h3.hex_range_distances(h, k))


@F.udf(returnType=T.MapType(T.StringType(), T.ArrayType(T.ArrayType(T.StringType()))))
@handle_nulls
def hex_ranges(h, k):
return sanitize_types(h3.hex_ranges(h, k))


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def hex_ring(h, k):
return sanitize_types(h3.hex_ring(h, k))


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def h3_line(start, end):
return sanitize_types(h3.h3_line(start, end))


@F.udf(returnType=T.IntegerType())
@handle_nulls
def h3_distance(h1, h2):
return sanitize_types(h3.h3_distance(h1, h2))


@F.udf(returnType=T.ArrayType(T.IntegerType()))
@handle_nulls
def experimental_h3_to_local_ij(origin, h):
return sanitize_types(h3.experimental_h3_to_local_ij(origin, h))


@F.udf(returnType=T.StringType())
@handle_nulls
def experimental_local_ij_to_h3(origin, i, j):
return sanitize_types(h3.experimental_local_ij_to_h3(origin, i, j))

Expand All @@ -138,26 +159,31 @@ def experimental_local_ij_to_h3(origin, i, j):


@F.udf(returnType=T.StringType())
@handle_nulls
def h3_to_parent(h, parent_res):
return sanitize_types(h3.h3_to_parent(h, parent_res))


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def h3_to_children(h, child_res):
return sanitize_types(h3.h3_to_children(h, child_res))


@F.udf(returnType=T.StringType())
@handle_nulls
def h3_to_center_child(h, child_res):
return sanitize_types(h3.h3_to_center_child(h, child_res))


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def compact(hexes):
return sanitize_types(h3.compact(hexes))


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def uncompact(hexes, res):
return sanitize_types(h3.uncompact(hexes, res))

Expand All @@ -168,6 +194,7 @@ def uncompact(hexes, res):


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def polyfill(polygons, res, geo_json_conformant):
# NOTE: this behavior differs from default
# h3-pyspark expect `polygons` argument to be a valid GeoJSON string
Expand All @@ -176,6 +203,7 @@ def polyfill(polygons, res, geo_json_conformant):


@F.udf(returnType=T.StringType())
@handle_nulls
def h3_set_to_multi_polygon(hexes, geo_json):
# NOTE: this behavior differs from default
# h3-pyspark return type will be a valid GeoJSON string if geo_json is set to True
Expand All @@ -191,41 +219,49 @@ def h3_set_to_multi_polygon(hexes, geo_json):


@F.udf(returnType=T.BooleanType())
@handle_nulls
def h3_indexes_are_neighbors(origin, destination):
return sanitize_types(h3.h3_indexes_are_neighbors(origin, destination))


@F.udf(returnType=T.StringType())
@handle_nulls
def get_h3_unidirectional_edge(origin, destination):
return sanitize_types(h3.get_h3_unidirectional_edge(origin, destination))


@F.udf(returnType=T.BooleanType())
@handle_nulls
def h3_unidirectional_edge_is_valid(edge):
return sanitize_types(h3.h3_unidirectional_edge_is_valid(edge))


@F.udf(returnType=T.StringType())
@handle_nulls
def get_origin_h3_index_from_unidirectional_edge(edge):
return sanitize_types(h3.get_origin_h3_index_from_unidirectional_edge(edge))


@F.udf(returnType=T.StringType())
@handle_nulls
def get_destination_h3_index_from_unidirectional_edge(edge):
return sanitize_types(h3.get_destination_h3_index_from_unidirectional_edge(edge))


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def get_h3_indexes_from_unidirectional_edge(edge):
return sanitize_types(h3.get_h3_indexes_from_unidirectional_edge(edge))


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def get_h3_unidirectional_edges_from_hexagon(h):
return sanitize_types(h3.get_h3_unidirectional_edges_from_hexagon(h))


@F.udf(returnType=T.ArrayType(T.ArrayType(T.DoubleType())))
@handle_nulls
def get_h3_unidirectional_edge_boundary(h, geo_json):
return sanitize_types(h3.get_h3_unidirectional_edge_boundary(h, geo_json))

Expand All @@ -236,41 +272,49 @@ def get_h3_unidirectional_edge_boundary(h, geo_json):


@F.udf(returnType=T.DoubleType())
@handle_nulls
def hex_area(res, unit):
return sanitize_types(h3.hex_area(res, unit))


@F.udf(returnType=T.DoubleType())
@handle_nulls
def cell_area(h, unit):
return sanitize_types(h3.cell_area(h, unit))


@F.udf(returnType=T.DoubleType())
@handle_nulls
def edge_length(res, unit):
return sanitize_types(h3.edge_length(res, unit))


@F.udf(returnType=T.DoubleType())
@handle_nulls
def exact_edge_length(res, unit):
return sanitize_types(h3.exact_edge_length(res, unit))


@F.udf(returnType=T.IntegerType())
@handle_nulls
def num_hexagons(res):
return sanitize_types(h3.num_hexagons(res))


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def get_res0_indexes():
return sanitize_types(h3.get_res0_indexes())


@F.udf(returnType=T.ArrayType(T.StringType()))
@handle_nulls
def get_pentagon_indexes(res):
return sanitize_types(h3.get_pentagon_indexes(res))


@F.udf(returnType=T.DoubleType())
@handle_nulls
def point_dist(point1, point2, unit):
return sanitize_types(h3.point_dist(point1, point2, unit))

Expand Down
3 changes: 2 additions & 1 deletion src/h3_pyspark/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
MultiPolygon,
)
from pyspark.sql import functions as F, types as T
from .utils import flatten, densify
from .utils import flatten, densify, handle_nulls


def _index_point_object(point: Point, resolution: int):
Expand Down Expand Up @@ -117,6 +117,7 @@ def _index_shape(shape: str, resolution: int):


@F.udf(T.ArrayType(T.StringType()))
@handle_nulls
def index_shape(geometry: Column, resolution: Column):
"""
Generate an H3 spatial index for an input GeoJSON geometry column.
Expand Down
2 changes: 2 additions & 0 deletions src/h3_pyspark/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pyspark.sql import functions as F, types as T
from pyspark.sql.column import Column
from typing import List
from .utils import handle_nulls


def _k_ring_distinct(cells: List[str], distance: int = 1):
Expand All @@ -15,6 +16,7 @@ def _k_ring_distinct(cells: List[str], distance: int = 1):


@F.udf(T.ArrayType(T.StringType()))
@handle_nulls
def k_ring_distinct(cells: Column, distance: Column):
"""
Perform a k-ring operation on every input cell and return the distinct set of output cells.
Expand Down
13 changes: 13 additions & 0 deletions src/h3_pyspark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@
from shapely.geometry import LineString


def handle_nulls(function):
"""
Decorator to return null if any of the input arguments are null.
"""

def inner(*args, **kwargs):
if any(arg is None for arg in args):
return None
return function(*args, **kwargs)

return inner


def flatten(t):
return [item for sublist in t for item in sublist]

Expand Down
12 changes: 12 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ def test_geo_to_h3(self):
expected = sanitize_types(h3.geo_to_h3(*h3_test_args))
assert sort(actual) == sort(expected)

def test_geo_to_h3_single_null_input(self):
actual = df.withColumn("actual", h3_pyspark.geo_to_h3(F.lit(100), F.lit(None), F.lit(9)))
actual = actual.collect()[0]["actual"]
expected = None
assert actual == expected

def test_geo_to_h3_all_null_inputs(self):
actual = df.withColumn("actual", h3_pyspark.geo_to_h3(F.lit(None), F.lit(None), F.lit(None)))
actual = actual.collect()[0]["actual"]
expected = None
assert actual == expected

def test_h3_to_geo(self):
h3_test_args, h3_pyspark_test_args = get_test_args(h3.h3_to_geo)

Expand Down