From 10e69f3f2334352b7c8556fec3306877ad198e09 Mon Sep 17 00:00:00 2001 From: Kevin Schaich Date: Thu, 3 Mar 2022 14:17:04 -0500 Subject: [PATCH] Handle null values in inputs to UDFs --- src/h3_pyspark/__init__.py | 46 ++++++++++++++++++++++++++++++++++++- src/h3_pyspark/indexing.py | 3 ++- src/h3_pyspark/traversal.py | 2 ++ src/h3_pyspark/utils.py | 13 +++++++++++ tests/test_core.py | 12 ++++++++++ 5 files changed, 74 insertions(+), 2 deletions(-) diff --git a/src/h3_pyspark/__init__.py b/src/h3_pyspark/__init__.py index 237b24d..62f421c 100644 --- a/src/h3_pyspark/__init__.py +++ b/src/h3_pyspark/__init__.py @@ -2,7 +2,7 @@ from pyspark.sql import functions as F, types as T import json from inspect import getmembers, isfunction -from .utils import sanitize_types +from .utils import sanitize_types, handle_nulls import sys from shapely import geometry @@ -13,16 +13,19 @@ @F.udf(returnType=T.StringType()) +@handle_nulls def geo_to_h3(lat, lng, resolution): return sanitize_types(h3.geo_to_h3(lat, lng, resolution)) @F.udf(returnType=T.ArrayType(T.DoubleType())) +@handle_nulls def h3_to_geo(h): return sanitize_types(h3.h3_to_geo(h)) @F.udf(returnType=T.StringType()) +@handle_nulls def h3_to_geo_boundary(h, geo_json): # NOTE: this behavior differs from default # h3-pyspark return type will be a valid GeoJSON string if geo_json is set to True @@ -38,41 +41,49 @@ def h3_to_geo_boundary(h, geo_json): @F.udf(returnType=T.IntegerType()) +@handle_nulls def h3_get_resolution(h): return sanitize_types(h3.h3_get_resolution(h)) @F.udf(returnType=T.IntegerType()) +@handle_nulls def h3_get_base_cell(h): return sanitize_types(h3.h3_get_base_cell(h)) @F.udf(returnType=T.LongType()) +@handle_nulls def string_to_h3(h): return sanitize_types(h3.string_to_h3(h)) @F.udf(returnType=T.StringType()) +@handle_nulls def h3_to_string(h): return sanitize_types(h3.h3_to_string(h)) @F.udf(returnType=T.BooleanType()) +@handle_nulls def h3_is_valid(h): return sanitize_types(h3.h3_is_valid(h)) @F.udf(returnType=T.BooleanType()) +@handle_nulls def h3_is_res_class_III(h): return sanitize_types(h3.h3_is_res_class_III(h)) @F.udf(returnType=T.BooleanType()) +@handle_nulls def h3_is_pentagon(h): return sanitize_types(h3.h3_is_pentagon(h)) @F.udf(returnType=T.ArrayType(T.IntegerType())) +@handle_nulls def h3_get_faces(h): return sanitize_types(h3.h3_get_faces(h)) @@ -83,51 +94,61 @@ def h3_get_faces(h): @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def k_ring(origin, k): return sanitize_types(h3.k_ring(origin, k)) @F.udf(returnType=T.ArrayType(T.ArrayType(T.StringType()))) +@handle_nulls def k_ring_distances(origin, k): return sanitize_types(h3.k_ring_distances(origin, k)) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def hex_range(h, k): return sanitize_types(h3.hex_range(h, k)) @F.udf(returnType=T.ArrayType(T.ArrayType(T.StringType()))) +@handle_nulls def hex_range_distances(h, k): return sanitize_types(h3.hex_range_distances(h, k)) @F.udf(returnType=T.MapType(T.StringType(), T.ArrayType(T.ArrayType(T.StringType())))) +@handle_nulls def hex_ranges(h, k): return sanitize_types(h3.hex_ranges(h, k)) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def hex_ring(h, k): return sanitize_types(h3.hex_ring(h, k)) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def h3_line(start, end): return sanitize_types(h3.h3_line(start, end)) @F.udf(returnType=T.IntegerType()) +@handle_nulls def h3_distance(h1, h2): return sanitize_types(h3.h3_distance(h1, h2)) @F.udf(returnType=T.ArrayType(T.IntegerType())) +@handle_nulls def experimental_h3_to_local_ij(origin, h): return sanitize_types(h3.experimental_h3_to_local_ij(origin, h)) @F.udf(returnType=T.StringType()) +@handle_nulls def experimental_local_ij_to_h3(origin, i, j): return sanitize_types(h3.experimental_local_ij_to_h3(origin, i, j)) @@ -138,26 +159,31 @@ def experimental_local_ij_to_h3(origin, i, j): @F.udf(returnType=T.StringType()) +@handle_nulls def h3_to_parent(h, parent_res): return sanitize_types(h3.h3_to_parent(h, parent_res)) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def h3_to_children(h, child_res): return sanitize_types(h3.h3_to_children(h, child_res)) @F.udf(returnType=T.StringType()) +@handle_nulls def h3_to_center_child(h, child_res): return sanitize_types(h3.h3_to_center_child(h, child_res)) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def compact(hexes): return sanitize_types(h3.compact(hexes)) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def uncompact(hexes, res): return sanitize_types(h3.uncompact(hexes, res)) @@ -168,6 +194,7 @@ def uncompact(hexes, res): @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def polyfill(polygons, res, geo_json_conformant): # NOTE: this behavior differs from default # h3-pyspark expect `polygons` argument to be a valid GeoJSON string @@ -176,6 +203,7 @@ def polyfill(polygons, res, geo_json_conformant): @F.udf(returnType=T.StringType()) +@handle_nulls def h3_set_to_multi_polygon(hexes, geo_json): # NOTE: this behavior differs from default # h3-pyspark return type will be a valid GeoJSON string if geo_json is set to True @@ -191,41 +219,49 @@ def h3_set_to_multi_polygon(hexes, geo_json): @F.udf(returnType=T.BooleanType()) +@handle_nulls def h3_indexes_are_neighbors(origin, destination): return sanitize_types(h3.h3_indexes_are_neighbors(origin, destination)) @F.udf(returnType=T.StringType()) +@handle_nulls def get_h3_unidirectional_edge(origin, destination): return sanitize_types(h3.get_h3_unidirectional_edge(origin, destination)) @F.udf(returnType=T.BooleanType()) +@handle_nulls def h3_unidirectional_edge_is_valid(edge): return sanitize_types(h3.h3_unidirectional_edge_is_valid(edge)) @F.udf(returnType=T.StringType()) +@handle_nulls def get_origin_h3_index_from_unidirectional_edge(edge): return sanitize_types(h3.get_origin_h3_index_from_unidirectional_edge(edge)) @F.udf(returnType=T.StringType()) +@handle_nulls def get_destination_h3_index_from_unidirectional_edge(edge): return sanitize_types(h3.get_destination_h3_index_from_unidirectional_edge(edge)) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def get_h3_indexes_from_unidirectional_edge(edge): return sanitize_types(h3.get_h3_indexes_from_unidirectional_edge(edge)) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def get_h3_unidirectional_edges_from_hexagon(h): return sanitize_types(h3.get_h3_unidirectional_edges_from_hexagon(h)) @F.udf(returnType=T.ArrayType(T.ArrayType(T.DoubleType()))) +@handle_nulls def get_h3_unidirectional_edge_boundary(h, geo_json): return sanitize_types(h3.get_h3_unidirectional_edge_boundary(h, geo_json)) @@ -236,41 +272,49 @@ def get_h3_unidirectional_edge_boundary(h, geo_json): @F.udf(returnType=T.DoubleType()) +@handle_nulls def hex_area(res, unit): return sanitize_types(h3.hex_area(res, unit)) @F.udf(returnType=T.DoubleType()) +@handle_nulls def cell_area(h, unit): return sanitize_types(h3.cell_area(h, unit)) @F.udf(returnType=T.DoubleType()) +@handle_nulls def edge_length(res, unit): return sanitize_types(h3.edge_length(res, unit)) @F.udf(returnType=T.DoubleType()) +@handle_nulls def exact_edge_length(res, unit): return sanitize_types(h3.exact_edge_length(res, unit)) @F.udf(returnType=T.IntegerType()) +@handle_nulls def num_hexagons(res): return sanitize_types(h3.num_hexagons(res)) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def get_res0_indexes(): return sanitize_types(h3.get_res0_indexes()) @F.udf(returnType=T.ArrayType(T.StringType())) +@handle_nulls def get_pentagon_indexes(res): return sanitize_types(h3.get_pentagon_indexes(res)) @F.udf(returnType=T.DoubleType()) +@handle_nulls def point_dist(point1, point2, unit): return sanitize_types(h3.point_dist(point1, point2, unit)) diff --git a/src/h3_pyspark/indexing.py b/src/h3_pyspark/indexing.py index f8efc07..852cc80 100644 --- a/src/h3_pyspark/indexing.py +++ b/src/h3_pyspark/indexing.py @@ -12,7 +12,7 @@ MultiPolygon, ) from pyspark.sql import functions as F, types as T -from .utils import flatten, densify +from .utils import flatten, densify, handle_nulls def _index_point_object(point: Point, resolution: int): @@ -117,6 +117,7 @@ def _index_shape(shape: str, resolution: int): @F.udf(T.ArrayType(T.StringType())) +@handle_nulls def index_shape(geometry: Column, resolution: Column): """ Generate an H3 spatial index for an input GeoJSON geometry column. diff --git a/src/h3_pyspark/traversal.py b/src/h3_pyspark/traversal.py index b6c6c15..c5b6a33 100644 --- a/src/h3_pyspark/traversal.py +++ b/src/h3_pyspark/traversal.py @@ -2,6 +2,7 @@ from pyspark.sql import functions as F, types as T from pyspark.sql.column import Column from typing import List +from .utils import handle_nulls def _k_ring_distinct(cells: List[str], distance: int = 1): @@ -15,6 +16,7 @@ def _k_ring_distinct(cells: List[str], distance: int = 1): @F.udf(T.ArrayType(T.StringType())) +@handle_nulls def k_ring_distinct(cells: Column, distance: Column): """ Perform a k-ring operation on every input cell and return the distinct set of output cells. diff --git a/src/h3_pyspark/utils.py b/src/h3_pyspark/utils.py index 4019f6c..582e1fc 100644 --- a/src/h3_pyspark/utils.py +++ b/src/h3_pyspark/utils.py @@ -2,6 +2,19 @@ from shapely.geometry import LineString +def handle_nulls(function): + """ + Decorator to return null if any of the input arguments are null. + """ + + def inner(*args, **kwargs): + if any(arg is None for arg in args): + return None + return function(*args, **kwargs) + + return inner + + def flatten(t): return [item for sublist in t for item in sublist] diff --git a/tests/test_core.py b/tests/test_core.py index df3e928..a017051 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -92,6 +92,18 @@ def test_geo_to_h3(self): expected = sanitize_types(h3.geo_to_h3(*h3_test_args)) assert sort(actual) == sort(expected) + def test_geo_to_h3_single_null_input(self): + actual = df.withColumn("actual", h3_pyspark.geo_to_h3(F.lit(100), F.lit(None), F.lit(9))) + actual = actual.collect()[0]["actual"] + expected = None + assert actual == expected + + def test_geo_to_h3_all_null_inputs(self): + actual = df.withColumn("actual", h3_pyspark.geo_to_h3(F.lit(None), F.lit(None), F.lit(None))) + actual = actual.collect()[0]["actual"] + expected = None + assert actual == expected + def test_h3_to_geo(self): h3_test_args, h3_pyspark_test_args = get_test_args(h3.h3_to_geo)