Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with labels in variable importance, add reference line, remove deprecation warning #207

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pymc_bart.utils import (
compute_variable_importance,
plot_convergence,
plot_dependence,
plot_ice,
plot_pdp,
plot_scatter_submodels,
Expand All @@ -35,14 +34,13 @@
"SubsetSplitRule",
"compute_variable_importance",
"plot_convergence",
"plot_dependence",
"plot_ice",
"plot_pdp",
"plot_scatter_submodels",
"plot_variable_importance",
"plot_variable_inclusion",
]
__version__ = "0.8.0"
__version__ = "0.8.1"


pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
32 changes: 14 additions & 18 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,6 @@ def plot_convergence(
return ax


def plot_dependence(*args, kind="pdp", **kwargs): # pylint: disable=unused-argument
"""
Partial dependence or individual conditional expectation plot.
"""
if kind == "pdp":
warnings.warn(
"This function has been deprecated. Use plot_pdp instead.",
FutureWarning,
)
elif kind == "ice":
warnings.warn(
"This function has been deprecated. Use plot_ice instead.",
FutureWarning,
)


def plot_ice(
bartrv: Variable,
X: npt.NDArray[np.float64],
Expand Down Expand Up @@ -307,6 +291,7 @@ def plot_pdp(
var_discrete: Optional[list[int]] = None,
func: Optional[Callable] = None,
samples: int = 200,
ref_line: bool = True,
random_seed: Optional[int] = None,
sharey: bool = True,
smooth: bool = True,
Expand Down Expand Up @@ -347,6 +332,8 @@ def plot_pdp(
Arbitrary function to apply to the predictions. Defaults to the identity function.
samples : int
Number of posterior samples used in the predictions. Defaults to 200
ref_line : bool
If True a reference line is plotted at the mean of the partial dependence. Defaults to True.
random_seed : Optional[int], by default None.
Seed used to sample from the posterior. Defaults to None.
sharey : bool
Expand Down Expand Up @@ -402,6 +389,7 @@ def identity(x):

count = 0
fake_X = _create_pdp_data(X, xs_interval, xs_values)
null_pd = []
for var in range(len(var_idx)):
excluded = indices[:]
excluded.remove(var)
Expand All @@ -413,6 +401,7 @@ def identity(x):
new_x = fake_X[:, var]
for s_i in range(shape):
p_di = func(p_d[:, :, s_i])
null_pd.append(p_di.mean())
if var in var_discrete:
_, idx_uni = np.unique(new_x, return_index=True)
y_means = p_di.mean(0)[idx_uni]
Expand Down Expand Up @@ -442,6 +431,11 @@ def identity(x):

count += 1

if ref_line:
ref_val = sum(null_pd) / len(null_pd)
for ax_ in np.ravel(axes):
ax_.axhline(ref_val, color="0.7", linestyle="--")

fig.text(-0.05, 0.5, y_label, va="center", rotation="vertical", fontsize=15)

return axes
Expand Down Expand Up @@ -949,11 +943,13 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

indices = least_important_vars[::-1]

labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)])
labels = np.array(
["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
)

vi_results = {
"indices": np.asarray(indices),
"labels": labels[indices],
"labels": labels,
"r2_mean": r2_mean,
"r2_hdi": r2_hdi,
"preds": preds,
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ line-length = 100
select = ["E", "F", "I", "PL", "UP", "W"]
ignore = [
"PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons.
"PLR0913", #Too many arguments in function definition

]

[tool.ruff.lint.pylint]
Expand Down
Loading