From 524e94f4c2b35665f8cc47f86f241f88a4cd2251 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Mon, 24 Apr 2023 21:18:44 -0400 Subject: [PATCH] fix import rewriting with aliases in multi-from --- pyupgrade/_plugins/imports.py | 10 +++++++--- tests/features/import_replaces_test.py | 7 +++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pyupgrade/_plugins/imports.py b/pyupgrade/_plugins/imports.py index f6e03597..823d634f 100644 --- a/pyupgrade/_plugins/imports.py +++ b/pyupgrade/_plugins/imports.py @@ -281,6 +281,7 @@ class FromImport(NamedTuple): mod_start: int mod_end: int names: tuple[int, ...] + ends: tuple[int, ...] end: int @classmethod @@ -310,11 +311,14 @@ def parse(cls, i: int, tokens: list[Token]) -> FromImport: for j in range(import_token + 1, end) if tokens[j].name == 'NAME' ] + ends_by_offset = {} for i in reversed(range(len(names))): if tokens[names[i]].src == 'as': + ends_by_offset[names[i - 1]] = names[i + 1] del names[i:i + 2] + ends = tuple(ends_by_offset.get(pos, pos) for pos in names) - return cls(start, mod_start, mod_end + 1, tuple(names), end) + return cls(start, mod_start, mod_end + 1, tuple(names), ends, end) def remove_self(self, tokens: list[Token]) -> None: del tokens[self.start:self.end] @@ -327,10 +331,10 @@ def remove_parts(self, tokens: list[Token], idxs: list[int]) -> None: if idx == 0: # look forward until next name and del del tokens[self.names[idx]:self.names[idx + 1]] else: # look backward for comma and del - j = end = self.names[idx] + j = self.names[idx] while tokens[j].src != ',': j -= 1 - del tokens[j:end + 1] + del tokens[j:self.ends[idx] + 1] def _alias_to_s(alias: ast.alias) -> str: diff --git a/tests/features/import_replaces_test.py b/tests/features/import_replaces_test.py index 386afdbf..07bc2288 100644 --- a/tests/features/import_replaces_test.py +++ b/tests/features/import_replaces_test.py @@ -305,6 +305,13 @@ def test_mock_noop_keep_mock(): 'from collections.abc import Callable\n', id='typing.Callable is rewritable in 3.10+ only', ), + pytest.param( + 'from typing import Optional, Sequence as S\n', + (3, 10), + 'from typing import Optional\n' + 'from collections.abc import Sequence as S\n', + id='aliasing in multi from import', + ), ), ) def test_import_replaces(s, min_version, expected):