diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 8f79d966..8fc6f799 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -60,6 +60,7 @@ from pytensor.tensor import TensorConstant, TensorVariable from rich.console import Console, Group from rich.padding import Padding +from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.table import Table from rich.text import Text @@ -1395,7 +1396,7 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]: path_status_message = { PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.", - PathStatus.INVALID_LOGP: "Invalid logP values occur when a path's logP values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", + PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.", PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", } @@ -1521,12 +1522,20 @@ def multipath_pathfinder( results = [] compute_start = time.time() try: - with CustomProgress( + desc = f"Paths Complete: {{path_idx}}/{num_paths}" + progress = CustomProgress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeRemainingColumn(), + TextColumn("/"), + TimeElapsedColumn(), console=Console(theme=default_progress_theme), disable=not progressbar, - ) as progress: - task = progress.add_task("Fitting", total=num_paths) - for result in generator: + ) + with progress: + task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths) + for path_idx, result in enumerate(generator, start=1): try: if isinstance(result, Exception): raise result @@ -1552,7 +1561,14 @@ def multipath_pathfinder( lbfgs_status=LBFGSStatus.LBFGS_FAILED, ) ) - progress.update(task, advance=1) + finally: + # TODO: display LBFGS and Path Status in real time + progress.update( + task, + description=desc.format(path_idx=path_idx), + completed=path_idx, + refresh=True, + ) except (KeyboardInterrupt, StopIteration) as e: # if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData. if isinstance(e, StopIteration):