Skip to content

Commit

Permalink
[models] mv log attributes to base nn class
Browse files Browse the repository at this point in the history
  • Loading branch information
luigibonati committed Jul 26, 2022
1 parent 14a5933 commit f6acd54
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 12 deletions.
6 changes: 1 addition & 5 deletions mlcvs/lda/deep_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ def __init__(self, layers, activation="relu", device = None, **kwargs):
# lorentzian regularization
self.lorentzian_reg = 0

# training logs
self.epochs = 0
self.loss_train = []
self.loss_valid = []
self.log_header = True


def set_regularization(self, sw_reg=0.05, lorentzian_reg=None):
"""
Expand Down
35 changes: 33 additions & 2 deletions mlcvs/models/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def __init__(
self.name_ = "NN_CV"
self.feature_names = ["x" + str(i) for i in range(self.n_features)]

# training logs
self.epochs = 0
self.loss_train = []
self.loss_valid = []
self.log_header = True

def _init_weights(self, module):
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=1.0)
Expand Down Expand Up @@ -167,7 +173,7 @@ def forward_nn(self, x: torch.tensor) -> (torch.tensor):

def forward(self, x: torch.tensor) -> (torch.tensor):
"""
Compute model output.
Compute model output.
Parameters
----------
Expand Down Expand Up @@ -208,7 +214,7 @@ def predict(self, X):
Returns
-------
s : array-like of shape (n_samples, n_classes-1)
Linear projection of inputs.
Model outputs.
"""

return self.forward(X)
Expand Down Expand Up @@ -294,6 +300,31 @@ def set_LRScheduler(self ,optimizer, patience=5, min_lr=1e-6, factor=0.9, log=Fa
self.lrscheduler_ = LRScheduler(optimizer, patience=patience, min_lr=min_lr, factor=factor, log=log
)

# Fit function
def train_epoch(self,loader):
"""
Auxiliary function for training an epoch.
Parameters
----------
loader: DataLoader
training set
"""
for data in loader:
# =================get data===================
X = data[0].to(self.device_)
y = data[1].to(self.device_)
# =================forward====================
H = self.forward_nn(X)
# =================lda loss===================
loss = self.loss_function(H, y, save_params=False)
# =================backprop===================
self.opt_.zero_grad()
loss.backward()
self.opt_.step()
# ===================log======================
self.epochs += 1

# Input / output standardization

def standardize_inputs(self, x: torch.Tensor, print_values=False):
Expand Down
6 changes: 1 addition & 5 deletions mlcvs/tica/deep_tica.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,8 @@ def __init__(self, layers, activation="relu", gaussian_random_initialization=Fal
# lorentzian regularization
self.reg_cholesky = 0

# training logs
self.epochs = 0
self.loss_train = []
self.loss_valid = []
# (additional) training logs
self.evals_train = []
self.log_header = True

def set_regularization(self, cholesky_reg=1e-6):
"""
Expand Down

0 comments on commit f6acd54

Please sign in to comment.