Skip to content

Commit

Permalink
update pre-commit (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz authored Jul 1, 2024
1 parent b759dff commit e04ca5c
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 68 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.5.0
hooks:
- id: ruff
args: ["--fix", "--show-source"]
args: ["--fix", "--output-format=full"]
- id: ruff-format
args: ["--line-length=100"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
rev: v1.10.1
hooks:
- id: mypy
args: [--ignore-missing-imports]
Expand Down
4 changes: 2 additions & 2 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __new__(
alpha: float = 0.95,
beta: float = 2.0,
response: str = "constant",
split_prior: Optional[npt.NDArray[np.float_]] = None,
split_prior: Optional[npt.NDArray[np.float64]] = None,
split_rules: Optional[List[SplitRule]] = None,
separate_trees: Optional[bool] = False,
**kwargs,
Expand Down Expand Up @@ -198,7 +198,7 @@ def get_moment(cls, rv, size, *rv_inputs):

def preprocess_xy(
X: TensorLike, Y: TensorLike
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]:
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
Expand Down
46 changes: 23 additions & 23 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def normalize(self, particles: List[ParticleTree]) -> float:
return wei / wei.sum()

def resample(
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_]
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
) -> List[ParticleTree]:
"""
Use systematic resample for all but the first particle
Expand All @@ -335,7 +335,7 @@ def resample(
return particles

def get_particle_tree(
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_]
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
) -> Tuple[ParticleTree, Tree]:
"""
Sample a new particle and associated tree
Expand All @@ -347,7 +347,7 @@ def get_particle_tree(

return new_particle, new_particle.tree

def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[np.int_]:
def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
"""
Systematic resampling.
Expand Down Expand Up @@ -399,7 +399,7 @@ def __init__(self, shape: tuple) -> None:
self.mean = np.zeros(shape) # running mean
self.m_2 = np.zeros(shape) # running second moment

def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
self.count = self.count + 1
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
return fast_mean(std)
Expand All @@ -408,10 +408,10 @@ def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[
@njit
def _update(
count: int,
mean: npt.NDArray[np.float_],
m_2: npt.NDArray[np.float_],
new_value: npt.NDArray[np.float_],
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], Union[float, npt.NDArray[np.float_]]]:
mean: npt.NDArray[np.float64],
m_2: npt.NDArray[np.float64],
new_value: npt.NDArray[np.float64],
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
Expand All @@ -422,7 +422,7 @@ def _update(


class SampleSplittingVariable:
def __init__(self, alpha_vec: npt.NDArray[np.float_]) -> None:
def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
"""
Sample splitting variables proportional to `alpha_vec`.
Expand Down Expand Up @@ -535,13 +535,13 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d


def draw_leaf_value(
y_mu_pred: npt.NDArray[np.float_],
x_mu: npt.NDArray[np.float_],
y_mu_pred: npt.NDArray[np.float64],
x_mu: npt.NDArray[np.float64],
m: int,
norm: npt.NDArray[np.float_],
norm: npt.NDArray[np.float64],
shape: int,
response: str,
) -> Tuple[npt.NDArray[np.float_], Optional[npt.NDArray[np.float_]]]:
) -> Tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
mu_mean = np.empty(shape)
Expand All @@ -559,7 +559,7 @@ def draw_leaf_value(


@njit
def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
"""Use Numba to speed up the computation of the mean."""
if ari.ndim == 1:
count = ari.shape[0]
Expand All @@ -578,11 +578,11 @@ def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_

@njit
def fast_linear_fit(
x: npt.NDArray[np.float_],
y: npt.NDArray[np.float_],
x: npt.NDArray[np.float64],
y: npt.NDArray[np.float64],
m: int,
norm: npt.NDArray[np.float_],
) -> Tuple[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]:
norm: npt.NDArray[np.float64],
) -> Tuple[npt.NDArray[np.float64], List[npt.NDArray[np.float64]]]:
n = len(x)
y = y / m + np.expand_dims(norm, axis=1)

Expand Down Expand Up @@ -666,17 +666,17 @@ def update(self):

@njit
def inverse_cdf(
single_uniform: npt.NDArray[np.float_], normalized_weights: npt.NDArray[np.float_]
single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64]
) -> npt.NDArray[np.int_]:
"""
Inverse CDF algorithm for a finite distribution.
Parameters
----------
single_uniform: npt.NDArray[np.float_]
single_uniform: npt.NDArray[np.float64]
Ordered points in [0,1]
normalized_weights: npt.NDArray[np.float_])
normalized_weights: npt.NDArray[np.float64])
Normalized weights
Returns
Expand All @@ -699,7 +699,7 @@ def inverse_cdf(


@njit
def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[np.float_]:
def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]:
"""
Jitter duplicated values.
"""
Expand All @@ -715,7 +715,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[


@njit
def are_whole_number(array: npt.NDArray[np.float_]) -> np.bool_:
def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_:
"""Check if all values in array are whole numbers"""
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)

Expand Down
40 changes: 20 additions & 20 deletions pymc_bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Node:
Attributes
----------
value : npt.NDArray[np.float_]
value : npt.NDArray[np.float64]
idx_data_points : Optional[npt.NDArray[np.int_]]
idx_split_variable : int
linear_params: Optional[List[float]] = None
Expand All @@ -37,11 +37,11 @@ class Node:

def __init__(
self,
value: npt.NDArray[np.float_] = np.array([-1.0]),
value: npt.NDArray[np.float64] = np.array([-1.0]),
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
linear_params: Optional[List[npt.NDArray[np.float_]]] = None,
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
) -> None:
self.value = value
self.nvalue = nvalue
Expand All @@ -52,11 +52,11 @@ def __init__(
@classmethod
def new_leaf_node(
cls,
value: npt.NDArray[np.float_],
value: npt.NDArray[np.float64],
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
linear_params: Optional[List[npt.NDArray[np.float_]]] = None,
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
) -> "Node":
return cls(
value=value,
Expand Down Expand Up @@ -100,7 +100,7 @@ class Tree:
The dictionary's keys are integers that represent the nodes position.
The dictionary's values are objects of type Node that represent the split and leaf nodes
of the tree itself.
output: Optional[npt.NDArray[np.float_]]
output: Optional[npt.NDArray[np.float64]]
Array of shape number of observations, shape
split_rules : List[SplitRule]
List of SplitRule objects, one per column in input data.
Expand All @@ -121,7 +121,7 @@ class Tree:
def __init__(
self,
tree_structure: Dict[int, Node],
output: npt.NDArray[np.float_],
output: npt.NDArray[np.float64],
split_rules: List[SplitRule],
idx_leaf_nodes: Optional[List[int]] = None,
) -> None:
Expand All @@ -133,7 +133,7 @@ def __init__(
@classmethod
def new_tree(
cls,
leaf_node_value: npt.NDArray[np.float_],
leaf_node_value: npt.NDArray[np.float64],
idx_data_points: Optional[npt.NDArray[np.int_]],
num_observations: int,
shape: int,
Expand Down Expand Up @@ -189,7 +189,7 @@ def grow_leaf_node(
self,
current_node: Node,
selected_predictor: int,
split_value: npt.NDArray[np.float_],
split_value: npt.NDArray[np.float64],
index_leaf_node: int,
) -> None:
current_node.value = split_value
Expand Down Expand Up @@ -221,7 +221,7 @@ def get_split_variables(self) -> Generator[int, None, None]:
if node.is_split_node():
yield node.idx_split_variable

def _predict(self) -> npt.NDArray[np.float_]:
def _predict(self) -> npt.NDArray[np.float64]:
output = self.output

if self.idx_leaf_nodes is not None:
Expand All @@ -232,23 +232,23 @@ def _predict(self) -> npt.NDArray[np.float_]:

def predict(
self,
x: npt.NDArray[np.float_],
x: npt.NDArray[np.float64],
excluded: Optional[List[int]] = None,
shape: int = 1,
) -> npt.NDArray[np.float_]:
) -> npt.NDArray[np.float64]:
"""
Predict output of tree for an (un)observed point x.
Parameters
----------
x : npt.NDArray[np.float_]
x : npt.NDArray[np.float64]
Unobserved point
excluded: Optional[List[int]]
Indexes of the variables to exclude when computing predictions
Returns
-------
npt.NDArray[np.float_]
npt.NDArray[np.float64]
Value of the leaf value where the unobserved point lies.
"""
if excluded is None:
Expand All @@ -258,16 +258,16 @@ def predict(

def _traverse_tree(
self,
X: npt.NDArray[np.float_],
X: npt.NDArray[np.float64],
excluded: Optional[List[int]] = None,
shape: Union[int, Tuple[int, ...]] = 1,
) -> npt.NDArray[np.float_]:
) -> npt.NDArray[np.float64]:
"""
Traverse the tree starting from the root node given an (un)observed point.
Parameters
----------
X : npt.NDArray[np.float_]
X : npt.NDArray[np.float64]
(Un)observed point(s)
node_index : int
Index of the node to start the traversal from
Expand All @@ -278,7 +278,7 @@ def _traverse_tree(
Returns
-------
npt.NDArray[np.float_]
npt.NDArray[np.float64]
Leaf node value or mean of leaf node values
"""

Expand Down Expand Up @@ -327,14 +327,14 @@ def _traverse_tree(
return p_d

def _traverse_leaf_values(
self, leaf_values: List[npt.NDArray[np.float_]], leaf_n_values: List[int], node_index: int
self, leaf_values: List[npt.NDArray[np.float64]], leaf_n_values: List[int], node_index: int
) -> None:
"""
Traverse the tree appending leaf values starting from a particular node.
Parameters
----------
leaf_values : List[npt.NDArray[np.float_]]
leaf_values : List[npt.NDArray[np.float64]]
node_index : int
"""
node = self.get_node(node_index)
Expand Down
Loading

0 comments on commit e04ca5c

Please sign in to comment.