Skip to content

Commit

Permalink
remove X argument from plots
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Nov 26, 2024
1 parent d72c963 commit 38bf6b0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 25 deletions.
38 changes: 15 additions & 23 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,10 +824,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
else:
shape = bartrv.eval().shape[0]

n_vars = X.shape[1]

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns
X = X.to_numpy()
else:
labels = np.arange(n_vars).astype(str)

n_vars = X.shape[1]
r2_mean = np.zeros(n_vars)
r2_hdi = np.zeros((n_vars, 2))
preds = np.zeros((n_vars, samples, bartrv.eval().shape[0]))
Expand Down Expand Up @@ -947,6 +951,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

vi_results = {
"indices": indices,
"labels": labels[indices],
"r2_mean": r2_mean,
"r2_hdi": r2_hdi,
"preds": preds,
Expand All @@ -957,7 +962,6 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

def plot_variable_importance(
vi_results: dict,
X: npt.NDArray[np.float64],
labels=None,
figsize=None,
plot_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -1008,19 +1012,13 @@ def plot_variable_importance(
if figsize is None:
figsize = (8, 3)

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns
X = X.to_numpy()

if ax is None:
_, ax = plt.subplots(1, 1, figsize=figsize)

if labels is None:
labels = np.arange(n_vars).astype(str)
else:
labels = np.asarray(labels)
labels = vi_results["labels"]

new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]

r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)])

Expand Down Expand Up @@ -1048,7 +1046,7 @@ def plot_variable_importance(
)
ax.set_xticks(
ticks,
new_labels,
labels,
rotation=plot_kwargs.get("rotation", 0),
)
ax.set_ylabel("R²", rotation=0, labelpad=12)
Expand All @@ -1060,9 +1058,9 @@ def plot_variable_importance(

def plot_scatter_submodels(
vi_results: dict,
X: npt.NDArray[np.float64],
func: Optional[Callable] = None,
grid: str = "long",
labels=None,
figsize: Optional[Tuple[float, float]] = None,
plot_kwargs: Optional[Dict[str, Any]] = None,
axes: Optional[plt.Axes] = None,
Expand All @@ -1074,14 +1072,14 @@ def plot_scatter_submodels(
----------
vi_results: Dictionary
Dictionary computed with `compute_variable_importance`
X : npt.NDArray[np.float64]
The covariate matrix.
func : Optional[Callable], by default None.
Arbitrary function to apply to the predictions. Defaults to the identity function.
grid : str or tuple
How to arrange the subplots. Defaults to "long", one subplot below the other.
Other options are "wide", one subplot next to each other or a tuple indicating the number
of rows and columns.
labels : Optional[List[str]]
List of the names of the covariates.
plot_kwargs : dict
Additional keyword arguments for the plot. Defaults to None.
Valid keys are:
Expand All @@ -1097,23 +1095,17 @@ def plot_scatter_submodels(
indices = vi_results["indices"]
preds = vi_results["preds"]
preds_all = vi_results["preds_all"]
n_vars = len(indices)

if axes is None:
_, axes = _get_axes(grid, len(indices), True, True, figsize)

if plot_kwargs is None:
plot_kwargs = {}

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns

if labels is None:
labels = np.arange(n_vars).astype(str)
else:
labels = np.asarray(labels)
labels = vi_results["labels"]

new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]

if func is not None:
preds = func(preds)
Expand All @@ -1122,7 +1114,7 @@ def plot_scatter_submodels(
min_ = min(np.min(preds), np.min(preds_all))
max_ = max(np.max(preds), np.max(preds_all))

for pred, x_label, ax in zip(preds, new_labels, axes.ravel()):
for pred, x_label, ax in zip(preds, labels, axes.ravel()):
ax.plot(
pred,
preds_all,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def test_vi(self, kwargs):
vi_results = pmb.compute_variable_importance(
self.idata, bartrv=self.mu, X=self.X, samples=samples
)
pmb.plot_variable_importance(vi_results, X=self.X, **kwargs)
pmb.plot_scatter_submodels(vi_results)
pmb.plot_variable_importance(vi_results, **kwargs)
pmb.plot_scatter_submodels(vi_results, **kwargs)

def test_pdp_pandas_labels(self):
pd = pytest.importorskip("pandas")
Expand Down

0 comments on commit 38bf6b0

Please sign in to comment.