Skip to content

Commit

Permalink
lint updates (#199)
Browse files Browse the repository at this point in the history
* lint updates

* use built-in types
  • Loading branch information
aloctavodia authored Nov 28, 2024
1 parent 40f1220 commit 9ec4de8
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 101 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
rev: v0.8.0
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
Expand Down
14 changes: 7 additions & 7 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import warnings
from multiprocessing import Manager
from typing import List, Optional, Tuple
from typing import Optional

import numpy as np
import numpy.typing as npt
Expand All @@ -39,8 +39,8 @@ class BARTRV(RandomVariable):
name: str = "BART"
signature = "(m,n),(m),(),(),() -> (m)"
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = List[List[List[Tree]]]
_print_name: tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = list[list[list[Tree]]]

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
idx = dist_params[0].ndim - 2
Expand Down Expand Up @@ -92,10 +92,10 @@ class BART(Distribution):
beta : float
Controls the prior probability over the number of leaves of the trees.
Should be positive.
split_prior : Optional[List[float]], default None.
split_prior : Optional[list[float]], default None.
List of positive numbers, one per column in input data.
Defaults to None, all covariates have the same prior probability to be selected.
split_rules : Optional[List[SplitRule]], default None
split_rules : Optional[list[SplitRule]], default None
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
Expand Down Expand Up @@ -126,7 +126,7 @@ def __new__(
beta: float = 2.0,
response: str = "constant",
split_prior: Optional[npt.NDArray[np.float64]] = None,
split_rules: Optional[List[SplitRule]] = None,
split_rules: Optional[list[SplitRule]] = None,
separate_trees: Optional[bool] = False,
**kwargs,
):
Expand Down Expand Up @@ -198,7 +198,7 @@ def get_moment(cls, rv, size, *rv_inputs):

def preprocess_xy(
X: TensorLike, Y: TensorLike
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
Expand Down
36 changes: 18 additions & 18 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Tuple, Union
from typing import Optional, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -43,7 +43,7 @@ class ParticleTree:

def __init__(self, tree: Tree):
self.tree: Tree = tree.copy()
self.expansion_nodes: List[int] = [0]
self.expansion_nodes: list[int] = [0]
self.log_weight: float = 0

def copy(self) -> "ParticleTree":
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__( # noqa: PLR0915
self,
vars=None, # pylint: disable=redefined-builtin
num_particles: int = 10,
batch: Tuple[float, float] = (0.1, 0.1),
batch: tuple[float, float] = (0.1, 0.1),
model: Optional[Model] = None,
):
model = modelcontext(model)
Expand Down Expand Up @@ -310,7 +310,7 @@ def astep(self, _):
stats = {"variable_inclusion": variable_inclusion, "tune": self.tune}
return self.sum_trees, [stats]

def normalize(self, particles: List[ParticleTree]) -> float:
def normalize(self, particles: list[ParticleTree]) -> float:
"""
Use softmax to get normalized_weights.
"""
Expand All @@ -321,16 +321,16 @@ def normalize(self, particles: List[ParticleTree]) -> float:
return wei / wei.sum()

def resample(
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
) -> List[ParticleTree]:
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
) -> list[ParticleTree]:
"""
Use systematic resample for all but the first particle
Ensure particles are copied only if needed.
"""
new_indices = self.systematic(normalized_weights) + 1
seen: List[int] = []
new_particles: List[ParticleTree] = []
seen: list[int] = []
new_particles: list[ParticleTree] = []
for idx in new_indices:
if idx in seen:
new_particles.append(particles[idx].copy())
Expand All @@ -343,8 +343,8 @@ def resample(
return particles

def get_particle_tree(
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
) -> Tuple[ParticleTree, Tree]:
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
) -> tuple[ParticleTree, Tree]:
"""
Sample a new particle and associated tree
"""
Expand All @@ -367,12 +367,12 @@ def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray
single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw
return inverse_cdf(single_uniform, normalized_weights)

def init_particles(self, tree_id: int, odim: int) -> List[ParticleTree]:
def init_particles(self, tree_id: int, odim: int) -> list[ParticleTree]:
"""Initialize particles."""
p0: ParticleTree = self.all_particles[odim][tree_id]
# The old tree does not grow so we update the weight only once
self.update_weight(p0, odim)
particles: List[ParticleTree] = [p0]
particles: list[ParticleTree] = [p0]

particles.extend(ParticleTree(self.a_tree) for _ in self.indices)
return particles
Expand Down Expand Up @@ -419,7 +419,7 @@ def _update(
mean: npt.NDArray[np.float64],
m_2: npt.NDArray[np.float64],
new_value: npt.NDArray[np.float64],
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
Expand All @@ -439,15 +439,15 @@ def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
"""
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))

def rvs(self) -> Union[int, Tuple[int, float]]:
def rvs(self) -> Union[int, tuple[int, float]]:
rnd: float = np.random.random()
for i, val in self.enu:
if rnd <= val:
return i
return self.enu[-1]


def compute_prior_probability(alpha: int, beta: int) -> List[float]:
def compute_prior_probability(alpha: int, beta: int) -> list[float]:
"""
Calculate the probability of the node being a leaf node (1 - p(being split node)).
Expand All @@ -460,7 +460,7 @@ def compute_prior_probability(alpha: int, beta: int) -> List[float]:
-------
list with probabilities for leaf nodes
"""
prior_leaf_prob: List[float] = [0]
prior_leaf_prob: list[float] = [0]
depth = 0
while prior_leaf_prob[-1] < 0.9999:
prior_leaf_prob.append(1 - (alpha * ((1 + depth) ** (-beta))))
Expand Down Expand Up @@ -549,7 +549,7 @@ def draw_leaf_value(
norm: npt.NDArray[np.float64],
shape: int,
response: str,
) -> Tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
mu_mean = np.empty(shape)
Expand Down Expand Up @@ -590,7 +590,7 @@ def fast_linear_fit(
y: npt.NDArray[np.float64],
m: int,
norm: npt.NDArray[np.float64],
) -> Tuple[npt.NDArray[np.float64], List[npt.NDArray[np.float64]]]:
) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]:
n = len(x)
y = y / m + np.expand_dims(norm, axis=1)

Expand Down
41 changes: 21 additions & 20 deletions pymc_bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Generator
from functools import lru_cache
from typing import Dict, Generator, List, Optional, Tuple, Union
from typing import Optional, Union

import numpy as np
import numpy.typing as npt
Expand All @@ -30,7 +31,7 @@ class Node:
value : npt.NDArray[np.float64]
idx_data_points : Optional[npt.NDArray[np.int_]]
idx_split_variable : int
linear_params: Optional[List[float]] = None
linear_params: Optional[list[float]] = None
"""

__slots__ = "value", "nvalue", "idx_split_variable", "idx_data_points", "linear_params"
Expand All @@ -41,7 +42,7 @@ def __init__(
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
linear_params: Optional[list[npt.NDArray[np.float64]]] = None,
) -> None:
self.value = value
self.nvalue = nvalue
Expand All @@ -56,7 +57,7 @@ def new_leaf_node(
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
linear_params: Optional[list[npt.NDArray[np.float64]]] = None,
) -> "Node":
return cls(
value=value,
Expand Down Expand Up @@ -94,19 +95,19 @@ class Tree:
Attributes
----------
tree_structure : Dict[int, Node]
tree_structure : dict[int, Node]
A dictionary that represents the nodes stored in breadth-first order, based in the array
method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays).
The dictionary's keys are integers that represent the nodes position.
The dictionary's values are objects of type Node that represent the split and leaf nodes
of the tree itself.
output: Optional[npt.NDArray[np.float64]]
Array of shape number of observations, shape
split_rules : List[SplitRule]
split_rules : list[SplitRule]
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
idx_leaf_nodes : Optional[List[int]], by default None.
idx_leaf_nodes : Optional[list[int]], by default None.
Array with the index of the leaf nodes of the tree.
Parameters
Expand All @@ -120,10 +121,10 @@ class Tree:

def __init__(
self,
tree_structure: Dict[int, Node],
tree_structure: dict[int, Node],
output: npt.NDArray[np.float64],
split_rules: List[SplitRule],
idx_leaf_nodes: Optional[List[int]] = None,
split_rules: list[SplitRule],
idx_leaf_nodes: Optional[list[int]] = None,
) -> None:
self.tree_structure = tree_structure
self.idx_leaf_nodes = idx_leaf_nodes
Expand All @@ -137,7 +138,7 @@ def new_tree(
idx_data_points: Optional[npt.NDArray[np.int_]],
num_observations: int,
shape: int,
split_rules: List[SplitRule],
split_rules: list[SplitRule],
) -> "Tree":
return cls(
tree_structure={
Expand All @@ -159,7 +160,7 @@ def __setitem__(self, index, node) -> None:
self.set_node(index, node)

def copy(self) -> "Tree":
tree: Dict[int, Node] = {
tree: dict[int, Node] = {
k: Node(
value=v.value,
nvalue=v.nvalue,
Expand Down Expand Up @@ -199,7 +200,7 @@ def grow_leaf_node(
self.idx_leaf_nodes.remove(index_leaf_node)

def trim(self) -> "Tree":
tree: Dict[int, Node] = {
tree: dict[int, Node] = {
k: Node(
value=v.value,
nvalue=v.nvalue,
Expand Down Expand Up @@ -233,7 +234,7 @@ def _predict(self) -> npt.NDArray[np.float64]:
def predict(
self,
x: npt.NDArray[np.float64],
excluded: Optional[List[int]] = None,
excluded: Optional[list[int]] = None,
shape: int = 1,
) -> npt.NDArray[np.float64]:
"""
Expand All @@ -243,7 +244,7 @@ def predict(
----------
x : npt.NDArray[np.float64]
Unobserved point
excluded: Optional[List[int]]
excluded: Optional[list[int]]
Indexes of the variables to exclude when computing predictions
Returns
Expand All @@ -259,8 +260,8 @@ def predict(
def _traverse_tree(
self,
X: npt.NDArray[np.float64],
excluded: Optional[List[int]] = None,
shape: Union[int, Tuple[int, ...]] = 1,
excluded: Optional[list[int]] = None,
shape: Union[int, tuple[int, ...]] = 1,
) -> npt.NDArray[np.float64]:
"""
Traverse the tree starting from the root node given an (un)observed point.
Expand All @@ -273,7 +274,7 @@ def _traverse_tree(
Index of the node to start the traversal from
split_variable : int
Index of the variable used to split the node
excluded: Optional[List[int]]
excluded: Optional[list[int]]
Indexes of the variables to exclude when computing predictions
Returns
Expand Down Expand Up @@ -327,14 +328,14 @@ def _traverse_tree(
return p_d

def _traverse_leaf_values(
self, leaf_values: List[npt.NDArray[np.float64]], leaf_n_values: List[int], node_index: int
self, leaf_values: list[npt.NDArray[np.float64]], leaf_n_values: list[int], node_index: int
) -> None:
"""
Traverse the tree appending leaf values starting from a particular node.
Parameters
----------
leaf_values : List[npt.NDArray[np.float64]]
leaf_values : list[npt.NDArray[np.float64]]
node_index : int
"""
node = self.get_node(node_index)
Expand Down
Loading

0 comments on commit 9ec4de8

Please sign in to comment.