diff --git a/CHANGES.md b/CHANGES.md index b1fe25ef625..780a00247ce 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,9 @@ +- Support formatting ranges of lines with the new `--line-ranges` command-line option + (#4020). + ### Stable style - Fix crash on formatting bytes strings that look like docstrings (#4003) diff --git a/docs/usage_and_configuration/the_basics.md b/docs/usage_and_configuration/the_basics.md index f25dbb13d4d..dbd8c7ba434 100644 --- a/docs/usage_and_configuration/the_basics.md +++ b/docs/usage_and_configuration/the_basics.md @@ -175,6 +175,23 @@ All done! ✨ 🍰 ✨ 1 file would be reformatted. ``` +### `--line-ranges` + +When specified, _Black_ will try its best to only format these lines. + +This option can be specified multiple times, and a union of the lines will be formatted. +Each range must be specified as two integers connected by a `-`: `-`. The +`` and `` integer indices are 1-based and inclusive on both ends. + +_Black_ may still format lines outside of the ranges for multi-line statements. +Formatting more than one file or any ipynb files with this option is not supported. This +option cannot be specified in the `pyproject.toml` config. + +Example: `black --line-ranges=1-10 --line-ranges=21-30 test.py` will format lines from +`1` to `10` and `21` to `30`. + +This option is mainly for editor integrations, such as "Format Selection". + #### `--color` / `--no-color` Show (or do not show) colored diff. Only applies when `--diff` is given. diff --git a/src/black/__init__.py b/src/black/__init__.py index c11a66b7bc8..5aca3316df0 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -13,6 +13,7 @@ from pathlib import Path from typing import ( Any, + Collection, Dict, Generator, Iterator, @@ -77,6 +78,7 @@ from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out from black.parsing import InvalidInput # noqa F401 from black.parsing import lib2to3_parse, parse_ast, stringify_ast +from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges from black.report import Changed, NothingChanged, Report from black.trans import iter_fexpr_spans from blib2to3.pgen2 import token @@ -163,6 +165,12 @@ def read_pyproject_toml( "extend-exclude", "Config key extend-exclude must be a string" ) + line_ranges = config.get("line_ranges") + if line_ranges is not None: + raise click.BadOptionUsage( + "line-ranges", "Cannot use line-ranges in the pyproject.toml file." + ) + default_map: Dict[str, Any] = {} if ctx.default_map: default_map.update(ctx.default_map) @@ -304,6 +312,19 @@ def validate_regex( is_flag=True, help="Don't write the files back, just output a diff for each file on stdout.", ) +@click.option( + "--line-ranges", + multiple=True, + metavar="START-END", + help=( + "When specified, _Black_ will try its best to only format these lines. This" + " option can be specified multiple times, and a union of the lines will be" + " formatted. Each range must be specified as two integers connected by a `-`:" + " `-`. The `` and `` integer indices are 1-based and" + " inclusive on both ends." + ), + default=(), +) @click.option( "--color/--no-color", is_flag=True, @@ -443,6 +464,7 @@ def main( # noqa: C901 target_version: List[TargetVersion], check: bool, diff: bool, + line_ranges: Sequence[str], color: bool, fast: bool, pyi: bool, @@ -544,6 +566,18 @@ def main( # noqa: C901 python_cell_magics=set(python_cell_magics), ) + lines: List[Tuple[int, int]] = [] + if line_ranges: + if ipynb: + err("Cannot use --line-ranges with ipynb files.") + ctx.exit(1) + + try: + lines = parse_line_ranges(line_ranges) + except ValueError as e: + err(str(e)) + ctx.exit(1) + if code is not None: # Run in quiet mode by default with -c; the extra output isn't useful. # You can still pass -v to get verbose output. @@ -553,7 +587,12 @@ def main( # noqa: C901 if code is not None: reformat_code( - content=code, fast=fast, write_back=write_back, mode=mode, report=report + content=code, + fast=fast, + write_back=write_back, + mode=mode, + report=report, + lines=lines, ) else: assert root is not None # root is only None if code is not None @@ -588,10 +627,14 @@ def main( # noqa: C901 write_back=write_back, mode=mode, report=report, + lines=lines, ) else: from black.concurrency import reformat_many + if lines: + err("Cannot use --line-ranges to format multiple files.") + ctx.exit(1) reformat_many( sources=sources, fast=fast, @@ -714,7 +757,13 @@ def path_empty( def reformat_code( - content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report + content: str, + fast: bool, + write_back: WriteBack, + mode: Mode, + report: Report, + *, + lines: Collection[Tuple[int, int]] = (), ) -> None: """ Reformat and print out `content` without spawning child processes. @@ -727,7 +776,7 @@ def reformat_code( try: changed = Changed.NO if format_stdin_to_stdout( - content=content, fast=fast, write_back=write_back, mode=mode + content=content, fast=fast, write_back=write_back, mode=mode, lines=lines ): changed = Changed.YES report.done(path, changed) @@ -741,7 +790,13 @@ def reformat_code( # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26 @mypyc_attr(patchable=True) def reformat_one( - src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" + src: Path, + fast: bool, + write_back: WriteBack, + mode: Mode, + report: "Report", + *, + lines: Collection[Tuple[int, int]] = (), ) -> None: """Reformat a single file under `src` without spawning child processes. @@ -766,7 +821,9 @@ def reformat_one( mode = replace(mode, is_pyi=True) elif src.suffix == ".ipynb": mode = replace(mode, is_ipynb=True) - if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode): + if format_stdin_to_stdout( + fast=fast, write_back=write_back, mode=mode, lines=lines + ): changed = Changed.YES else: cache = Cache.read(mode) @@ -774,7 +831,7 @@ def reformat_one( if not cache.is_changed(src): changed = Changed.CACHED if changed is not Changed.CACHED and format_file_in_place( - src, fast=fast, write_back=write_back, mode=mode + src, fast=fast, write_back=write_back, mode=mode, lines=lines ): changed = Changed.YES if (write_back is WriteBack.YES and changed is not Changed.CACHED) or ( @@ -794,6 +851,8 @@ def format_file_in_place( mode: Mode, write_back: WriteBack = WriteBack.NO, lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy + *, + lines: Collection[Tuple[int, int]] = (), ) -> bool: """Format file under `src` path. Return True if changed. @@ -813,7 +872,9 @@ def format_file_in_place( header = buf.readline() src_contents, encoding, newline = decode_bytes(buf.read()) try: - dst_contents = format_file_contents(src_contents, fast=fast, mode=mode) + dst_contents = format_file_contents( + src_contents, fast=fast, mode=mode, lines=lines + ) except NothingChanged: return False except JSONDecodeError: @@ -858,6 +919,7 @@ def format_stdin_to_stdout( content: Optional[str] = None, write_back: WriteBack = WriteBack.NO, mode: Mode, + lines: Collection[Tuple[int, int]] = (), ) -> bool: """Format file on stdin. Return True if changed. @@ -876,7 +938,7 @@ def format_stdin_to_stdout( dst = src try: - dst = format_file_contents(src, fast=fast, mode=mode) + dst = format_file_contents(src, fast=fast, mode=mode, lines=lines) return True except NothingChanged: @@ -904,7 +966,11 @@ def format_stdin_to_stdout( def check_stability_and_equivalence( - src_contents: str, dst_contents: str, *, mode: Mode + src_contents: str, + dst_contents: str, + *, + mode: Mode, + lines: Collection[Tuple[int, int]] = (), ) -> None: """Perform stability and equivalence checks. @@ -913,10 +979,16 @@ def check_stability_and_equivalence( content differently. """ assert_equivalent(src_contents, dst_contents) - assert_stable(src_contents, dst_contents, mode=mode) + assert_stable(src_contents, dst_contents, mode=mode, lines=lines) -def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: +def format_file_contents( + src_contents: str, + *, + fast: bool, + mode: Mode, + lines: Collection[Tuple[int, int]] = (), +) -> FileContent: """Reformat contents of a file and return new contents. If `fast` is False, additionally confirm that the reformatted code is @@ -926,13 +998,15 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo if mode.is_ipynb: dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode) else: - dst_contents = format_str(src_contents, mode=mode) + dst_contents = format_str(src_contents, mode=mode, lines=lines) if src_contents == dst_contents: raise NothingChanged if not fast and not mode.is_ipynb: # Jupyter notebooks will already have been checked above. - check_stability_and_equivalence(src_contents, dst_contents, mode=mode) + check_stability_and_equivalence( + src_contents, dst_contents, mode=mode, lines=lines + ) return dst_contents @@ -1043,7 +1117,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon raise NothingChanged -def format_str(src_contents: str, *, mode: Mode) -> str: +def format_str( + src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = () +) -> str: """Reformat a string and return new contents. `mode` determines formatting options, such as how many characters per line are @@ -1073,16 +1149,20 @@ def f( hey """ - dst_contents = _format_str_once(src_contents, mode=mode) + dst_contents = _format_str_once(src_contents, mode=mode, lines=lines) # Forced second pass to work around optional trailing commas (becoming # forced trailing commas on pass 2) interacting differently with optional # parentheses. Admittedly ugly. if src_contents != dst_contents: - return _format_str_once(dst_contents, mode=mode) + if lines: + lines = adjusted_lines(lines, src_contents, dst_contents) + return _format_str_once(dst_contents, mode=mode, lines=lines) return dst_contents -def _format_str_once(src_contents: str, *, mode: Mode) -> str: +def _format_str_once( + src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = () +) -> str: src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) dst_blocks: List[LinesBlock] = [] if mode.target_versions: @@ -1097,7 +1177,11 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str: if supports_feature(versions, feature) } normalize_fmt_off(src_node, mode) - lines = LineGenerator(mode=mode, features=context_manager_features) + if lines: + # This should be called after normalize_fmt_off. + convert_unchanged_lines(src_node, lines) + + line_generator = LineGenerator(mode=mode, features=context_manager_features) elt = EmptyLineTracker(mode=mode) split_line_features = { feature @@ -1105,7 +1189,7 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str: if supports_feature(versions, feature) } block: Optional[LinesBlock] = None - for current_line in lines.visit(src_node): + for current_line in line_generator.visit(src_node): block = elt.maybe_empty_lines(current_line) dst_blocks.append(block) for line in transform_line( @@ -1373,12 +1457,16 @@ def assert_equivalent(src: str, dst: str) -> None: ) from None -def assert_stable(src: str, dst: str, mode: Mode) -> None: +def assert_stable( + src: str, dst: str, mode: Mode, *, lines: Collection[Tuple[int, int]] = () +) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" # We shouldn't call format_str() here, because that formats the string # twice and may hide a bug where we bounce back and forth between two # versions. - newdst = _format_str_once(dst, mode=mode) + if lines: + lines = adjusted_lines(lines, src, dst) + newdst = _format_str_once(dst, mode=mode, lines=lines) if dst != newdst: log = dump_to_file( str(mode), diff --git a/src/black/nodes.py b/src/black/nodes.py index fff8e05a118..9251b0defb0 100644 --- a/src/black/nodes.py +++ b/src/black/nodes.py @@ -935,3 +935,31 @@ def is_part_of_annotation(leaf: Leaf) -> bool: return True ancestor = ancestor.parent return False + + +def first_leaf(node: LN) -> Optional[Leaf]: + """Returns the first leaf of the ancestor node.""" + if isinstance(node, Leaf): + return node + elif not node.children: + return None + else: + return first_leaf(node.children[0]) + + +def last_leaf(node: LN) -> Optional[Leaf]: + """Returns the last leaf of the ancestor node.""" + if isinstance(node, Leaf): + return node + elif not node.children: + return None + else: + return last_leaf(node.children[-1]) + + +def furthest_ancestor_with_last_leaf(leaf: Leaf) -> LN: + """Returns the furthest ancestor that has this leaf node as the last leaf.""" + node: LN = leaf + while node.parent and node.parent.children and node is node.parent.children[-1]: + node = node.parent + return node diff --git a/src/black/ranges.py b/src/black/ranges.py new file mode 100644 index 00000000000..b0c312e6274 --- /dev/null +++ b/src/black/ranges.py @@ -0,0 +1,496 @@ +"""Functions related to Black's formatting by line ranges feature.""" + +import difflib +from dataclasses import dataclass +from typing import Collection, Iterator, List, Sequence, Set, Tuple, Union + +from black.nodes import ( + LN, + STANDALONE_COMMENT, + Leaf, + Node, + Visitor, + first_leaf, + furthest_ancestor_with_last_leaf, + last_leaf, + syms, +) +from blib2to3.pgen2.token import ASYNC, NEWLINE + + +def parse_line_ranges(line_ranges: Sequence[str]) -> List[Tuple[int, int]]: + lines: List[Tuple[int, int]] = [] + for lines_str in line_ranges: + parts = lines_str.split("-") + if len(parts) != 2: + raise ValueError( + "Incorrect --line-ranges format, expect 'START-END', found" + f" {lines_str!r}" + ) + try: + start = int(parts[0]) + end = int(parts[1]) + except ValueError: + raise ValueError( + "Incorrect --line-ranges value, expect integer ranges, found" + f" {lines_str!r}" + ) from None + else: + lines.append((start, end)) + return lines + + +def is_valid_line_range(lines: Tuple[int, int]) -> bool: + """Returns whether the line range is valid.""" + return not lines or lines[0] <= lines[1] + + +def adjusted_lines( + lines: Collection[Tuple[int, int]], + original_source: str, + modified_source: str, +) -> List[Tuple[int, int]]: + """Returns the adjusted line ranges based on edits from the original code. + + This computes the new line ranges by diffing original_source and + modified_source, and adjust each range based on how the range overlaps with + the diffs. + + Note the diff can contain lines outside of the original line ranges. This can + happen when the formatting has to be done in adjacent to maintain consistent + local results. For example: + + 1. def my_func(arg1, arg2, + 2. arg3,): + 3. pass + + If it restricts to line 2-2, it can't simply reformat line 2, it also has + to reformat line 1: + + 1. def my_func( + 2. arg1, + 3. arg2, + 4. arg3, + 5. ): + 6. pass + + In this case, we will expand the line ranges to also include the whole diff + block. + + Args: + lines: a collection of line ranges. + original_source: the original source. + modified_source: the modified source. + """ + lines_mappings = _calculate_lines_mappings(original_source, modified_source) + + new_lines = [] + # Keep an index of the current search. Since the lines and lines_mappings are + # sorted, this makes the search complexity linear. + current_mapping_index = 0 + for start, end in sorted(lines): + start_mapping_index = _find_lines_mapping_index( + start, + lines_mappings, + current_mapping_index, + ) + end_mapping_index = _find_lines_mapping_index( + end, + lines_mappings, + start_mapping_index, + ) + current_mapping_index = start_mapping_index + if start_mapping_index >= len(lines_mappings) or end_mapping_index >= len( + lines_mappings + ): + # Protect against invalid inputs. + continue + start_mapping = lines_mappings[start_mapping_index] + end_mapping = lines_mappings[end_mapping_index] + if start_mapping.is_changed_block: + # When the line falls into a changed block, expands to the whole block. + new_start = start_mapping.modified_start + else: + new_start = ( + start - start_mapping.original_start + start_mapping.modified_start + ) + if end_mapping.is_changed_block: + # When the line falls into a changed block, expands to the whole block. + new_end = end_mapping.modified_end + else: + new_end = end - end_mapping.original_start + end_mapping.modified_start + new_range = (new_start, new_end) + if is_valid_line_range(new_range): + new_lines.append(new_range) + return new_lines + + +def convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]) -> None: + """Converts unchanged lines to STANDALONE_COMMENT. + + The idea is similar to how `# fmt: on/off` is implemented. It also converts the + nodes between those markers as a single `STANDALONE_COMMENT` leaf node with + the unformatted code as its value. `STANDALONE_COMMENT` is a "fake" token + that will be formatted as-is with its prefix normalized. + + Here we perform two passes: + + 1. Visit the top-level statements, and convert them to a single + `STANDALONE_COMMENT` when unchanged. This speeds up formatting when some + of the top-level statements aren't changed. + 2. Convert unchanged "unwrapped lines" to `STANDALONE_COMMENT` nodes line by + line. "unwrapped lines" are divided by the `NEWLINE` token. e.g. a + multi-line statement is *one* "unwrapped line" that ends with `NEWLINE`, + even though this statement itself can span multiple lines, and the + tokenizer only sees the last '\n' as the `NEWLINE` token. + + NOTE: During pass (2), comment prefixes and indentations are ALWAYS + normalized even when the lines aren't changed. This is fixable by moving + more formatting to pass (1). However, it's hard to get it correct when + incorrect indentations are used. So we defer this to future optimizations. + """ + lines_set: Set[int] = set() + for start, end in lines: + lines_set.update(range(start, end + 1)) + visitor = _TopLevelStatementsVisitor(lines_set) + _ = list(visitor.visit(src_node)) # Consume all results. + _convert_unchanged_line_by_line(src_node, lines_set) + + +def _contains_standalone_comment(node: LN) -> bool: + if isinstance(node, Leaf): + return node.type == STANDALONE_COMMENT + else: + for child in node.children: + if _contains_standalone_comment(child): + return True + return False + + +class _TopLevelStatementsVisitor(Visitor[None]): + """ + A node visitor that converts unchanged top-level statements to + STANDALONE_COMMENT. + + This is used in addition to _convert_unchanged_lines_by_flatterning, to + speed up formatting when there are unchanged top-level + classes/functions/statements. + """ + + def __init__(self, lines_set: Set[int]): + self._lines_set = lines_set + + def visit_simple_stmt(self, node: Node) -> Iterator[None]: + # This is only called for top-level statements, since `visit_suite` + # won't visit its children nodes. + yield from [] + newline_leaf = last_leaf(node) + if not newline_leaf: + return + assert ( + newline_leaf.type == NEWLINE + ), f"Unexpectedly found leaf.type={newline_leaf.type}" + # We need to find the furthest ancestor with the NEWLINE as the last + # leaf, since a `suite` can simply be a `simple_stmt` when it puts + # its body on the same line. Example: `if cond: pass`. + ancestor = furthest_ancestor_with_last_leaf(newline_leaf) + if not _get_line_range(ancestor).intersection(self._lines_set): + _convert_node_to_standalone_comment(ancestor) + + def visit_suite(self, node: Node) -> Iterator[None]: + yield from [] + # If there is a STANDALONE_COMMENT node, it means parts of the node tree + # have fmt on/off/skip markers. Those STANDALONE_COMMENT nodes can't + # be simply converted by calling str(node). So we just don't convert + # here. + if _contains_standalone_comment(node): + return + # Find the semantic parent of this suite. For `async_stmt` and + # `async_funcdef`, the ASYNC token is defined on a separate level by the + # grammar. + semantic_parent = node.parent + if semantic_parent is not None: + if ( + semantic_parent.prev_sibling is not None + and semantic_parent.prev_sibling.type == ASYNC + ): + semantic_parent = semantic_parent.parent + if semantic_parent is not None and not _get_line_range( + semantic_parent + ).intersection(self._lines_set): + _convert_node_to_standalone_comment(semantic_parent) + + +def _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]) -> None: + """Converts unchanged to STANDALONE_COMMENT line by line.""" + for leaf in node.leaves(): + if leaf.type != NEWLINE: + # We only consider "unwrapped lines", which are divided by the NEWLINE + # token. + continue + if leaf.parent and leaf.parent.type == syms.match_stmt: + # The `suite` node is defined as: + # match_stmt: "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT + # Here we need to check `subject_expr`. The `case_block+` will be + # checked by their own NEWLINEs. + nodes_to_ignore: List[LN] = [] + prev_sibling = leaf.prev_sibling + while prev_sibling: + nodes_to_ignore.insert(0, prev_sibling) + prev_sibling = prev_sibling.prev_sibling + if not _get_line_range(nodes_to_ignore).intersection(lines_set): + _convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf) + elif leaf.parent and leaf.parent.type == syms.suite: + # The `suite` node is defined as: + # suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT + # We will check `simple_stmt` and `stmt+` separately against the lines set + parent_sibling = leaf.parent.prev_sibling + nodes_to_ignore = [] + while parent_sibling and not parent_sibling.type == syms.suite: + # NOTE: Multiple suite nodes can exist as siblings in e.g. `if_stmt`. + nodes_to_ignore.insert(0, parent_sibling) + parent_sibling = parent_sibling.prev_sibling + # Special case for `async_stmt` and `async_funcdef` where the ASYNC + # token is on the grandparent node. + grandparent = leaf.parent.parent + if ( + grandparent is not None + and grandparent.prev_sibling is not None + and grandparent.prev_sibling.type == ASYNC + ): + nodes_to_ignore.insert(0, grandparent.prev_sibling) + if not _get_line_range(nodes_to_ignore).intersection(lines_set): + _convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf) + else: + ancestor = furthest_ancestor_with_last_leaf(leaf) + # Consider multiple decorators as a whole block, as their + # newlines have different behaviors than the rest of the grammar. + if ( + ancestor.type == syms.decorator + and ancestor.parent + and ancestor.parent.type == syms.decorators + ): + ancestor = ancestor.parent + if not _get_line_range(ancestor).intersection(lines_set): + _convert_node_to_standalone_comment(ancestor) + + +def _convert_node_to_standalone_comment(node: LN) -> None: + """Convert node to STANDALONE_COMMENT by modifying the tree inline.""" + parent = node.parent + if not parent: + return + first = first_leaf(node) + last = last_leaf(node) + if not first or not last: + return + if first is last: + # This can happen on the following edge cases: + # 1. A block of `# fmt: off/on` code except the `# fmt: on` is placed + # on the end of the last line instead of on a new line. + # 2. A single backslash on its own line followed by a comment line. + # Ideally we don't want to format them when not requested, but fixing + # isn't easy. These cases are also badly formatted code, so it isn't + # too bad we reformat them. + return + # The prefix contains comments and indentation whitespaces. They are + # reformatted accordingly to the correct indentation level. + # This also means the indentation will be changed on the unchanged lines, and + # this is actually required to not break incremental reformatting. + prefix = first.prefix + first.prefix = "" + index = node.remove() + if index is not None: + # Remove the '\n', as STANDALONE_COMMENT will have '\n' appended when + # genearting the formatted code. + value = str(node)[:-1] + parent.insert_child( + index, + Leaf( + STANDALONE_COMMENT, + value, + prefix=prefix, + fmt_pass_converted_first_leaf=first, + ), + ) + + +def _convert_nodes_to_standalone_comment(nodes: Sequence[LN], *, newline: Leaf) -> None: + """Convert nodes to STANDALONE_COMMENT by modifying the tree inline.""" + if not nodes: + return + parent = nodes[0].parent + first = first_leaf(nodes[0]) + if not parent or not first: + return + prefix = first.prefix + first.prefix = "" + value = "".join(str(node) for node in nodes) + # The prefix comment on the NEWLINE leaf is the trailing comment of the statement. + if newline.prefix: + value += newline.prefix + newline.prefix = "" + index = nodes[0].remove() + for node in nodes[1:]: + node.remove() + if index is not None: + parent.insert_child( + index, + Leaf( + STANDALONE_COMMENT, + value, + prefix=prefix, + fmt_pass_converted_first_leaf=first, + ), + ) + + +def _leaf_line_end(leaf: Leaf) -> int: + """Returns the line number of the leaf node's last line.""" + if leaf.type == NEWLINE: + return leaf.lineno + else: + # Leaf nodes like multiline strings can occupy multiple lines. + return leaf.lineno + str(leaf).count("\n") + + +def _get_line_range(node_or_nodes: Union[LN, List[LN]]) -> Set[int]: + """Returns the line range of this node or list of nodes.""" + if isinstance(node_or_nodes, list): + nodes = node_or_nodes + if not nodes: + return set() + first = first_leaf(nodes[0]) + last = last_leaf(nodes[-1]) + if first and last: + line_start = first.lineno + line_end = _leaf_line_end(last) + return set(range(line_start, line_end + 1)) + else: + return set() + else: + node = node_or_nodes + if isinstance(node, Leaf): + return set(range(node.lineno, _leaf_line_end(node) + 1)) + else: + first = first_leaf(node) + last = last_leaf(node) + if first and last: + return set(range(first.lineno, _leaf_line_end(last) + 1)) + else: + return set() + + +@dataclass +class _LinesMapping: + """1-based lines mapping from original source to modified source. + + Lines [original_start, original_end] from original source + are mapped to [modified_start, modified_end]. + + The ranges are inclusive on both ends. + """ + + original_start: int + original_end: int + modified_start: int + modified_end: int + # Whether this range corresponds to a changed block, or an unchanged block. + is_changed_block: bool + + +def _calculate_lines_mappings( + original_source: str, + modified_source: str, +) -> Sequence[_LinesMapping]: + """Returns a sequence of _LinesMapping by diffing the sources. + + For example, given the following diff: + import re + - def func(arg1, + - arg2, arg3): + + def func(arg1, arg2, arg3): + pass + It returns the following mappings: + original -> modified + (1, 1) -> (1, 1), is_changed_block=False (the "import re" line) + (2, 3) -> (2, 2), is_changed_block=True (the diff) + (4, 4) -> (3, 3), is_changed_block=False (the "pass" line) + + You can think of this visually as if it brings up a side-by-side diff, and tries + to map the line ranges from the left side to the right side: + + (1, 1)->(1, 1) 1. import re 1. import re + (2, 3)->(2, 2) 2. def func(arg1, 2. def func(arg1, arg2, arg3): + 3. arg2, arg3): + (4, 4)->(3, 3) 4. pass 3. pass + + Args: + original_source: the original source. + modified_source: the modified source. + """ + matcher = difflib.SequenceMatcher( + None, + original_source.splitlines(keepends=True), + modified_source.splitlines(keepends=True), + ) + matching_blocks = matcher.get_matching_blocks() + lines_mappings: List[_LinesMapping] = [] + # matching_blocks is a sequence of "same block of code ranges", see + # https://docs.python.org/3/library/difflib.html#difflib.SequenceMatcher.get_matching_blocks + # Each block corresponds to a _LinesMapping with is_changed_block=False, + # and the ranges between two blocks corresponds to a _LinesMapping with + # is_changed_block=True, + # NOTE: matching_blocks is 0-based, but _LinesMapping is 1-based. + for i, block in enumerate(matching_blocks): + if i == 0: + if block.a != 0 or block.b != 0: + lines_mappings.append( + _LinesMapping( + original_start=1, + original_end=block.a, + modified_start=1, + modified_end=block.b, + is_changed_block=False, + ) + ) + else: + previous_block = matching_blocks[i - 1] + lines_mappings.append( + _LinesMapping( + original_start=previous_block.a + previous_block.size + 1, + original_end=block.a, + modified_start=previous_block.b + previous_block.size + 1, + modified_end=block.b, + is_changed_block=True, + ) + ) + if i < len(matching_blocks) - 1: + lines_mappings.append( + _LinesMapping( + original_start=block.a + 1, + original_end=block.a + block.size, + modified_start=block.b + 1, + modified_end=block.b + block.size, + is_changed_block=False, + ) + ) + return lines_mappings + + +def _find_lines_mapping_index( + original_line: int, + lines_mappings: Sequence[_LinesMapping], + start_index: int, +) -> int: + """Returns the original index of the lines mappings for the original line.""" + index = start_index + while index < len(lines_mappings): + mapping = lines_mappings[index] + if ( + mapping.original_start <= original_line + and original_line <= mapping.original_end + ): + return index + index += 1 + return index diff --git a/tests/data/cases/line_ranges_basic.py b/tests/data/cases/line_ranges_basic.py new file mode 100644 index 00000000000..9f0fb2da70e --- /dev/null +++ b/tests/data/cases/line_ranges_basic.py @@ -0,0 +1,107 @@ +# flags: --line-ranges=5-6 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. +def foo1(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass +def foo2(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass +def foo3(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass +def foo4(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass + +# Adding some unformated code covering a wide range of syntaxes. + +if True: + # Incorrectly indented prefix comments. + pass + +import typing +from typing import ( + Any , + ) +class MyClass( object): # Trailing comment with extra leading space. + #NOTE: The following indentation is incorrect: + @decor( 1 * 3 ) + def my_func( arg): + pass + +try: # Trailing comment with extra leading space. + for i in range(10): # Trailing comment with extra leading space. + while condition: + if something: + then_something( ) + elif something_else: + then_something_else( ) +except ValueError as e: + unformatted( ) +finally: + unformatted( ) + +async def test_async_unformatted( ): # Trailing comment with extra leading space. + async for i in some_iter( unformatted ): # Trailing comment with extra leading space. + await asyncio.sleep( 1 ) + async with some_context( unformatted ): + print( "unformatted" ) + + +# output +# flags: --line-ranges=5-6 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. +def foo1(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass +def foo2( + parameter_1, + parameter_2, + parameter_3, + parameter_4, + parameter_5, + parameter_6, + parameter_7, +): + pass + + +def foo3( + parameter_1, + parameter_2, + parameter_3, + parameter_4, + parameter_5, + parameter_6, + parameter_7, +): + pass + + +def foo4(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass + +# Adding some unformated code covering a wide range of syntaxes. + +if True: + # Incorrectly indented prefix comments. + pass + +import typing +from typing import ( + Any , + ) +class MyClass( object): # Trailing comment with extra leading space. + #NOTE: The following indentation is incorrect: + @decor( 1 * 3 ) + def my_func( arg): + pass + +try: # Trailing comment with extra leading space. + for i in range(10): # Trailing comment with extra leading space. + while condition: + if something: + then_something( ) + elif something_else: + then_something_else( ) +except ValueError as e: + unformatted( ) +finally: + unformatted( ) + +async def test_async_unformatted( ): # Trailing comment with extra leading space. + async for i in some_iter( unformatted ): # Trailing comment with extra leading space. + await asyncio.sleep( 1 ) + async with some_context( unformatted ): + print( "unformatted" ) diff --git a/tests/data/cases/line_ranges_fmt_off.py b/tests/data/cases/line_ranges_fmt_off.py new file mode 100644 index 00000000000..b51cef58fe5 --- /dev/null +++ b/tests/data/cases/line_ranges_fmt_off.py @@ -0,0 +1,49 @@ +# flags: --line-ranges=7-7 --line-ranges=17-23 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. + +# fmt: off +import os +def myfunc( ): # Intentionally unformatted. + pass +# fmt: on + + +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +# fmt: off +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +# fmt: on + + +def myfunc( ): # This will be reformatted. + print( {"this will be reformatted"} ) + +# output + +# flags: --line-ranges=7-7 --line-ranges=17-23 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. + +# fmt: off +import os +def myfunc( ): # Intentionally unformatted. + pass +# fmt: on + + +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +# fmt: off +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +# fmt: on + + +def myfunc(): # This will be reformatted. + print({"this will be reformatted"}) diff --git a/tests/data/cases/line_ranges_fmt_off_decorator.py b/tests/data/cases/line_ranges_fmt_off_decorator.py new file mode 100644 index 00000000000..14aa1dda02d --- /dev/null +++ b/tests/data/cases/line_ranges_fmt_off_decorator.py @@ -0,0 +1,27 @@ +# flags: --line-ranges=12-12 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. + +# Regression test for an edge case involving decorators and fmt: off/on. +class MyClass: + + # fmt: off + @decorator ( ) + # fmt: on + def method(): + print ( "str" ) + +# output + +# flags: --line-ranges=12-12 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. + +# Regression test for an edge case involving decorators and fmt: off/on. +class MyClass: + + # fmt: off + @decorator ( ) + # fmt: on + def method(): + print("str") diff --git a/tests/data/cases/line_ranges_fmt_off_overlap.py b/tests/data/cases/line_ranges_fmt_off_overlap.py new file mode 100644 index 00000000000..0391d17a843 --- /dev/null +++ b/tests/data/cases/line_ranges_fmt_off_overlap.py @@ -0,0 +1,37 @@ +# flags: --line-ranges=11-17 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. + + +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +# fmt: off +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +# fmt: on + + +def myfunc( ): # This will be reformatted. + print( {"this will be reformatted"} ) + +# output + +# flags: --line-ranges=11-17 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. + + +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +# fmt: off +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +def myfunc( ): # This will not be reformatted. + print( {"also won't be reformatted"} ) +# fmt: on + + +def myfunc(): # This will be reformatted. + print({"this will be reformatted"}) diff --git a/tests/data/cases/line_ranges_imports.py b/tests/data/cases/line_ranges_imports.py new file mode 100644 index 00000000000..76b18ffecb3 --- /dev/null +++ b/tests/data/cases/line_ranges_imports.py @@ -0,0 +1,9 @@ +# flags: --line-ranges=8-8 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. + +# This test ensures no empty lines are added around import lines. +# It caused an issue before https://github.com/psf/black/pull/3610 is merged. +import os +import re +import sys diff --git a/tests/data/cases/line_ranges_indentation.py b/tests/data/cases/line_ranges_indentation.py new file mode 100644 index 00000000000..82d3ad69a5e --- /dev/null +++ b/tests/data/cases/line_ranges_indentation.py @@ -0,0 +1,27 @@ +# flags: --line-ranges=5-5 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. +if cond1: + print("first") + if cond2: + print("second") + else: + print("else") + +if another_cond: + print("will not be changed") + +# output + +# flags: --line-ranges=5-5 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. +if cond1: + print("first") + if cond2: + print("second") + else: + print("else") + +if another_cond: + print("will not be changed") diff --git a/tests/data/cases/line_ranges_two_passes.py b/tests/data/cases/line_ranges_two_passes.py new file mode 100644 index 00000000000..aeed3260b8e --- /dev/null +++ b/tests/data/cases/line_ranges_two_passes.py @@ -0,0 +1,27 @@ +# flags: --line-ranges=9-11 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. + +# This is a specific case for Black's two-pass formatting behavior in `format_str`. +# The second pass must respect the line ranges before the first pass. + + +def restrict_to_this_line(arg1, + arg2, + arg3): + print ( "This should not be formatted." ) + print ( "Note that in the second pass, the original line range 9-11 will cover these print lines.") + +# output + +# flags: --line-ranges=9-11 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. + +# This is a specific case for Black's two-pass formatting behavior in `format_str`. +# The second pass must respect the line ranges before the first pass. + + +def restrict_to_this_line(arg1, arg2, arg3): + print ( "This should not be formatted." ) + print ( "Note that in the second pass, the original line range 9-11 will cover these print lines.") diff --git a/tests/data/cases/line_ranges_unwrapping.py b/tests/data/cases/line_ranges_unwrapping.py new file mode 100644 index 00000000000..cd7751b9417 --- /dev/null +++ b/tests/data/cases/line_ranges_unwrapping.py @@ -0,0 +1,25 @@ +# flags: --line-ranges=5-5 --line-ranges=9-9 --line-ranges=13-13 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. +alist = [ + 1, 2 +] + +adict = { + "key" : "value" +} + +func_call ( + arg = value +) + +# output + +# flags: --line-ranges=5-5 --line-ranges=9-9 --line-ranges=13-13 +# NOTE: If you need to modify this file, pay special attention to the --line-ranges= +# flag above as it's formatting specifically these lines. +alist = [1, 2] + +adict = {"key": "value"} + +func_call(arg=value) diff --git a/tests/data/invalid_line_ranges.toml b/tests/data/invalid_line_ranges.toml new file mode 100644 index 00000000000..791573f2625 --- /dev/null +++ b/tests/data/invalid_line_ranges.toml @@ -0,0 +1,2 @@ +[tool.black] +line-ranges = "1-1" diff --git a/tests/data/line_ranges_formatted/basic.py b/tests/data/line_ranges_formatted/basic.py new file mode 100644 index 00000000000..b419b1f16ae --- /dev/null +++ b/tests/data/line_ranges_formatted/basic.py @@ -0,0 +1,50 @@ +"""Module doc.""" + +from typing import ( + Callable, + Literal, +) + + +# fmt: off +class Unformatted: + def should_also_work(self): + pass +# fmt: on + + +a = [1, 2] # fmt: skip + + +# This should cover as many syntaxes as possible. +class Foo: + """Class doc.""" + + def __init__(self) -> None: + pass + + @add_logging + @memoize.memoize(max_items=2) + def plus_one( + self, + number: int, + ) -> int: + return number + 1 + + async def async_plus_one(self, number: int) -> int: + await asyncio.sleep(1) + async with some_context(): + return number + 1 + + +try: + for i in range(10): + while condition: + if something: + then_something() + elif something_else: + then_something_else() +except ValueError as e: + handle(e) +finally: + done() diff --git a/tests/data/line_ranges_formatted/pattern_matching.py b/tests/data/line_ranges_formatted/pattern_matching.py new file mode 100644 index 00000000000..cd98efdd504 --- /dev/null +++ b/tests/data/line_ranges_formatted/pattern_matching.py @@ -0,0 +1,25 @@ +# flags: --minimum-version=3.10 + + +def pattern_matching(): + match status: + case 1: + return "1" + case [single]: + return "single" + case [ + action, + obj, + ]: + return "act on obj" + case Point(x=0): + return "class pattern" + case {"text": message}: + return "mapping" + case { + "text": message, + "format": _, + }: + return "mapping" + case _: + return "fallback" diff --git a/tests/test_black.py b/tests/test_black.py index c7196098e14..c9819742425 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -8,6 +8,7 @@ import os import re import sys +import textwrap import types from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, redirect_stderr @@ -1269,7 +1270,7 @@ def test_reformat_one_with_stdin_filename(self) -> None: report=report, ) fsts.assert_called_once_with( - fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE + fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE, lines=() ) # __BLACK_STDIN_FILENAME__ should have been stripped report.done.assert_called_with(expected, black.Changed.YES) @@ -1295,6 +1296,7 @@ def test_reformat_one_with_stdin_filename_pyi(self) -> None: fast=True, write_back=black.WriteBack.YES, mode=replace(DEFAULT_MODE, is_pyi=True), + lines=(), ) # __BLACK_STDIN_FILENAME__ should have been stripped report.done.assert_called_with(expected, black.Changed.YES) @@ -1320,6 +1322,7 @@ def test_reformat_one_with_stdin_filename_ipynb(self) -> None: fast=True, write_back=black.WriteBack.YES, mode=replace(DEFAULT_MODE, is_ipynb=True), + lines=(), ) # __BLACK_STDIN_FILENAME__ should have been stripped report.done.assert_called_with(expected, black.Changed.YES) @@ -1941,6 +1944,88 @@ def test_equivalency_ast_parse_failure_includes_error(self) -> None: err.match("invalid character") err.match(r"\(, line 1\)") + def test_line_ranges_with_code_option(self) -> None: + code = textwrap.dedent("""\ + if a == b: + print ( "OK" ) + """) + args = ["--line-ranges=1-1", "--code", code] + result = CliRunner().invoke(black.main, args) + + expected = textwrap.dedent("""\ + if a == b: + print ( "OK" ) + """) + self.compare_results(result, expected, expected_exit_code=0) + + def test_line_ranges_with_stdin(self) -> None: + code = textwrap.dedent("""\ + if a == b: + print ( "OK" ) + """) + runner = BlackRunner() + result = runner.invoke( + black.main, ["--line-ranges=1-1", "-"], input=BytesIO(code.encode("utf-8")) + ) + + expected = textwrap.dedent("""\ + if a == b: + print ( "OK" ) + """) + self.compare_results(result, expected, expected_exit_code=0) + + def test_line_ranges_with_source(self) -> None: + with TemporaryDirectory() as workspace: + test_file = Path(workspace) / "test.py" + test_file.write_text( + textwrap.dedent("""\ + if a == b: + print ( "OK" ) + """), + encoding="utf-8", + ) + args = ["--line-ranges=1-1", str(test_file)] + result = CliRunner().invoke(black.main, args) + assert not result.exit_code + + formatted = test_file.read_text(encoding="utf-8") + expected = textwrap.dedent("""\ + if a == b: + print ( "OK" ) + """) + assert expected == formatted + + def test_line_ranges_with_multiple_sources(self) -> None: + with TemporaryDirectory() as workspace: + test1_file = Path(workspace) / "test1.py" + test1_file.write_text("", encoding="utf-8") + test2_file = Path(workspace) / "test2.py" + test2_file.write_text("", encoding="utf-8") + args = ["--line-ranges=1-1", str(test1_file), str(test2_file)] + result = CliRunner().invoke(black.main, args) + assert result.exit_code == 1 + assert "Cannot use --line-ranges to format multiple files" in result.output + + def test_line_ranges_with_ipynb(self) -> None: + with TemporaryDirectory() as workspace: + test_file = Path(workspace) / "test.ipynb" + test_file.write_text("{}", encoding="utf-8") + args = ["--line-ranges=1-1", "--ipynb", str(test_file)] + result = CliRunner().invoke(black.main, args) + assert "Cannot use --line-ranges with ipynb files" in result.output + assert result.exit_code == 1 + + def test_line_ranges_in_pyproject_toml(self) -> None: + config = THIS_DIR / "data" / "invalid_line_ranges.toml" + result = BlackRunner().invoke( + black.main, ["--code", "print()", "--config", str(config)] + ) + assert result.exit_code == 2 + assert result.stderr_bytes is not None + assert ( + b"Cannot use line-ranges in the pyproject.toml file." in result.stderr_bytes + ) + class TestCaching: def test_get_cache_dir( diff --git a/tests/test_format.py b/tests/test_format.py index 4e863c6c54b..6c2eca8c618 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -29,13 +29,19 @@ def check_file(subdir: str, filename: str, *, data: bool = True) -> None: args.mode, fast=args.fast, minimum_version=args.minimum_version, + lines=args.lines, ) if args.minimum_version is not None: major, minor = args.minimum_version target_version = TargetVersion[f"PY{major}{minor}"] mode = replace(args.mode, target_versions={target_version}) assert_format( - source, expected, mode, fast=args.fast, minimum_version=args.minimum_version + source, + expected, + mode, + fast=args.fast, + minimum_version=args.minimum_version, + lines=args.lines, ) @@ -45,6 +51,24 @@ def test_simple_format(filename: str) -> None: check_file("cases", filename) +@pytest.mark.parametrize("filename", all_data_cases("line_ranges_formatted")) +def test_line_ranges_line_by_line(filename: str) -> None: + args, source, expected = read_data_with_mode("line_ranges_formatted", filename) + assert ( + source == expected + ), "Test cases in line_ranges_formatted must already be formatted." + line_count = len(source.splitlines()) + for line in range(1, line_count + 1): + assert_format( + source, + expected, + args.mode, + fast=args.fast, + minimum_version=args.minimum_version, + lines=[(line, line)], + ) + + # =============== # # Unusual cases # =============== # diff --git a/tests/test_ranges.py b/tests/test_ranges.py new file mode 100644 index 00000000000..d9fa9171a7f --- /dev/null +++ b/tests/test_ranges.py @@ -0,0 +1,185 @@ +"""Test the black.ranges module.""" + +from typing import List, Tuple + +import pytest + +from black.ranges import adjusted_lines + + +@pytest.mark.parametrize( + "lines", + [[(1, 1)], [(1, 3)], [(1, 1), (3, 4)]], +) +def test_no_diff(lines: List[Tuple[int, int]]) -> None: + source = """\ +import re + +def func(): +pass +""" + assert lines == adjusted_lines(lines, source, source) + + +@pytest.mark.parametrize( + "lines", + [ + [(1, 0)], + [(-8, 0)], + [(-8, 8)], + [(1, 100)], + [(2, 1)], + [(0, 8), (3, 1)], + ], +) +def test_invalid_lines(lines: List[Tuple[int, int]]) -> None: + original_source = """\ +import re +def foo(arg): +'''This is the foo function. + +This is foo function's +docstring with more descriptive texts. +''' + +def func(arg1, +arg2, arg3): +pass +""" + modified_source = """\ +import re +def foo(arg): +'''This is the foo function. + +This is foo function's +docstring with more descriptive texts. +''' + +def func(arg1, arg2, arg3): +pass +""" + assert not adjusted_lines(lines, original_source, modified_source) + + +@pytest.mark.parametrize( + "lines,adjusted", + [ + ( + [(1, 1)], + [(1, 1)], + ), + ( + [(1, 2)], + [(1, 1)], + ), + ( + [(1, 6)], + [(1, 2)], + ), + ( + [(6, 6)], + [], + ), + ], +) +def test_removals( + lines: List[Tuple[int, int]], adjusted: List[Tuple[int, int]] +) -> None: + original_source = """\ +1. first line +2. second line +3. third line +4. fourth line +5. fifth line +6. sixth line +""" + modified_source = """\ +2. second line +5. fifth line +""" + assert adjusted == adjusted_lines(lines, original_source, modified_source) + + +@pytest.mark.parametrize( + "lines,adjusted", + [ + ( + [(1, 1)], + [(2, 2)], + ), + ( + [(1, 2)], + [(2, 5)], + ), + ( + [(2, 2)], + [(5, 5)], + ), + ], +) +def test_additions( + lines: List[Tuple[int, int]], adjusted: List[Tuple[int, int]] +) -> None: + original_source = """\ +1. first line +2. second line +""" + modified_source = """\ +this is added +1. first line +this is added +this is added +2. second line +this is added +""" + assert adjusted == adjusted_lines(lines, original_source, modified_source) + + +@pytest.mark.parametrize( + "lines,adjusted", + [ + ( + [(1, 11)], + [(1, 10)], + ), + ( + [(1, 12)], + [(1, 11)], + ), + ( + [(10, 10)], + [(9, 9)], + ), + ([(1, 1), (9, 10)], [(1, 1), (9, 9)]), + ([(9, 10), (1, 1)], [(1, 1), (9, 9)]), + ], +) +def test_diffs(lines: List[Tuple[int, int]], adjusted: List[Tuple[int, int]]) -> None: + original_source = """\ + 1. import re + 2. def foo(arg): + 3. '''This is the foo function. + 4. + 5. This is foo function's + 6. docstring with more descriptive texts. + 7. ''' + 8. + 9. def func(arg1, +10. arg2, arg3): +11. pass +12. # last line +""" + modified_source = """\ + 1. import re # changed + 2. def foo(arg): + 3. '''This is the foo function. + 4. + 5. This is foo function's + 6. docstring with more descriptive texts. + 7. ''' + 8. + 9. def func(arg1, arg2, arg3): +11. pass +12. # last line changed +""" + assert adjusted == adjusted_lines(lines, original_source, modified_source) diff --git a/tests/util.py b/tests/util.py index a31ae0992c2..c8699d335ab 100644 --- a/tests/util.py +++ b/tests/util.py @@ -8,13 +8,14 @@ from dataclasses import dataclass, field, replace from functools import partial from pathlib import Path -from typing import Any, Iterator, List, Optional, Tuple +from typing import Any, Collection, Iterator, List, Optional, Tuple import black from black.const import DEFAULT_LINE_LENGTH from black.debug import DebugVisitor from black.mode import TargetVersion from black.output import diff, err, out +from black.ranges import parse_line_ranges from . import conftest @@ -44,6 +45,7 @@ class TestCaseArgs: mode: black.Mode = field(default_factory=black.Mode) fast: bool = False minimum_version: Optional[Tuple[int, int]] = None + lines: Collection[Tuple[int, int]] = () def _assert_format_equal(expected: str, actual: str) -> None: @@ -93,6 +95,7 @@ def assert_format( *, fast: bool = False, minimum_version: Optional[Tuple[int, int]] = None, + lines: Collection[Tuple[int, int]] = (), ) -> None: """Convenience function to check that Black formats as expected. @@ -101,7 +104,7 @@ def assert_format( separate from TargetVerson Mode configuration. """ _assert_format_inner( - source, expected, mode, fast=fast, minimum_version=minimum_version + source, expected, mode, fast=fast, minimum_version=minimum_version, lines=lines ) # For both preview and non-preview tests, ensure that Black doesn't crash on @@ -113,6 +116,7 @@ def assert_format( replace(mode, preview=not mode.preview), fast=fast, minimum_version=minimum_version, + lines=lines, ) except Exception as e: text = "non-preview" if mode.preview else "preview" @@ -129,6 +133,7 @@ def assert_format( replace(mode, preview=False, line_length=1), fast=fast, minimum_version=minimum_version, + lines=lines, ) except Exception as e: raise FormatFailure( @@ -143,8 +148,9 @@ def _assert_format_inner( *, fast: bool = False, minimum_version: Optional[Tuple[int, int]] = None, + lines: Collection[Tuple[int, int]] = (), ) -> None: - actual = black.format_str(source, mode=mode) + actual = black.format_str(source, mode=mode, lines=lines) if expected is not None: _assert_format_equal(expected, actual) # It's not useful to run safety checks if we're expecting no changes anyway. The @@ -156,7 +162,7 @@ def _assert_format_inner( # when checking modern code on older versions. if minimum_version is None or sys.version_info >= minimum_version: black.assert_equivalent(source, actual) - black.assert_stable(source, actual, mode=mode) + black.assert_stable(source, actual, mode=mode, lines=lines) def dump_to_stderr(*output: str) -> str: @@ -239,6 +245,7 @@ def get_flags_parser() -> argparse.ArgumentParser: " version works correctly." ), ) + parser.add_argument("--line-ranges", action="append") return parser @@ -254,7 +261,13 @@ def parse_mode(flags_line: str) -> TestCaseArgs: magic_trailing_comma=not args.skip_magic_trailing_comma, preview=args.preview, ) - return TestCaseArgs(mode=mode, fast=args.fast, minimum_version=args.minimum_version) + if args.line_ranges: + lines = parse_line_ranges(args.line_ranges) + else: + lines = [] + return TestCaseArgs( + mode=mode, fast=args.fast, minimum_version=args.minimum_version, lines=lines + ) def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]: @@ -267,6 +280,12 @@ def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]: for line in lines: if not _input and line.startswith("# flags: "): mode = parse_mode(line[len("# flags: ") :]) + if mode.lines: + # Retain the `# flags: ` line when using --line-ranges=. This requires + # the `# output` section to also include this line, but retaining the + # line is important to make the line ranges match what you see in the + # test file. + result.append(line) continue line = line.replace(EMPTY_LINE, "") if line.rstrip() == "# output":