Skip to content

Commit

Permalink
Fix(tokenizer): initialize self._col properly to avoid edge case (tob…
Browse files Browse the repository at this point in the history
…ymao#1678)

* Fix(tokenizer): maintain token start column correctly

* Revert changes

* Revert changes

* Fixups
  • Loading branch information
georgesittas authored and adrianisk committed Jun 21, 2023
1 parent dceada3 commit 98520ee
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 26 deletions.
18 changes: 9 additions & 9 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def reset(self) -> None:
self._start = 0
self._current = 0
self._line = 1
self._col = 1
self._col = 0
self._comments: t.List[str] = []

self._char = ""
Expand All @@ -810,13 +810,12 @@ def tokenize(self, sql: str) -> t.List[Token]:
self.reset()
self.sql = sql
self.size = len(sql)

try:
self._scan()
except Exception as e:
start = self._current - 50
end = self._current + 50
start = start if start > 0 else 0
end = end if end < self.size else self.size - 1
start = max(self._current - 50, 0)
end = min(self._current + 50, self.size - 1)
context = self.sql[start:end]
raise ValueError(f"Error tokenizing '{context}'") from e

Expand All @@ -841,17 +840,17 @@ def _scan(self, until: t.Optional[t.Callable] = None) -> None:
if until and until():
break

if self.tokens:
if self.tokens and self._comments:
self.tokens[-1].comments.extend(self._comments)

def _chars(self, size: int) -> str:
if size == 1:
return self._char

start = self._current - 1
end = start + size
if end <= self.size:
return self.sql[start:end]
return ""

return self.sql[start:end] if end <= self.size else ""

def _advance(self, i: int = 1, alnum: bool = False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
Expand All @@ -866,6 +865,7 @@ def _advance(self, i: int = 1, alnum: bool = False) -> None:
self._peek = "" if self._end else self.sql[self._current]

if alnum and self._char.isalnum():
# Here we use local variables instead of attributes for better performance
_col = self._col
_current = self._current
_end = self._end
Expand Down
10 changes: 6 additions & 4 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ def test_parse_into_error(self):
{
"description": "Invalid expression / Unexpected token",
"line": 1,
"col": 7,
"col": 6,
"start_context": "",
"highlight": "SELECT",
"end_context": " 1;",
"into_expression": exp.From,
}
]
with self.assertRaises(ParseError) as ctx:
parse_one("SELECT 1;", "sqlite", [exp.From])
parse_one("SELECT 1;", read="sqlite", into=[exp.From])

self.assertEqual(str(ctx.exception), expected_message)
self.assertEqual(ctx.exception.errors, expected_errors)

Expand All @@ -40,7 +41,7 @@ def test_parse_into_errors(self):
{
"description": "Invalid expression / Unexpected token",
"line": 1,
"col": 7,
"col": 6,
"start_context": "",
"highlight": "SELECT",
"end_context": " 1;",
Expand All @@ -49,7 +50,7 @@ def test_parse_into_errors(self):
{
"description": "Invalid expression / Unexpected token",
"line": 1,
"col": 7,
"col": 6,
"start_context": "",
"highlight": "SELECT",
"end_context": " 1;",
Expand All @@ -58,6 +59,7 @@ def test_parse_into_errors(self):
]
with self.assertRaises(ParseError) as ctx:
parse_one("SELECT 1;", "sqlite", [exp.From, exp.Join])

self.assertEqual(str(ctx.exception), expected_message)
self.assertEqual(ctx.exception.errors, expected_errors)

Expand Down
17 changes: 13 additions & 4 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_comment_attachment(self):
for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)

def test_token_line(self):
def test_token_line_col(self):
tokens = Tokenizer().tokenize(
"""SELECT /*
line break
Expand All @@ -30,10 +30,19 @@ def test_token_line(self):
x"""
)

self.assertEqual(tokens[0].line, 1)
self.assertEqual(tokens[0].col, 6)
self.assertEqual(tokens[1].line, 5)
self.assertEqual(tokens[1].col, 3)
self.assertEqual(tokens[-1].line, 6)
self.assertEqual(tokens[-1].col, 1)
self.assertEqual(tokens[2].line, 5)
self.assertEqual(tokens[2].col, 4)
self.assertEqual(tokens[3].line, 6)
self.assertEqual(tokens[3].col, 1)

tokens = Tokenizer().tokenize("SELECT .")

self.assertEqual(tokens[1].line, 1)
self.assertEqual(tokens[1].col, 8)

def test_command(self):
tokens = Tokenizer().tokenize("SHOW;")
Expand All @@ -51,7 +60,7 @@ def test_command(self):
self.assertEqual(tokens[3].token_type, TokenType.SEMICOLON)

def test_error_msg(self):
with self.assertRaisesRegex(ValueError, "Error tokenizing 'select.*"):
with self.assertRaisesRegex(ValueError, "Error tokenizing 'select /'"):
Tokenizer().tokenize("select /*")

def test_jinja(self):
Expand Down
21 changes: 12 additions & 9 deletions tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,14 @@ def test_pretty_line_breaks(self):
def test_error_level(self, logger):
invalid = "x + 1. ("
expected_messages = [
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 9.\n x + 1. \033[4m(\033[0m",
"Expecting ). Line 1, Col: 9.\n x + 1. \033[4m(\033[0m",
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
"Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
]
expected_errors = [
{
"description": "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>",
"line": 1,
"col": 9,
"col": 8,
"start_context": "x + 1. ",
"highlight": "(",
"end_context": "",
Expand All @@ -571,7 +571,7 @@ def test_error_level(self, logger):
{
"description": "Expecting )",
"line": 1,
"col": 9,
"col": 8,
"start_context": "x + 1. ",
"highlight": "(",
"end_context": "",
Expand All @@ -585,26 +585,28 @@ def test_error_level(self, logger):

with self.assertRaises(ParseError) as ctx:
transpile(invalid, error_level=ErrorLevel.IMMEDIATE)

self.assertEqual(str(ctx.exception), expected_messages[0])
self.assertEqual(ctx.exception.errors[0], expected_errors[0])

with self.assertRaises(ParseError) as ctx:
transpile(invalid, error_level=ErrorLevel.RAISE)

self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages))
self.assertEqual(ctx.exception.errors, expected_errors)

more_than_max_errors = "(((("
expected_messages = (
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n"
"Expecting ). Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n"
"Expecting ). Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n"
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"... and 2 more"
)
expected_errors = [
{
"description": "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>",
"line": 1,
"col": 5,
"col": 4,
"start_context": "(((",
"highlight": "(",
"end_context": "",
Expand All @@ -613,7 +615,7 @@ def test_error_level(self, logger):
{
"description": "Expecting )",
"line": 1,
"col": 5,
"col": 4,
"start_context": "(((",
"highlight": "(",
"end_context": "",
Expand All @@ -625,6 +627,7 @@ def test_error_level(self, logger):

with self.assertRaises(ParseError) as ctx:
transpile(more_than_max_errors, error_level=ErrorLevel.RAISE)

self.assertEqual(str(ctx.exception), expected_messages)
self.assertEqual(ctx.exception.errors, expected_errors)

Expand Down

0 comments on commit 98520ee

Please sign in to comment.