Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepTDA cannot be loaded from checkpoint #103

Closed
luigibonati opened this issue Nov 19, 2023 · 2 comments
Closed

DeepTDA cannot be loaded from checkpoint #103

luigibonati opened this issue Nov 19, 2023 · 2 comments
Assignees

Comments

@luigibonati
Copy link
Owner

Loading a DeepTDA CV from a checkpoint does not work:

Minimal (non)working example:

from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
checkpoint = ModelCheckpoint(save_top_k=1,  monitor="valid_loss")

trainer = pl.Trainer(callbacks=[checkpoint],enable_checkpointing=True)
trainer.fit( model, datamodule )

best_model = DeepTDA.load_from_checkpoint(checkpoint.best_model_path)

given an error in initialization:

File [~/software/mambaforge/envs/mlcolvar/lib/python3.10/site-packages/mlcolvar/cvs/supervised/deeptda.py:67](https://file+.vscode-resource.vscode-cdn.net/home/lbonati%40iit.local/work/simulations/sampl5/OAMe/G2/~/software/mambaforge/envs/mlcolvar/lib/python3.10/site-packages/mlcolvar/cvs/supervised/deeptda.py:67), in DeepTDA.__init__(self, n_states, n_cvs, target_centers, target_sigmas, layers, options, **kwargs)
     35 def __init__(
     36     self,
     37     n_states: int,
   (...)
     43     **kwargs,
     44 ):
     45     """
     46     Define Deep Targeted Discriminant Analysis (Deep-TDA) CV composed by a neural network module.
     47     By default a module standardizing the inputs is also used.
   (...)
     64         Set 'block_name' = None or False to turn off that block
     65     """
---> 67     super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs)
     69     # =======   LOSS  =======
     70     self.loss_fn = TDALoss(
     71         n_states=n_states,
     72         target_centers=target_centers,
     73         target_sigmas=target_sigmas,
     74     )

TypeError: mlcolvar.cvs.cv.BaseCV.__init__() got multiple values for keyword argument 'in_features'

we should also check the other CVs and add regtests for this feature (as of now only regressionCV was tested in this notebook: https://mlcolvar.readthedocs.io/en/stable/notebooks/tutorials/intro_3_loss_optim.html#Model-checkpointing)

@luigibonati
Copy link
Owner Author

@andrrizzi we looked into it, the problem is that when loading a checkpoint:
super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs)

kwargs contains also in_features and out_features.
If we delete those keys before calling the init of the mother class it works:

        kwargs.pop("in_features", None)
        kwargs.pop("out_features", None)
        super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs)

but what I don't like is that we need to do this in every class that inherits from BaseCV..
do you have any suggestion?

@andrrizzi
Copy link
Collaborator

If all the inherited CVs explicitly pass in/out_features to BaseCV.__init__ based on some other init argument, an alternative might be to modify BaseCV.__init__ to call self.save_parameters(ignore=['in_features', 'out_features']). I'm not sure, but I seem to remember that only saved parameters are then restored from the checkpoint.

If only a handful are doing it, then we might add that save_parameters(ignore=...) bit individually in the their __init__.

luigibonati added a commit that referenced this issue Dec 22, 2023
Fix #103: CVs cannot be loaded from checkpoint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants