Skip to content

Commit

Permalink
Split PipProgress into 2 classes
Browse files Browse the repository at this point in the history
1 for sequential downloads and 1 for parallel downloads.
  • Loading branch information
NeilBotelho committed Nov 26, 2023
1 parent 5b7e6d7 commit 390fd51
Showing 1 changed file with 109 additions and 39 deletions.
148 changes: 109 additions & 39 deletions src/pip/_internal/cli/progress_bars.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from logging import Logger
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -41,6 +42,10 @@


class RenderableLine:
"""
A wrapper for a single row, renderable by `Console` methods.
"""

def __init__(self, line_items: List[Union[Text, ProgressBar]]):
self.line_items = line_items

Expand All @@ -57,6 +62,10 @@ def __rich_console__(


class RenderableLines:
"""
A wrapper for multiple rows, renderable by `Console` methods.
"""

def __init__(self, lines: Iterable[RenderableLine]):
self.lines = lines

Expand All @@ -71,6 +80,21 @@ def __rich_console__(


class PipProgress(Progress):
"""
Custom Progress bar for sequential downloads.
"""

def __init__(
self,
refresh_per_second: int,
progress_disabled: bool = False,
logger: Optional[Logger] = None,
) -> None:
super().__init__(refresh_per_second=refresh_per_second)
self.progress_disabled = progress_disabled
self.log_download_description = True
self.logger = logger

@classmethod
def get_default_columns(cls) -> Tuple[ProgressColumn, ...]:
"""
Expand Down Expand Up @@ -100,53 +124,37 @@ def get_indefinite_columns(cls) -> Tuple[ProgressColumn, ...]:
TimeElapsedColumn(),
)

@classmethod
def get_description_columns(cls) -> Tuple[ProgressColumn, ...]:
"""
Get the columns to use for the log message, i.e. the task description
"""
# These columns will be the "Downloading"/"Using cached" message
# This message needs to be columns because,logging this message interferes
# with parallel progress bars, and if we want the message to remain next
# to the progress bar even when there are multiple tasks, then it needs
# to be a part of the progress bar
indentation = get_indentation()
if indentation:
return (
TextColumn(" " * get_indentation()),
TextColumn("{task.description}"),
)
return (TextColumn("{task.description}"),)

def get_renderable(self) -> RenderableType:
"""
Get the renderable representation of the progress bars of all tasks
Get the renderable representation of the progress of all tasks
"""
renderables: List[RenderableLine] = []
for task in self.tasks:
if task.visible:
renderables.extend(self.make_task_group(task))
if not task.visible:
continue
task_renderable = [x for x in self.make_task_group(task) if x is not None]
renderables.extend(task_renderable)
return RenderableLines(renderables)

def make_task_group(self, task: Task) -> Iterable[RenderableLine]:
def make_task_group(self, task: Task) -> Iterable[Optional[RenderableLine]]:
"""
Create a representation for a task, including both the description line
and the progress line.
Create a representation for a task, i.e. it's progress bar.
Parameters:
- task (Task): The task for which to generate the representation.
Returns:
- Iterable[RenderableLine]: An iterable of renderable lines containing the
description and (optionally) progress lines,
- Optional[Group]: text representation of a Progress Column,
"""
columns = self.columns if task.total else self.get_indefinite_columns()
description_row = self.make_task_row(self.get_description_columns(), task)
# Only print description if download isn't large enough
if task.total is not None and not task.total > (40 * 1000):
return (description_row,)

hide_progress = task.fields["hide_progress"]
if self.progress_disabled or hide_progress:
return (None,)
columns = (
self.columns if task.total is not None else self.get_indefinite_columns()
)
progress_row = self.make_task_row(columns, task)
return (description_row, progress_row)
return (progress_row,)

def make_task_row(
self, columns: Tuple[Union[str, ProgressColumn], ...], task: Task
Expand All @@ -167,8 +175,6 @@ def merge_text_objects(
) -> List[Union[Text, ProgressBar]]:
"""
Merge adjacent Text objects in the given row into a single Text object.
This is required to prevent newlines from being rendered between
Text objects
"""
merged_row: List[Union[Text, ProgressBar]] = []
markup_to_merge: List[str] = []
Expand All @@ -186,6 +192,68 @@ def merge_text_objects(
merged_row.append(Text.from_markup(merged_markup))
return merged_row

def add_task(
self,
description: str,
start: bool = True,
total: Optional[float] = 100.0,
completed: int = 0,
visible: bool = True,
**fields: Any,
) -> TaskID:
"""
Reimplementation of Progress.add_task with description logging
"""
if visible and self.log_download_description and self.logger:
indentation = " " * get_indentation()
log_statement = f"{indentation}{description}"
self.logger.info(log_statement)
return super().add_task(
description=description, total=total, visible=visible, **fields
)


class PipParallelProgress(PipProgress):
def __init__(self, refresh_per_second: int, progress_disabled: bool = True):
super().__init__(refresh_per_second=refresh_per_second)
# Overrides behaviour of logging description on add_task from PipProgress
self.log_download_description = False

@classmethod
def get_description_columns(cls) -> Tuple[ProgressColumn, ...]:
"""
Get the columns to use for the log message, i.e. the task description
"""
# These columns will be the "Downloading"/"Using cached" message
# This message needs to be columns because,logging this message interferes
# with parallel progress bars, and if we want the message to remain next
# to the progress bar even when there are multiple tasks, then it needs
# to be a part of the progress bar
indentation = get_indentation()
if indentation:
return (
TextColumn(" " * get_indentation()),
TextColumn("{task.description}"),
)
return (TextColumn("{task.description}"),)

def make_task_group(self, task: Task) -> Iterable[Optional[RenderableLine]]:
"""
Create a representation for a task, including both the description row
and the progress row.
Parameters:
- task (Task): The task for which to generate the representation.
Returns:
- Iterable[Optional[RenderableLine]]: An Iterable containing the
description and progress rows,
"""
progress_row = super().make_task_group(task)

description_row = self.make_task_row(self.get_description_columns(), task)
return (description_row, *progress_row)

def sort_tasks(self) -> None:
"""
Sort tasks
Expand All @@ -196,7 +264,7 @@ def sort_tasks(self) -> None:
tasks = []
for task_id in self._tasks:
task = self._tasks[task_id]
if task.finished and len(self._tasks) > 3:
if task.finished and len(self._tasks) > 1:
# Remove and log the finished task if there are too many active
# tasks to reduce the number of things to be rendered
# If there are too many actice tasks on screen rich renders the
Expand All @@ -205,8 +273,10 @@ def sort_tasks(self) -> None:
# If we remove every task on completion, it adds an extra newline
# for sequential downloads due to self.live on __exit__
if task.visible:
task_group = RenderableLines(self.make_task_group(task))
self.console.print(task_group)
task_group = [
x for x in self.make_task_group(task) if x is not None
]
self.console.print(RenderableLines(task_group))
else:
tasks.append((task_id, self._tasks[task_id]))
# Sorting by finished ensures that all active downloads remain together
Expand All @@ -226,8 +296,8 @@ def update(
**fields: Any,
) -> None:
"""
A copy of Progress' implementation of update, with sorting of self.tasks
when a task is completed
A copy of Progress' implementation of update, with sorting of
self.tasks when a task is completed
"""
with self._lock:
task = self._tasks[task_id]
Expand Down

0 comments on commit 390fd51

Please sign in to comment.