Skip to content

Commit

Permalink
Fix deprecation warnings for SQLFluff 2.0.0+
Browse files Browse the repository at this point in the history
  • Loading branch information
stristr committed Jan 10, 2024
1 parent a71765b commit 915e280
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 20 deletions.
2 changes: 1 addition & 1 deletion sql/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
install_requires="sqlfluff==2.3.2",
entry_points={
"sqlfluff": [
"{plugin_logical_name} = {plugin_root_module}.rules".format(
"{plugin_logical_name} = {plugin_root_module}.plugin".format(
plugin_logical_name=PLUGIN_LOGICAL_NAME,
plugin_root_module=PLUGIN_ROOT_MODULE,
)
Expand Down
48 changes: 48 additions & 0 deletions sql/src/sparksql_upgrade/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Custom Spark SQL upgrade rules."""

import os.path
from typing import List


from sqlfluff.core.config import ConfigLoader
from sqlfluff.core.plugin import hookimpl
from sqlfluff.core.rules import BaseRule


@hookimpl
def get_rules() -> List[BaseRule]:
"""Get plugin rules."""
from .rules import (
Rule_SPARKSQLCAST_L001,
Rule_RESERVEDROPERTIES_L002,
Rule_NOCHARS_L003,
Rule_FORMATSTRONEINDEX_L004,
Rule_SPARKSQL_L004,
Rule_SPARKSQL_L005,
)

return [
Rule_SPARKSQLCAST_L001,
Rule_RESERVEDROPERTIES_L002,
Rule_NOCHARS_L003,
Rule_FORMATSTRONEINDEX_L004,
Rule_SPARKSQL_L004,
Rule_SPARKSQL_L005,
]


@hookimpl
def load_default_config() -> dict:
"""Loads the default configuration for the plugin."""
return ConfigLoader.get_global().load_config_file(
file_dir=os.path.dirname(__file__),
file_name="plugin_default_config.cfg",
)


@hookimpl
def get_configs_info() -> dict:
"""Get rule config validations and descriptions."""
return {
"forbidden_columns": {"definition": "A list of column to forbid"},
}
32 changes: 13 additions & 19 deletions sql/src/sparksql_upgrade/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
RuleContext,
)
from sqlfluff.core.rules.crawlers import SegmentSeekerCrawler
from sqlfluff.core.rules.doc_decorators import (
document_configuration,
document_fix_compatible,
document_groups,
)
from sqlfluff.utils.functional import FunctionalContext, sp


Expand Down Expand Up @@ -57,9 +52,6 @@ def get_configs_info() -> dict:
}


@document_groups
@document_fix_compatible
@document_configuration
class Rule_SPARKSQLCAST_L001(BaseRule):
"""Spark 3.0 cast as int on strings will fail.
Expand All @@ -86,6 +78,7 @@ class Rule_SPARKSQLCAST_L001(BaseRule):

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"function"})
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
"""Check integer casts."""
Expand All @@ -99,6 +92,8 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
raw_function_name = function_name_id_seg.raw.upper().strip()
function_name = raw_function_name.upper().strip()
bracketed_segments = children.first(sp.is_type("bracketed"))
if not bracketed_segments:
return None
bracketed = bracketed_segments[0]

# Is this a cast function call
Expand Down Expand Up @@ -126,19 +121,17 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
return None


@document_groups
@document_fix_compatible
@document_configuration
class Rule_FORMATSTRONEINDEX_L004(BaseRule):
"""Spark 3.3 Format strings are one indexed.
Previously on JDK8 format strings were still one indexed, but zero was treated as one.
One JDK17 an exception was thrown.
On JDK17 an exception was thrown.
"""

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"function"})
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
"""Check for invalid format strs"""
Expand All @@ -152,6 +145,8 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
raw_function_name = function_name_id_seg.raw.upper().strip()
function_name = raw_function_name.upper().strip()
bracketed_segments = children.first(sp.is_type("bracketed"))
if not bracketed_segments:
return None
bracketed = bracketed_segments[0]

# Is this a cast function call
Expand Down Expand Up @@ -179,9 +174,6 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
return None


@document_groups
@document_fix_compatible
@document_configuration
class Rule_NOCHARS_L003(BaseRule):
"""Spark 3.0 No longer supports CHAR type in non-Hive tables.
Expand All @@ -191,6 +183,7 @@ class Rule_NOCHARS_L003(BaseRule):

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"primitive_type"})
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
"""Check for char types."""
Expand All @@ -208,9 +201,6 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
return None


@document_groups
@document_fix_compatible
@document_configuration
class Rule_RESERVEDROPERTIES_L002(BaseRule):
"""Spark 3.0 Reserves some table properties
Expand All @@ -224,6 +214,7 @@ class Rule_RESERVEDROPERTIES_L002(BaseRule):
# TODO -- Also look at SET calls once we fix SET DBPROPS in SQLFLUFF grammar.
crawl_behaviour = SegmentSeekerCrawler({"property_name_identifier"})
reserved = {"provider", "location", "owner"}
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
"""Check for reserved properties being configured."""
Expand Down Expand Up @@ -379,6 +370,7 @@ class Rule_SPARKSQL_L004(BaseRule):

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"function"})
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
functional_context = FunctionalContext(context)
Expand All @@ -391,6 +383,8 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
raw_function_name = function_name_id_seg.raw.upper().strip()
function_name = raw_function_name.upper().strip()
bracketed_segments = children.first(sp.is_type("bracketed"))
if not bracketed_segments:
return None
bracketed = bracketed_segments[0]

if function_name == "EXTRACT":
Expand Down Expand Up @@ -438,7 +432,7 @@ class Rule_SPARKSQL_L005(BaseRule):

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"function"})
# is_fix_compatible = True
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
functional_context = FunctionalContext(context)
Expand Down

0 comments on commit 915e280

Please sign in to comment.