diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a224cd0..1a21dcf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,14 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.1 + rev: v0.6.3 hooks: - id: ruff args: ["--fix", "--output-format=full"] - id: ruff-format args: ["--line-length=100"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.1 + rev: v1.11.2 hooks: - id: mypy args: [--ignore-missing-imports] diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 9eee3b4..2937d7a 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -8,10 +8,11 @@ import numpy as np import numpy.typing as npt import pytensor.tensor as pt +from numba import jit from pytensor.tensor.variable import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter -from scipy.stats import norm, pearsonr +from scipy.stats import norm from .tree import Tree @@ -700,8 +701,9 @@ def plot_variable_importance( # noqa: PLR0915 method: str = "VI", figsize: Optional[Tuple[float, float]] = None, xlabel_angle: float = 0, - samples: int = 100, + samples: int = 50, random_seed: Optional[int] = None, + plot_kwargs: Optional[Dict[str, Any]] = None, ax: Optional[plt.Axes] = None, ) -> Tuple[List[int], Union[List[plt.Axes], Any]]: """ @@ -733,6 +735,14 @@ def plot_variable_importance( # noqa: PLR0915 Number of predictions used to compute correlation for subsets of variables. Defaults to 100 random_seed : Optional[int] random_seed used to sample from the posterior. Defaults to None. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_r2: matplotlib valid color for error bars + - marker_r2: matplotlib valid marker for the mean R squared + - marker_fc_r2: matplotlib valid marker face color for the mean R squared + - ls_ref: matplotlib valid linestyle for the reference line + - color_ref: matplotlib valid color for the reference line ax : axes Matplotlib axes. @@ -745,6 +755,9 @@ def plot_variable_importance( # noqa: PLR0915 all_trees = bartrv.owner.op.all_trees + if plot_kwargs is None: + plot_kwargs = {} + if bartrv.ndim == 1: # type: ignore shape = 1 else: @@ -773,6 +786,10 @@ def plot_variable_importance( # noqa: PLR0915 all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape ) + r_2_ref = np.array( + [pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)] + ) + if method == "VI": idxs = np.argsort( idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values @@ -794,10 +811,7 @@ def plot_variable_importance( # noqa: PLR0915 shape=shape, ) r_2 = np.array( - [ - pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2 - for j in range(samples) - ] + [pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)] ) r2_mean[idx] = np.mean(r_2) r2_hdi[idx] = az.hdi(r_2) @@ -833,10 +847,7 @@ def plot_variable_importance( # noqa: PLR0915 # Calculate Pearson correlation for each sample and find the mean r_2 = np.zeros(samples) for j in range(samples): - r_2[j] = ( - (pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]) - ** 2 - ) + r_2[j] = pearsonr2(predicted_all[j], predicted_subset[j]) mean_r_2 = np.mean(r_2, dtype=float) # Identify the least important combination of variables # based on the maximum mean squared Pearson correlation @@ -872,9 +883,21 @@ def plot_variable_importance( # noqa: PLR0915 ticks, r2_mean, np.array((r2_yerr_min, r2_yerr_max)), - color="C0", + color=plot_kwargs.get("color_r2", "k"), + fmt=plot_kwargs.get("marker_r2", "o"), + mfc=plot_kwargs.get("marker_fc_r2", "white"), + ) + ax.axhline( + np.mean(r_2_ref), + ls=plot_kwargs.get("ls_ref", "--"), + color=plot_kwargs.get("color_ref", "grey"), + ) + ax.fill_between( + [-0.5, n_vars - 0.5], + *az.hdi(r_2_ref), + alpha=0.1, + color=plot_kwargs.get("color_ref", "grey"), ) - ax.axhline(r2_mean[-1], ls="--", color="0.5") ax.set_xticks(ticks, new_labels, rotation=xlabel_angle) ax.set_ylabel("R²", rotation=0, labelpad=12) ax.set_ylim(0, 1) @@ -890,3 +913,13 @@ def generate_sequences(n_vars, i_var, include): else: sequences = [()] return sequences + + +@jit(nopython=True) +def pearsonr2(A, B): + """Compute the squared Pearson correlation coefficient""" + A = A.flatten() + B = B.flatten() + am = A - np.mean(A) + bm = B - np.mean(B) + return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2))