Skip to content

Commit

Permalink
remove reference np.float64
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 26, 2024
1 parent 7c6b462 commit dd6a6e8
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 73 deletions.
6 changes: 2 additions & 4 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __new__(
alpha: float = 0.95,
beta: float = 2.0,
response: str = "constant",
split_prior: Optional[npt.NDArray[np.float64]] = None,
split_prior: Optional[npt.NDArray] = None,
split_rules: Optional[list[SplitRule]] = None,
separate_trees: Optional[bool] = False,
**kwargs,
Expand Down Expand Up @@ -203,9 +203,7 @@ def get_moment(cls, rv, size, *rv_inputs):
return mean


def preprocess_xy(
X: TensorLike, Y: TensorLike
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
def preprocess_xy(X: TensorLike, Y: TensorLike) -> tuple[npt.NDArray, npt.NDArray]:
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
Expand Down
50 changes: 25 additions & 25 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def normalize(self, particles: list[ParticleTree]) -> float:
return wei / wei.sum()

def resample(
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
) -> list[ParticleTree]:
"""
Use systematic resample for all but the first particle
Expand All @@ -347,7 +347,7 @@ def resample(
return particles

def get_particle_tree(
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
) -> tuple[ParticleTree, Tree]:
"""
Sample a new particle and associated tree
Expand All @@ -359,7 +359,7 @@ def get_particle_tree(

return new_particle, new_particle.tree

def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
def systematic(self, normalized_weights: npt.NDArray) -> npt.NDArray[np.int_]:
"""
Systematic resampling.
Expand Down Expand Up @@ -411,7 +411,7 @@ def __init__(self, shape: tuple) -> None:
self.mean = np.zeros(shape) # running mean
self.m_2 = np.zeros(shape) # running second moment

def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]:
self.count = self.count + 1
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
return fast_mean(std)
Expand All @@ -420,21 +420,21 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray
@njit
def _update(
count: int,
mean: npt.NDArray[np.float64],
m_2: npt.NDArray[np.float64],
new_value: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
mean: npt.NDArray,
m_2: npt.NDArray,
new_value: npt.NDArray,
) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
m_2 += delta * delta2

std = (m_2 / count) ** 0.5
return mean.astype(np.float64), m_2.astype(np.float64), std.astype(np.float64)
return mean, m_2, std


class SampleSplittingVariable:
def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
def __init__(self, alpha_vec: npt.NDArray) -> None:
"""
Sample splitting variables proportional to `alpha_vec`.
Expand Down Expand Up @@ -547,16 +547,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d


def draw_leaf_value(
y_mu_pred: npt.NDArray[np.float64],
x_mu: npt.NDArray[np.float64],
y_mu_pred: npt.NDArray,
x_mu: npt.NDArray,
m: int,
norm: npt.NDArray[np.float64],
norm: npt.NDArray,
shape: int,
response: str,
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
) -> tuple[npt.NDArray, Optional[npt.NDArray]]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
mu_mean: npt.NDArray[np.float64]
mu_mean: npt.NDArray
if y_mu_pred.size == 0:
return np.zeros(shape), linear_params

Expand All @@ -571,7 +571,7 @@ def draw_leaf_value(


@njit
def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
"""Use Numba to speed up the computation of the mean."""
if ari.ndim == 1:
count = ari.shape[0]
Expand All @@ -590,11 +590,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float

@njit
def fast_linear_fit(
x: npt.NDArray[np.float64],
y: npt.NDArray[np.float64],
x: npt.NDArray,
y: npt.NDArray,
m: int,
norm: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]:
norm: npt.NDArray,
) -> tuple[npt.NDArray, list[npt.NDArray]]:
n = len(x)
y = (y / m + np.expand_dims(norm, axis=1)).astype(np.float64)

Expand Down Expand Up @@ -678,17 +678,17 @@ def update(self):

@njit
def inverse_cdf(
single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64]
single_uniform: npt.NDArray, normalized_weights: npt.NDArray
) -> npt.NDArray[np.int_]:
"""
Inverse CDF algorithm for a finite distribution.
Parameters
----------
single_uniform: npt.NDArray[np.float64]
single_uniform: npt.NDArray
Ordered points in [0,1]
normalized_weights: npt.NDArray[np.float64])
normalized_weights: npt.NDArray)
Normalized weights
Returns
Expand All @@ -711,7 +711,7 @@ def inverse_cdf(


@njit
def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]:
def jitter_duplicated(array: npt.NDArray, std: float) -> npt.NDArray:
"""
Jitter duplicated values.
"""
Expand All @@ -727,7 +727,7 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray


@njit
def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_:
def are_whole_number(array: npt.NDArray) -> np.bool_:
"""Check if all values in array are whole numbers"""
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)

Expand Down
40 changes: 20 additions & 20 deletions pymc_bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Node:
Attributes
----------
value : npt.NDArray[np.float64]
value : npt.NDArray
idx_data_points : Optional[npt.NDArray[np.int_]]
idx_split_variable : int
linear_params: Optional[list[float]] = None
Expand All @@ -38,11 +38,11 @@ class Node:

def __init__(
self,
value: npt.NDArray[np.float64] = np.array([-1.0]),
value: npt.NDArray = np.array([-1.0]),
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
linear_params: Optional[list[npt.NDArray[np.float64]]] = None,
linear_params: Optional[list[npt.NDArray]] = None,
) -> None:
self.value = value
self.nvalue = nvalue
Expand All @@ -53,11 +53,11 @@ def __init__(
@classmethod
def new_leaf_node(
cls,
value: npt.NDArray[np.float64],
value: npt.NDArray,
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
linear_params: Optional[list[npt.NDArray[np.float64]]] = None,
linear_params: Optional[list[npt.NDArray]] = None,
) -> "Node":
return cls(
value=value,
Expand Down Expand Up @@ -101,7 +101,7 @@ class Tree:
The dictionary's keys are integers that represent the nodes position.
The dictionary's values are objects of type Node that represent the split and leaf nodes
of the tree itself.
output: Optional[npt.NDArray[np.float64]]
output: Optional[npt.NDArray]
Array of shape number of observations, shape
split_rules : list[SplitRule]
List of SplitRule objects, one per column in input data.
Expand All @@ -122,7 +122,7 @@ class Tree:
def __init__(
self,
tree_structure: dict[int, Node],
output: npt.NDArray[np.float64],
output: npt.NDArray,
split_rules: list[SplitRule],
idx_leaf_nodes: Optional[list[int]] = None,
) -> None:
Expand All @@ -134,7 +134,7 @@ def __init__(
@classmethod
def new_tree(
cls,
leaf_node_value: npt.NDArray[np.float64],
leaf_node_value: npt.NDArray,
idx_data_points: Optional[npt.NDArray[np.int_]],
num_observations: int,
shape: int,
Expand Down Expand Up @@ -190,7 +190,7 @@ def grow_leaf_node(
self,
current_node: Node,
selected_predictor: int,
split_value: npt.NDArray[np.float64],
split_value: npt.NDArray,
index_leaf_node: int,
) -> None:
current_node.value = split_value
Expand Down Expand Up @@ -222,7 +222,7 @@ def get_split_variables(self) -> Generator[int, None, None]:
if node.is_split_node():
yield node.idx_split_variable

def _predict(self) -> npt.NDArray[np.float64]:
def _predict(self) -> npt.NDArray:
output = self.output

if self.idx_leaf_nodes is not None:
Expand All @@ -233,23 +233,23 @@ def _predict(self) -> npt.NDArray[np.float64]:

def predict(
self,
x: npt.NDArray[np.float64],
x: npt.NDArray,
excluded: Optional[list[int]] = None,
shape: int = 1,
) -> npt.NDArray[np.float64]:
) -> npt.NDArray:
"""
Predict output of tree for an (un)observed point x.
Parameters
----------
x : npt.NDArray[np.float64]
x : npt.NDArray
Unobserved point
excluded: Optional[list[int]]
Indexes of the variables to exclude when computing predictions
Returns
-------
npt.NDArray[np.float64]
npt.NDArray
Value of the leaf value where the unobserved point lies.
"""
if excluded is None:
Expand All @@ -259,16 +259,16 @@ def predict(

def _traverse_tree(
self,
X: npt.NDArray[np.float64],
X: npt.NDArray,
excluded: Optional[list[int]] = None,
shape: Union[int, tuple[int, ...]] = 1,
) -> npt.NDArray[np.float64]:
) -> npt.NDArray:
"""
Traverse the tree starting from the root node given an (un)observed point.
Parameters
----------
X : npt.NDArray[np.float64]
X : npt.NDArray
(Un)observed point(s)
node_index : int
Index of the node to start the traversal from
Expand All @@ -279,7 +279,7 @@ def _traverse_tree(
Returns
-------
npt.NDArray[np.float64]
npt.NDArray
Leaf node value or mean of leaf node values
"""

Expand Down Expand Up @@ -338,14 +338,14 @@ def _traverse_tree(
return p_d

def _traverse_leaf_values(
self, leaf_values: list[npt.NDArray[np.float64]], leaf_n_values: list[int], node_index: int
self, leaf_values: list[npt.NDArray], leaf_n_values: list[int], node_index: int
) -> None:
"""
Traverse the tree appending leaf values starting from a particular node.
Parameters
----------
leaf_values : list[npt.NDArray[np.float64]]
leaf_values : list[npt.NDArray]
node_index : int
"""
node = self.get_node(node_index)
Expand Down
Loading

0 comments on commit dd6a6e8

Please sign in to comment.