diff --git a/sql/setup.py b/sql/setup.py index 8f2e925..ca7b687 100644 --- a/sql/setup.py +++ b/sql/setup.py @@ -23,7 +23,7 @@ install_requires="sqlfluff==2.3.2", entry_points={ "sqlfluff": [ - "{plugin_logical_name} = {plugin_root_module}.rules".format( + "{plugin_logical_name} = {plugin_root_module}.plugin".format( plugin_logical_name=PLUGIN_LOGICAL_NAME, plugin_root_module=PLUGIN_ROOT_MODULE, ) diff --git a/sql/src/sparksql_upgrade/plugin.py b/sql/src/sparksql_upgrade/plugin.py new file mode 100644 index 0000000..871b562 --- /dev/null +++ b/sql/src/sparksql_upgrade/plugin.py @@ -0,0 +1,48 @@ +"""Custom Spark SQL upgrade rules.""" + +import os.path +from typing import List + + +from sqlfluff.core.config import ConfigLoader +from sqlfluff.core.plugin import hookimpl +from sqlfluff.core.rules import BaseRule + + +@hookimpl +def get_rules() -> List[BaseRule]: + """Get plugin rules.""" + from .rules import ( + Rule_SPARKSQLCAST_L001, + Rule_RESERVEDROPERTIES_L002, + Rule_NOCHARS_L003, + Rule_FORMATSTRONEINDEX_L004, + Rule_SPARKSQL_L004, + Rule_SPARKSQL_L005, + ) + + return [ + Rule_SPARKSQLCAST_L001, + Rule_RESERVEDROPERTIES_L002, + Rule_NOCHARS_L003, + Rule_FORMATSTRONEINDEX_L004, + Rule_SPARKSQL_L004, + Rule_SPARKSQL_L005, + ] + + +@hookimpl +def load_default_config() -> dict: + """Loads the default configuration for the plugin.""" + return ConfigLoader.get_global().load_config_file( + file_dir=os.path.dirname(__file__), + file_name="plugin_default_config.cfg", + ) + + +@hookimpl +def get_configs_info() -> dict: + """Get rule config validations and descriptions.""" + return { + "forbidden_columns": {"definition": "A list of column to forbid"}, + } diff --git a/sql/src/sparksql_upgrade/rules.py b/sql/src/sparksql_upgrade/rules.py index fd0b91f..28f32b2 100644 --- a/sql/src/sparksql_upgrade/rules.py +++ b/sql/src/sparksql_upgrade/rules.py @@ -19,11 +19,6 @@ RuleContext, ) from sqlfluff.core.rules.crawlers import SegmentSeekerCrawler -from sqlfluff.core.rules.doc_decorators import ( - document_configuration, - document_fix_compatible, - document_groups, -) from sqlfluff.utils.functional import FunctionalContext, sp @@ -57,9 +52,6 @@ def get_configs_info() -> dict: } -@document_groups -@document_fix_compatible -@document_configuration class Rule_SPARKSQLCAST_L001(BaseRule): """Spark 3.0 cast as int on strings will fail. @@ -86,6 +78,7 @@ class Rule_SPARKSQLCAST_L001(BaseRule): groups = ("all",) crawl_behaviour = SegmentSeekerCrawler({"function"}) + is_fix_compatible = True def _eval(self, context: RuleContext) -> Optional[LintResult]: """Check integer casts.""" @@ -99,6 +92,8 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]: raw_function_name = function_name_id_seg.raw.upper().strip() function_name = raw_function_name.upper().strip() bracketed_segments = children.first(sp.is_type("bracketed")) + if not bracketed_segments: + return None bracketed = bracketed_segments[0] # Is this a cast function call @@ -126,19 +121,17 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]: return None -@document_groups -@document_fix_compatible -@document_configuration class Rule_FORMATSTRONEINDEX_L004(BaseRule): """Spark 3.3 Format strings are one indexed. Previously on JDK8 format strings were still one indexed, but zero was treated as one. - One JDK17 an exception was thrown. + On JDK17 an exception was thrown. """ groups = ("all",) crawl_behaviour = SegmentSeekerCrawler({"function"}) + is_fix_compatible = True def _eval(self, context: RuleContext) -> Optional[LintResult]: """Check for invalid format strs""" @@ -152,6 +145,8 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]: raw_function_name = function_name_id_seg.raw.upper().strip() function_name = raw_function_name.upper().strip() bracketed_segments = children.first(sp.is_type("bracketed")) + if not bracketed_segments: + return None bracketed = bracketed_segments[0] # Is this a cast function call @@ -179,9 +174,6 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]: return None -@document_groups -@document_fix_compatible -@document_configuration class Rule_NOCHARS_L003(BaseRule): """Spark 3.0 No longer supports CHAR type in non-Hive tables. @@ -191,6 +183,7 @@ class Rule_NOCHARS_L003(BaseRule): groups = ("all",) crawl_behaviour = SegmentSeekerCrawler({"primitive_type"}) + is_fix_compatible = True def _eval(self, context: RuleContext) -> Optional[LintResult]: """Check for char types.""" @@ -208,9 +201,6 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]: return None -@document_groups -@document_fix_compatible -@document_configuration class Rule_RESERVEDROPERTIES_L002(BaseRule): """Spark 3.0 Reserves some table properties @@ -224,6 +214,7 @@ class Rule_RESERVEDROPERTIES_L002(BaseRule): # TODO -- Also look at SET calls once we fix SET DBPROPS in SQLFLUFF grammar. crawl_behaviour = SegmentSeekerCrawler({"property_name_identifier"}) reserved = {"provider", "location", "owner"} + is_fix_compatible = True def _eval(self, context: RuleContext) -> Optional[LintResult]: """Check for reserved properties being configured.""" @@ -379,6 +370,7 @@ class Rule_SPARKSQL_L004(BaseRule): groups = ("all",) crawl_behaviour = SegmentSeekerCrawler({"function"}) + is_fix_compatible = True def _eval(self, context: RuleContext) -> Optional[LintResult]: functional_context = FunctionalContext(context) @@ -391,6 +383,8 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]: raw_function_name = function_name_id_seg.raw.upper().strip() function_name = raw_function_name.upper().strip() bracketed_segments = children.first(sp.is_type("bracketed")) + if not bracketed_segments: + return None bracketed = bracketed_segments[0] if function_name == "EXTRACT": @@ -438,7 +432,7 @@ class Rule_SPARKSQL_L005(BaseRule): groups = ("all",) crawl_behaviour = SegmentSeekerCrawler({"function"}) - # is_fix_compatible = True + is_fix_compatible = True def _eval(self, context: RuleContext) -> Optional[LintResult]: functional_context = FunctionalContext(context)