Skip to content

Commit

Permalink
[Fix] BaseModel.fit() 함수 코드 수정 (#30)
Browse files Browse the repository at this point in the history
* [fix] BaseModel.fit() 함수의 return값 수정

* [chore] checkpoint 저장용 directory 추가

* [fix] BaseModel.fit() 함수 내 checkpoint callback 추가

* [fix] BaseModel.fit() 함수 내 validation 성능지표 logging 추가
  • Loading branch information
22ema authored Jan 21, 2025
1 parent b0d0286 commit 95e33c4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions CATS/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from ..callbacks import History
from ..callbacks import History, ModelCheckpointTorch
from ..inputs import (DenseFeat, SparseFeat, VarLenSparseFeat,
build_input_features, create_embedding_matrix,
embedding_lookup, get_dense_inputs)
Expand Down Expand Up @@ -72,6 +72,9 @@ def __init__(
self._ckpt_saved_epoch = False # used for EarlyStopping in tf1.14

self.history = History()
self.model_checkpoint = ModelCheckpointTorch(
"./checkpoints/weights.e{epoch:02d}-auc{val_auc:.2f}.pt"
)

def fit(
self,
Expand Down Expand Up @@ -138,7 +141,10 @@ def fit(
steps_per_epoch = (sample_num - 1) // batch_size + 1

# configure callbacks
callbacks = (callbacks or []) + [self.history] # add history callback
callbacks = (callbacks or []) + [
self.history,
self.model_checkpoint,
] # add history callback
callbacks = CallbackList(callbacks)
callbacks.set_model(self)
callbacks.on_train_begin()
Expand Down Expand Up @@ -227,14 +233,23 @@ def fit(
for name in self.metrics:
eval_str += " - " + name + ": {0: .4f}".format(epoch_logs[name])

if do_validation:
for name in self.metrics:
eval_str += (
" - "
+ "val_"
+ name
+ ": {0: .4f}".format(epoch_logs["val_" + name])
)

logging.info(eval_str)
callbacks.on_epoch_end(epoch, epoch_logs)
if self.stop_training:
break

callbacks.on_train_end()

return self.historys
return self.history

def evaluate(
self,
Expand Down
Empty file added checkpoints/.gitkeep
Empty file.

0 comments on commit 95e33c4

Please sign in to comment.