Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deprecation warnings for SQLFluff 2.0.0+ #95

Merged
merged 1 commit into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sql/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
install_requires="sqlfluff==2.3.2",
entry_points={
"sqlfluff": [
"{plugin_logical_name} = {plugin_root_module}.rules".format(
"{plugin_logical_name} = {plugin_root_module}.plugin".format(
plugin_logical_name=PLUGIN_LOGICAL_NAME,
plugin_root_module=PLUGIN_ROOT_MODULE,
)
Expand Down
48 changes: 48 additions & 0 deletions sql/src/sparksql_upgrade/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Custom Spark SQL upgrade rules."""

import os.path
from typing import List


from sqlfluff.core.config import ConfigLoader
from sqlfluff.core.plugin import hookimpl
from sqlfluff.core.rules import BaseRule


@hookimpl
def get_rules() -> List[BaseRule]:
"""Get plugin rules."""
from .rules import (
Rule_SPARKSQLCAST_L001,
Rule_RESERVEDROPERTIES_L002,
Rule_NOCHARS_L003,
Rule_FORMATSTRONEINDEX_L004,
Rule_SPARKSQL_L004,
Rule_SPARKSQL_L005,
)

return [
Rule_SPARKSQLCAST_L001,
Rule_RESERVEDROPERTIES_L002,
Rule_NOCHARS_L003,
Rule_FORMATSTRONEINDEX_L004,
Rule_SPARKSQL_L004,
Rule_SPARKSQL_L005,
]


@hookimpl
def load_default_config() -> dict:
"""Loads the default configuration for the plugin."""
return ConfigLoader.get_global().load_config_file(
file_dir=os.path.dirname(__file__),
file_name="plugin_default_config.cfg",
)


@hookimpl
def get_configs_info() -> dict:
"""Get rule config validations and descriptions."""
return {
"forbidden_columns": {"definition": "A list of column to forbid"},
}
32 changes: 13 additions & 19 deletions sql/src/sparksql_upgrade/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
RuleContext,
)
from sqlfluff.core.rules.crawlers import SegmentSeekerCrawler
from sqlfluff.core.rules.doc_decorators import (
document_configuration,
document_fix_compatible,
document_groups,
)
from sqlfluff.utils.functional import FunctionalContext, sp


Expand Down Expand Up @@ -57,9 +52,6 @@ def get_configs_info() -> dict:
}


@document_groups
@document_fix_compatible
@document_configuration
class Rule_SPARKSQLCAST_L001(BaseRule):
"""Spark 3.0 cast as int on strings will fail.

Expand All @@ -86,6 +78,7 @@ class Rule_SPARKSQLCAST_L001(BaseRule):

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"function"})
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
"""Check integer casts."""
Expand All @@ -99,6 +92,8 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
raw_function_name = function_name_id_seg.raw.upper().strip()
function_name = raw_function_name.upper().strip()
bracketed_segments = children.first(sp.is_type("bracketed"))
if not bracketed_segments:
return None
bracketed = bracketed_segments[0]

# Is this a cast function call
Expand Down Expand Up @@ -126,19 +121,17 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
return None


@document_groups
@document_fix_compatible
@document_configuration
class Rule_FORMATSTRONEINDEX_L004(BaseRule):
"""Spark 3.3 Format strings are one indexed.


Previously on JDK8 format strings were still one indexed, but zero was treated as one.
One JDK17 an exception was thrown.
On JDK17 an exception was thrown.
"""

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"function"})
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
"""Check for invalid format strs"""
Expand All @@ -152,6 +145,8 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
raw_function_name = function_name_id_seg.raw.upper().strip()
function_name = raw_function_name.upper().strip()
bracketed_segments = children.first(sp.is_type("bracketed"))
if not bracketed_segments:
return None
bracketed = bracketed_segments[0]

# Is this a cast function call
Expand Down Expand Up @@ -179,9 +174,6 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
return None


@document_groups
@document_fix_compatible
@document_configuration
class Rule_NOCHARS_L003(BaseRule):
"""Spark 3.0 No longer supports CHAR type in non-Hive tables.

Expand All @@ -191,6 +183,7 @@ class Rule_NOCHARS_L003(BaseRule):

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"primitive_type"})
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
"""Check for char types."""
Expand All @@ -208,9 +201,6 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
return None


@document_groups
@document_fix_compatible
@document_configuration
class Rule_RESERVEDROPERTIES_L002(BaseRule):
"""Spark 3.0 Reserves some table properties

Expand All @@ -224,6 +214,7 @@ class Rule_RESERVEDROPERTIES_L002(BaseRule):
# TODO -- Also look at SET calls once we fix SET DBPROPS in SQLFLUFF grammar.
crawl_behaviour = SegmentSeekerCrawler({"property_name_identifier"})
reserved = {"provider", "location", "owner"}
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
"""Check for reserved properties being configured."""
Expand Down Expand Up @@ -379,6 +370,7 @@ class Rule_SPARKSQL_L004(BaseRule):

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"function"})
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
functional_context = FunctionalContext(context)
Expand All @@ -391,6 +383,8 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]:
raw_function_name = function_name_id_seg.raw.upper().strip()
function_name = raw_function_name.upper().strip()
bracketed_segments = children.first(sp.is_type("bracketed"))
if not bracketed_segments:
return None
bracketed = bracketed_segments[0]

if function_name == "EXTRACT":
Expand Down Expand Up @@ -438,7 +432,7 @@ class Rule_SPARKSQL_L005(BaseRule):

groups = ("all",)
crawl_behaviour = SegmentSeekerCrawler({"function"})
# is_fix_compatible = True
is_fix_compatible = True

def _eval(self, context: RuleContext) -> Optional[LintResult]:
functional_context = FunctionalContext(context)
Expand Down