Skip to content

Commit

Permalink
fix: progress bar display (#420)
Browse files Browse the repository at this point in the history
* fix: progress bar display
* edit PathStatus.ELBO_ARGMAX_AT_ZERO message for clearer explanation
  • Loading branch information
aphc14 authored Feb 18, 2025
1 parent 00a4ca3 commit ec46270
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions pymc_extras/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from pytensor.tensor import TensorConstant, TensorVariable
from rich.console import Console, Group
from rich.padding import Padding
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.table import Table
from rich.text import Text

Expand Down Expand Up @@ -1395,7 +1396,7 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:

path_status_message = {
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
PathStatus.INVALID_LOGP: "Invalid logP values occur when a path's logP values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
}

Expand Down Expand Up @@ -1521,12 +1522,20 @@ def multipath_pathfinder(
results = []
compute_start = time.time()
try:
with CustomProgress(
desc = f"Paths Complete: {{path_idx}}/{num_paths}"
progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=default_progress_theme),
disable=not progressbar,
) as progress:
task = progress.add_task("Fitting", total=num_paths)
for result in generator:
)
with progress:
task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths)
for path_idx, result in enumerate(generator, start=1):
try:
if isinstance(result, Exception):
raise result
Expand All @@ -1552,7 +1561,14 @@ def multipath_pathfinder(
lbfgs_status=LBFGSStatus.LBFGS_FAILED,
)
)
progress.update(task, advance=1)
finally:
# TODO: display LBFGS and Path Status in real time
progress.update(
task,
description=desc.format(path_idx=path_idx),
completed=path_idx,
refresh=True,
)
except (KeyboardInterrupt, StopIteration) as e:
# if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
if isinstance(e, StopIteration):
Expand Down

0 comments on commit ec46270

Please sign in to comment.