Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Aug 22, 2024
2 parents 950cb70 + 9fbca4b commit 794726f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.4
rev: v0.6.1
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
- id: ruff-format
args: ["--line-length=100"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.0
rev: v1.11.1
hooks:
- id: mypy
args: [--ignore-missing-imports]
Expand Down
2 changes: 1 addition & 1 deletion pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"plot_pdp",
"plot_variable_importance",
]
__version__ = "0.5.14"
__version__ = "0.6.0"


pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
13 changes: 9 additions & 4 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def __init__( # noqa: PLR0915
else:
self.X = self.bart.X

if isinstance(self.bart.Y, Variable):
self.Y = self.bart.Y.eval()
else:
self.Y = self.bart.Y

self.missing_data = np.any(np.isnan(self.X))
self.m = self.bart.m
self.response = self.bart.response
Expand Down Expand Up @@ -166,26 +171,26 @@ def __init__( # noqa: PLR0915
if rule is ContinuousSplitRule:
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx]))

init_mean = self.bart.Y.mean()
init_mean = self.Y.mean()
self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
self.available_predictors = list(range(self.num_variates))

# if data is binary
self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape))

y_unique = np.unique(self.bart.Y)
y_unique = np.unique(self.Y)
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
self.leaf_sd *= 3 / self.m**0.5
else:
self.leaf_sd *= self.bart.Y.std() / self.m**0.5
self.leaf_sd *= self.Y.std() / self.m**0.5

self.running_sd = [
RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape)
]

self.sum_trees = np.full(
(self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean
(self.trees_shape, self.leaves_shape, self.Y.shape[0]), init_mean
).astype(config.floatX)
self.sum_trees_noi = self.sum_trees - init_mean
self.a_tree = Tree.new_tree(
Expand Down

0 comments on commit 794726f

Please sign in to comment.