diff --git a/netam/framework.py b/netam/framework.py index 51b65778..289decc7 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -372,7 +372,7 @@ def __init__( train_dataset, val_dataset, model, - optimizer_name="AdamW", + optimizer_name="Adam", batch_size=1024, learning_rate=0.1, min_learning_rate=1e-4, @@ -757,6 +757,7 @@ def __init__( train_dataset, val_dataset, model, + optimizer_name="Adam", batch_size=1024, learning_rate=0.1, min_learning_rate=1e-4, @@ -767,11 +768,12 @@ def __init__( train_dataset, val_dataset, model, - batch_size, - learning_rate, - min_learning_rate, - l2_regularization_coeff, - name, + optimizer_name=optimizer_name, + batch_size=batch_size, + learning_rate=learning_rate, + min_learning_rate=min_learning_rate, + l2_regularization_coeff=l2_regularization_coeff, + name=name, ) def loss_of_batch(self, batch):