diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 956fafdf93..1c5df09b2e 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -797,7 +797,7 @@ def reset(self) -> None: self._start = 0 self._current = 0 self._line = 1 - self._col = 1 + self._col = 0 self._comments: t.List[str] = [] self._char = "" @@ -810,13 +810,12 @@ def tokenize(self, sql: str) -> t.List[Token]: self.reset() self.sql = sql self.size = len(sql) + try: self._scan() except Exception as e: - start = self._current - 50 - end = self._current + 50 - start = start if start > 0 else 0 - end = end if end < self.size else self.size - 1 + start = max(self._current - 50, 0) + end = min(self._current + 50, self.size - 1) context = self.sql[start:end] raise ValueError(f"Error tokenizing '{context}'") from e @@ -841,17 +840,17 @@ def _scan(self, until: t.Optional[t.Callable] = None) -> None: if until and until(): break - if self.tokens: + if self.tokens and self._comments: self.tokens[-1].comments.extend(self._comments) def _chars(self, size: int) -> str: if size == 1: return self._char + start = self._current - 1 end = start + size - if end <= self.size: - return self.sql[start:end] - return "" + + return self.sql[start:end] if end <= self.size else "" def _advance(self, i: int = 1, alnum: bool = False) -> None: if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: @@ -866,6 +865,7 @@ def _advance(self, i: int = 1, alnum: bool = False) -> None: self._peek = "" if self._end else self.sql[self._current] if alnum and self._char.isalnum(): + # Here we use local variables instead of attributes for better performance _col = self._col _current = self._current _end = self._end diff --git a/tests/test_parser.py b/tests/test_parser.py index e811e96a44..11df56690e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -22,7 +22,7 @@ def test_parse_into_error(self): { "description": "Invalid expression / Unexpected token", "line": 1, - "col": 7, + "col": 6, "start_context": "", "highlight": "SELECT", "end_context": " 1;", @@ -30,7 +30,8 @@ def test_parse_into_error(self): } ] with self.assertRaises(ParseError) as ctx: - parse_one("SELECT 1;", "sqlite", [exp.From]) + parse_one("SELECT 1;", read="sqlite", into=[exp.From]) + self.assertEqual(str(ctx.exception), expected_message) self.assertEqual(ctx.exception.errors, expected_errors) @@ -40,7 +41,7 @@ def test_parse_into_errors(self): { "description": "Invalid expression / Unexpected token", "line": 1, - "col": 7, + "col": 6, "start_context": "", "highlight": "SELECT", "end_context": " 1;", @@ -49,7 +50,7 @@ def test_parse_into_errors(self): { "description": "Invalid expression / Unexpected token", "line": 1, - "col": 7, + "col": 6, "start_context": "", "highlight": "SELECT", "end_context": " 1;", @@ -58,6 +59,7 @@ def test_parse_into_errors(self): ] with self.assertRaises(ParseError) as ctx: parse_one("SELECT 1;", "sqlite", [exp.From, exp.Join]) + self.assertEqual(str(ctx.exception), expected_message) self.assertEqual(ctx.exception.errors, expected_errors) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index f70d70e8a7..c09eab48db 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -20,7 +20,7 @@ def test_comment_attachment(self): for sql, comment in sql_comment: self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment) - def test_token_line(self): + def test_token_line_col(self): tokens = Tokenizer().tokenize( """SELECT /* line break @@ -30,10 +30,19 @@ def test_token_line(self): x""" ) + self.assertEqual(tokens[0].line, 1) + self.assertEqual(tokens[0].col, 6) self.assertEqual(tokens[1].line, 5) self.assertEqual(tokens[1].col, 3) - self.assertEqual(tokens[-1].line, 6) - self.assertEqual(tokens[-1].col, 1) + self.assertEqual(tokens[2].line, 5) + self.assertEqual(tokens[2].col, 4) + self.assertEqual(tokens[3].line, 6) + self.assertEqual(tokens[3].col, 1) + + tokens = Tokenizer().tokenize("SELECT .") + + self.assertEqual(tokens[1].line, 1) + self.assertEqual(tokens[1].col, 8) def test_command(self): tokens = Tokenizer().tokenize("SHOW;") @@ -51,7 +60,7 @@ def test_command(self): self.assertEqual(tokens[3].token_type, TokenType.SEMICOLON) def test_error_msg(self): - with self.assertRaisesRegex(ValueError, "Error tokenizing 'select.*"): + with self.assertRaisesRegex(ValueError, "Error tokenizing 'select /'"): Tokenizer().tokenize("select /*") def test_jinja(self): diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 701f8ef4ea..1085b092bf 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -555,14 +555,14 @@ def test_pretty_line_breaks(self): def test_error_level(self, logger): invalid = "x + 1. (" expected_messages = [ - "Required keyword: 'expressions' missing for . Line 1, Col: 9.\n x + 1. \033[4m(\033[0m", - "Expecting ). Line 1, Col: 9.\n x + 1. \033[4m(\033[0m", + "Required keyword: 'expressions' missing for . Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", + "Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", ] expected_errors = [ { "description": "Required keyword: 'expressions' missing for ", "line": 1, - "col": 9, + "col": 8, "start_context": "x + 1. ", "highlight": "(", "end_context": "", @@ -571,7 +571,7 @@ def test_error_level(self, logger): { "description": "Expecting )", "line": 1, - "col": 9, + "col": 8, "start_context": "x + 1. ", "highlight": "(", "end_context": "", @@ -585,26 +585,28 @@ def test_error_level(self, logger): with self.assertRaises(ParseError) as ctx: transpile(invalid, error_level=ErrorLevel.IMMEDIATE) + self.assertEqual(str(ctx.exception), expected_messages[0]) self.assertEqual(ctx.exception.errors[0], expected_errors[0]) with self.assertRaises(ParseError) as ctx: transpile(invalid, error_level=ErrorLevel.RAISE) + self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages)) self.assertEqual(ctx.exception.errors, expected_errors) more_than_max_errors = "((((" expected_messages = ( - "Required keyword: 'this' missing for . Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n" - "Expecting ). Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n" - "Expecting ). Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n" + "Required keyword: 'this' missing for . Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "... and 2 more" ) expected_errors = [ { "description": "Required keyword: 'this' missing for ", "line": 1, - "col": 5, + "col": 4, "start_context": "(((", "highlight": "(", "end_context": "", @@ -613,7 +615,7 @@ def test_error_level(self, logger): { "description": "Expecting )", "line": 1, - "col": 5, + "col": 4, "start_context": "(((", "highlight": "(", "end_context": "", @@ -625,6 +627,7 @@ def test_error_level(self, logger): with self.assertRaises(ParseError) as ctx: transpile(more_than_max_errors, error_level=ErrorLevel.RAISE) + self.assertEqual(str(ctx.exception), expected_messages) self.assertEqual(ctx.exception.errors, expected_errors)