Skip to content

Commit

Permalink
Random generator (#39)
Browse files Browse the repository at this point in the history
* Use random number generator class

* Bump version

* Style fix

* Style fix
  • Loading branch information
lgmoneda authored Nov 24, 2021
1 parent d2f595c commit dc8b3c9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "time-robust-forest"
version = "0.1.7"
version = "0.1.8"
description = "Explores time information to train a robust random forest"
readme = "README.md"
authors = [
Expand Down
16 changes: 11 additions & 5 deletions time_robust_forest/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from numpy.random import default_rng
from sklearn import metrics
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from time_robust_forest.functions import (
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
min_impurity_decrease=0,
n_jobs=-1,
multi=True,
random_state=42,
):
self.min_leaf, self.max_depth = min_leaf, max_depth
self.time_column = time_column
Expand All @@ -65,6 +67,7 @@ def __init__(
self.multi = multi
self.bootstrapping = bootstrapping
self.period_criterion = period_criterion
self.random_state = random_state

def fit(self, X, y, sample_weight=None, verbose=False):
"""
Expand Down Expand Up @@ -107,7 +110,7 @@ def fit(self, X, y, sample_weight=None, verbose=False):
max_features=self.max_features,
criterion="std",
period_criterion=self.period_criterion,
random_state=i,
random_state=i + self.random_state,
)
for i in range(self.n_estimators)
]
Expand All @@ -127,7 +130,7 @@ def fit(self, X, y, sample_weight=None, verbose=False):
max_features=self.max_features,
period_criterion=self.period_criterion,
criterion="std",
random_state=i,
random_state=i + self.random_state,
)
for i in range(self.n_estimators)
)
Expand Down Expand Up @@ -231,6 +234,7 @@ def __init__(
min_impurity_decrease=0,
n_jobs=-1,
multi=True,
random_state=42,
):
self.min_leaf, self.max_depth = min_leaf, max_depth
self.time_column = time_column
Expand All @@ -243,6 +247,7 @@ def __init__(
self.criterion = criterion
self.period_criterion = period_criterion
self.min_impurity_decrease = min_impurity_decrease
self.random_state = random_state

def fit(self, X, y, sample_weight=None, verbose=False):
"""
Expand Down Expand Up @@ -288,7 +293,7 @@ def fit(self, X, y, sample_weight=None, verbose=False):
period_criterion=self.period_criterion,
min_impurity_decrease=self.min_impurity_decrease,
total_sample=self.total_sample,
random_state=i,
random_state=i + self.random_state,
)
for i in range(self.n_estimators)
]
Expand All @@ -310,7 +315,7 @@ def fit(self, X, y, sample_weight=None, verbose=False):
min_impurity_decrease=self.min_impurity_decrease,
total_sample=self.total_sample,
criterion=self.criterion,
random_state=i,
random_state=i + self.random_state,
)
for i in range(self.n_estimators)
)
Expand Down Expand Up @@ -479,6 +484,7 @@ def __init__(
self.period_criterion = period_criterion
self.min_impurity_decrease = min_impurity_decrease
self.total_sample = total_sample
self.rng = default_rng(random_state)

if sample_weight is not None:
self.sample_weight = sample_weight
Expand Down Expand Up @@ -532,7 +538,7 @@ def create_split(self):
considering this chosen set, perform the split and make the
recursive call to build sub tress using the result splits.
"""
variables_to_consider = np.random.choice(
variables_to_consider = self.rng.choice(
self.variables, self.max_n_variables, replace=False
)
for idx, variable in enumerate(self.variables):
Expand Down

0 comments on commit dc8b3c9

Please sign in to comment.