diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index b60acaf..6a0d7b6 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -25,7 +25,7 @@ ) __all__ = ["BART", "PGBART"] -__version__ = "0.5.6" +__version__ = "0.5.7" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 197ae5e..32d545d 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -157,7 +157,7 @@ def __init__( for idx, rule in enumerate(self.split_rules): if rule is ContinuousSplitRule: - self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.std(self.X[:, idx])) + self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx])) init_mean = self.bart.Y.mean() self.num_observations = self.X.shape[0] @@ -700,7 +700,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[ if are_whole_number(array): seen = [] for idx, num in enumerate(array): - if num in seen: + if num in seen and not np.isnan(num): array[idx] = num + np.random.normal(0, std / 12) else: seen.append(num) @@ -711,8 +711,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[ @njit def are_whole_number(array: npt.NDArray[np.float_]) -> np.bool_: """Check if all values in array are whole numbers""" - new_array = np.mod(array, 1) - return np.all(new_array == 0) + return np.all(np.mod(array[~np.isnan(array)], 1) == 0) def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin