From c5a20fa6a749ce8711685864629a7463670b08c6 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 20 Dec 2024 12:24:07 -0300 Subject: [PATCH 1/2] fix bug labels variable importance, add reference line --- pymc_bart/__init__.py | 4 +--- pymc_bart/utils.py | 34 +++++++++++++++------------------- pyproject.toml | 2 ++ 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index eee1881..440f7f2 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -19,7 +19,6 @@ from pymc_bart.utils import ( compute_variable_importance, plot_convergence, - plot_dependence, plot_ice, plot_pdp, plot_scatter_submodels, @@ -35,14 +34,13 @@ "SubsetSplitRule", "compute_variable_importance", "plot_convergence", - "plot_dependence", "plot_ice", "plot_pdp", "plot_scatter_submodels", "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.8.0" +__version__ = "0.8.1" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index e10a511..2ca69db 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -81,7 +81,7 @@ def plot_convergence( kind: str = "ecdf", figsize: Optional[tuple[float, float]] = None, ax=None, -) -> list[plt.Axes]: +) -> plt.Axes: """ Plot convergence diagnostics. @@ -137,22 +137,6 @@ def plot_convergence( return ax -def plot_dependence(*args, kind="pdp", **kwargs): # pylint: disable=unused-argument - """ - Partial dependence or individual conditional expectation plot. - """ - if kind == "pdp": - warnings.warn( - "This function has been deprecated. Use plot_pdp instead.", - FutureWarning, - ) - elif kind == "ice": - warnings.warn( - "This function has been deprecated. Use plot_ice instead.", - FutureWarning, - ) - - def plot_ice( bartrv: Variable, X: npt.NDArray[np.float64], @@ -307,6 +291,7 @@ def plot_pdp( var_discrete: Optional[list[int]] = None, func: Optional[Callable] = None, samples: int = 200, + ref_line: bool = True, random_seed: Optional[int] = None, sharey: bool = True, smooth: bool = True, @@ -347,6 +332,8 @@ def plot_pdp( Arbitrary function to apply to the predictions. Defaults to the identity function. samples : int Number of posterior samples used in the predictions. Defaults to 200 + ref_line : bool + If True a reference line is plotted at the mean of the partial dependence. Defaults to True. random_seed : Optional[int], by default None. Seed used to sample from the posterior. Defaults to None. sharey : bool @@ -402,6 +389,7 @@ def identity(x): count = 0 fake_X = _create_pdp_data(X, xs_interval, xs_values) + null_pd = [] for var in range(len(var_idx)): excluded = indices[:] excluded.remove(var) @@ -413,6 +401,7 @@ def identity(x): new_x = fake_X[:, var] for s_i in range(shape): p_di = func(p_d[:, :, s_i]) + null_pd.append(p_di.mean()) if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) y_means = p_di.mean(0)[idx_uni] @@ -442,6 +431,11 @@ def identity(x): count += 1 + if ref_line: + ref_val = sum(null_pd) / len(null_pd) + for ax_ in np.ravel(axes): + ax_.axhline(ref_val, color="0.7", linestyle="--") + fig.text(-0.05, 0.5, y_label, va="center", rotation="vertical", fontsize=15) return axes @@ -949,11 +943,13 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 indices = least_important_vars[::-1] - labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]) + labels = np.array( + ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + ) vi_results = { "indices": np.asarray(indices), - "labels": labels[indices], + "labels": labels, "r2_mean": r2_mean, "r2_hdi": r2_hdi, "preds": preds, diff --git a/pyproject.toml b/pyproject.toml index bc94137..f8f3e7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ line-length = 100 select = ["E", "F", "I", "PL", "UP", "W"] ignore = [ "PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons. + "PLR0913", #Too many arguments in function definition + ] [tool.ruff.lint.pylint] From d4e2beee159922e0d188f080a6186c44c66f85e1 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 20 Dec 2024 12:26:16 -0300 Subject: [PATCH 2/2] revert change --- pymc_bart/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 2ca69db..d9738dd 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -81,7 +81,7 @@ def plot_convergence( kind: str = "ecdf", figsize: Optional[tuple[float, float]] = None, ax=None, -) -> plt.Axes: +) -> list[plt.Axes]: """ Plot convergence diagnostics.