diff --git a/sql/src/sparksql_upgrade/rules.py b/sql/src/sparksql_upgrade/rules.py index 28f32b2..bd3eeb5 100644 --- a/sql/src/sparksql_upgrade/rules.py +++ b/sql/src/sparksql_upgrade/rules.py @@ -447,10 +447,11 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]: bracketed_segments = children.first(sp.is_type("bracketed")) if function_name == "APPROX_PERCENTILE" or function_name == "PERCENTILE_APPROX": - print("Found approx function!") + print(f"Found approx function! {function_name}") expression_count = 0 expression_segment = None + # Find "middle" of the approx_percentile(bloop) (e.g. bloop) for segment in bracketed_segments.children().iterate_segments( sp.is_type("expression") ): @@ -460,7 +461,12 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]: if expression_segment is not None: expression_child = expression_segment.children().first() - if expression_child[0].type == "function": + # cast can either be a keyword or a function depending on if were iterating on + # parsed on updated code. + if expression_child[0].type == "keyword": + if expression_child.child[0].raw == "cast": + return None + elif expression_child[0].type == "function": function_name_id_seg = ( expression_child.children() .first(sp.is_type("function_name"))