Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save_hyperparameters() with pandas.DataFrame fails #8626

Closed
gunthergl opened this issue Jul 29, 2021 · 4 comments
Closed

save_hyperparameters() with pandas.DataFrame fails #8626

gunthergl opened this issue Jul 29, 2021 · 4 comments
Labels
bug Something isn't working help wanted Open to be worked on won't fix This will not be worked on

Comments

@gunthergl
Copy link

🐛 Bug

I reported this issue already at pandas-dev/pandas#42748 and yaml/pyyaml#540 but both say that is not their problem.

When upgrading pandas from 1.2.5 to the next version 1.3.0, self.save_hyperparameters() finally raises an error as (the previously working) yaml.dump(<pd.DataFrame>) does not work anymore.

See the traceback, but the error raises inside save_hparams_to_yaml(hparams_file, self.hparams) when yaml.dump() is called.

I see as possibilities

  1. Exclude the dataframe parameters completely from saving
  2. Catch dataframes: yaml.dump(<pd.DataFrame>.to_dict())
  3. Fix whatever got broken in pandas and/or pyyaml

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer

import pandas as pd

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self,  hparamB):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.save_hyperparameters()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel(hparamB=pd.DataFrame({'A': [1, 2, 3]}))
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


if __name__ == "__main__":
    run()

Expected behavior

Environment

  • PyTorch Lightning Version (e.g., 1.3.0): 1.4.0
  • PyTorch Version (e.g., 1.8) 1.9.0+cu111
  • Python version: 3.8.10
  • OS (e.g., Linux): windows
  • CUDA/cuDNN version: cu111
  • GPU models and configuration: GTX 1650
  • How you installed PyTorch (conda, pip, source): pip
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Additional context

Complete package list


(rmme_II) XXX\tests\testutil>pip list --format=freeze
absl-py==0.13.0
aiohttp==3.7.4.post0
alabaster==0.7.12
antlr4-python3-runtime==4.8
anytree==2.8.0
appdirs==1.4.4
astroid==2.6.5
async-timeout==3.0.1
atomicwrites==1.4.0
attrs==21.2.0
Babel==2.9.1
bokeh==2.3.3
brotlipy==0.7.0
CacheControl==0.12.6
cachetools==4.2.2
cachy==0.3.0
ccc==0.2.1
certifi==2021.5.30
cffi==1.14.6
chardet==4.0.0
charset-normalizer==2.0.3
cleo==0.8.1
clikit==0.6.2
cmake==3.21.0
colorama==0.4.4
colour==0.1.5
crashtest==0.3.1
cryptography==3.4.7
cycler==0.10.0
datatable==1.0.0
distlib==0.3.2
docutils==0.16
dtreeviz==1.3
filelock==3.0.12
flake8==3.9.2
FlowIO==0.9.11
FlowKit==0.4.0
FlowUtils==0.9.3
fsspec==2021.7.0
future==0.18.2
google-auth==1.33.1
google-auth-oauthlib==0.4.4
googledrivedownloader==0.4
graphviz==0.17
grpcio==1.39.0
html5lib==1.1
hydra-core==1.1.0
idna==3.2
imagesize==1.2.0
importlib-metadata==3.10.0
importlib-resources==5.2.0
iniconfig==1.1.1
isodate==0.6.0
isort==5.9.2
Jinja2==3.0.1
joblib==1.0.1
jsonschema==3.2.0
keyring==21.2.1
kiwisolver==1.3.1
lazy-object-proxy==1.6.0
llvmlite==0.34.0
lockfile==0.12.2
lxml==4.6.3
Markdown==3.3.4
MarkupSafe==2.0.1
matplotlib==3.4.2
mccabe==0.6.1
msgpack==1.0.2
MulticoreTSNE==0.1
multidict==5.1.0
mypy==0.910
mypy-extensions==0.4.3
networkx==2.6.2
numba==0.51.2
numpy==1.21.1
oauthlib==3.1.1
omegaconf==2.1.0
packaging==21.0
pandas==1.3.0
pastel==0.2.1
patsy==0.5.1
pexpect==4.8.0
Pillow==8.3.1
pip==21.1.3
pkginfo==1.7.1
pluggy==0.13.1
poethepoet==0.10.0
poetry==1.1.6
poetry-core==1.0.3
protobuf==3.17.3
ptyprocess==0.7.0
py==1.10.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycodestyle==2.7.0
pycparser==2.20
pyDeprecate==0.3.1
pydotplus==2.0.2
pyflakes==2.3.1
Pygments==2.9.0
pylev==1.3.0
pylint==2.9.5
pyOpenSSL==20.0.1
pyparsing==2.4.7
pyrsistent==0.14.11
PySocks==1.7.1
pytest==6.2.4
python-dateutil==2.8.2
python-louvain==0.15
pytorch-lightning==1.4.0
pytz==2021.1
pywin32-ctypes==0.2.0
PyYAML==5.4.1
rdflib==6.0.0
requests==2.26.0
requests-oauthlib==1.3.0
requests-toolbelt==0.9.1
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.6.1
seaborn==0.11.1
setuptools==52.0.0.post20210125
shap==0.36.0
shellingham==1.3.1
six==1.16.0
sklearn==0.0
slicer==0.0.7
snowballstemmer==2.1.0
Sphinx==3.5.4
sphinx-rtd-theme==0.5.2
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
statsmodels==0.12.2
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
threadpoolctl==2.2.0
toml==0.10.2
tomlkit==0.7.2
torch==1.9.0+cu111
torch-cluster==1.5.9
torch-geometric==1.7.2
torch-scatter==2.0.7
torch-sparse==0.6.10
torchmetrics==0.4.1
tornado==6.1
tqdm==4.61.2
typing-extensions==3.10.0.0
umap-learn==0.4.6
unittest-xml-reporting==3.0.4
urllib3==1.26.6
virtualenv==20.4.6
webencodings==0.5.1
Werkzeug==2.0.1
wheel==0.36.2
win-inet-pton==1.1.0
wincertstore==0.2
wrapt==1.12.1
yarl==1.6.3
zipp==3.5.0

Traceback

Traceback (most recent call last):
  File "XXXX/01_MWE/hyperparameter/02_ptl_pandasissue_save_hyperparameter/main.py", line 65, in <module>
    run()
  File "XXXX/01_MWE/hyperparameter/02_ptl_pandasissue_save_hyperparameter/main.py", line 61, in run
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 553, in fit
    self._run(model)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 912, in _run
    self._pre_dispatch()
  File "XXX\.conda\envs\rmme_II\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 941, in _pre_dispatch
    self._log_hyperparams()
  File "XXX\.conda\envs\rmme_II\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 970, in _log_hyperparams
    self.logger.save()
  File "XXX\.conda\envs\rmme_II\lib\site-packages\pytorch_lightning\utilities\distributed.py", line 48, in wrapped_fn
    return fn(*args, **kwargs)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\pytorch_lightning\loggers\tensorboard.py", line 249, in save
    save_hparams_to_yaml(hparams_file, self.hparams)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\pytorch_lightning\core\saving.py", line 405, in save_hparams_to_yaml
    yaml.dump(v)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\__init__.py", line 290, in dump
    return dump_all([data], stream, Dumper=Dumper, **kwds)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\__init__.py", line 278, in dump_all
    dumper.represent(data)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 27, in represent
    node = self.represent_data(data)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 342, in represent_object
    return self.represent_mapping(
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 346, in represent_object
    return self.represent_sequence(tag+function_name, args)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 286, in represent_tuple
    return self.represent_sequence('tag:yaml.org,2002:python/tuple', data)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "XXX\.conda\envs\rmme_II\lib\site-packages\yaml\representer.py", line 331, in represent_object
    if function.__name__ == '__newobj__':
AttributeError: 'functools.partial' object has no attribute '__name__'

Process finished with exit code 1

@gunthergl gunthergl added bug Something isn't working help wanted Open to be worked on labels Jul 29, 2021
@tchaton
Copy link
Contributor

tchaton commented Jul 29, 2021

Yes, I believe this shouldn t be supported. You can ignore it by providing its name to save_hyperparameters

@awaelchli
Copy link
Contributor

Like Thomas I am also not sure what can be done here. If an object passed into the constructor is not serializable then we have to explicitly ignore it (with the ignore argument in the save_hyperparameters). This is quite common for example when we pass in a backbone (nn.Module) which we also don't want to include as hyperparameter.

In your use case, did you actually intend to save the data frame or was it just accidental in the sense that you called save_hyperparameters() unconditionally out of habit?

@gunthergl
Copy link
Author

In my usecase it was mainly out of habit. Actually, I do not really need it anymore as I moved to hydra for hyperparameter settings.
In principle I am totally fine removing the hyperparameter, maybe even removing the hyperparameter saving, but if I remember correctly I used those hyperparameters when training, stopping and go on learning again. But I am not sure about that.

However, the pandas issue got reopened, so maybe there will be a fix anyway.

@stale
Copy link

stale bot commented Aug 29, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Aug 29, 2021
@stale stale bot closed this as completed Sep 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants