-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DEV: Initial AVA implementation review #699
base: main
Are you sure you want to change the base?
Conversation
🙌 awesome, thank you @marisbasha! This looks great. I had a quick look and can see you made the changes we discussed. I will review in the next couple of days. I sent you an invite for a meeting two weeks from today. Before we meet, I will add a recipe in a notebook using your implementation as we discussed. You got this most of the way there, I'm happy to take a first pass at stuff like docstrings + higher-level functions for training. I will make edits to this branch directly. In theory you should be able to It will be good to get your feedback on the docstrings, train/eval/predict functions, and tests too so we have more than one set of eyes on them. Thanks again! I'm excited about getting this first implementation added ASAP so we can test it out on some real data. |
@NickleDave I have github desktop, so its automatically synced, I can see it from the app. Let's meet on Monday 2 October then! Let me know when I will have something to test (docstring / training loop) |
Excellent 🤔 maybe I should be using GitHub desktop
Will do, thank you @marisbasha |
6009b0d
to
4fc5d80
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @marisbasha, thanks again for all you work here and sorry I haven't gotten to this sooner, I've been working on stuff for Yarden and dealing with other life things.
There's two quick changes I need you to make before I continue
- use
torchmetrics.KLDivergence
as one of themetrics
in thevak.models.AVA
definition (see my comment) - (this is very minor, just for consistency, but) rename
vak.nets.Ava
->vak.nets.AVA
Please do that and test that you can make an instance of the model in a notebook without getting any crashes, and let me know how it goes, thanks!
Also please notice I rewrote the commit history. I promise I'm not trying to take credit for your work, I just needed to remove some changes, and also break up some of the commits so I could better follow the changes. I can give you more detail when we meet but for now please know you will want to do git pull --rebase
in your local clone of vak
so that you can get the rewritten history. If you do git pull
(without the rebase) you'll get a bunch of weird conflicts. Just let me know if that's not clear.
Let's keep working async like this for now, I think we are close to getting a toy example in a notebook. I'll check in early next week but I feel like we can keep going this way for a bit before we need to meet (since I'm sure you're busy too)
src/vak/nets/ava.py
Outdated
def forward(self, x): | ||
return self.layer(x) | ||
|
||
class Ava(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's please rename this to AVA
since it's an acronym and to be consistent with the model name. We'll rely on the namespace to differentiate them (nets.AVA
vs models.AVA
) as you have done already in the model class
src/vak/models/ava.py
Outdated
optimizer = torch.optim.Adam | ||
metrics = { | ||
"loss": VaeElboLoss, | ||
"kl": torch.nn.functional.kl_div |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I try to instantiate the model in a notebook I get an error that I'll paste below.
It's not your fault though, it's because of something we haven't made precise in the code yet.
I'll raise an issue about that (so I don't dump a bunch of detail here).
Can you please try changing to torchmetrics.KLDivergence
and we'll see if that fixes the bug?
https://torchmetrics.readthedocs.io/en/stable/regression/kl_divergence.html
Below is the traceback from the error I'm getting, basically because vak.models.base.Model
expects every metric to implement a __call__
method (I think we might want to require that every metric be a subclass of the torchmetrics.Metric
class instead, since that's more explicit and helps us ensure consistent behavior)
TypeError Traceback (most recent call last)
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/models/decorator.py:73, in model.<locals>._model(definition)
72 try:
---> 73 validate_definition(definition)
74 except ValueError as err:
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/models/definition.py:262, in validate(definition)
259 if not (
260 inspect.isclass(metrics_dict_val) and callable(metrics_dict_val)
261 ):
--> 262 raise TypeError(
263 "A model definition's 'metrics' variable must be a dict mapping "
264 "string names to classes that define __call__ methods, "
265 f"but the key '{metrics_dict_key}' maps to a value with type {type(metrics_dict_val)}, "
266 f"not recognized as callable."
267 )
269 # ---- validate default config
TypeError: A model definition's 'metrics' variable must be a dict mapping string names to classes that define __call__ methods, but the key 'kl' maps to a value with type <class 'function'>, not recognized as callable.
The above exception was the direct cause of the following exception:
ModelDefinitionValidationError Traceback (most recent call last)
Cell In[1], line 3
1 import torch
----> 3 from vak.nets.ava import Ava
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/__init__.py:1
----> 1 from . import (
2 __main__,
3 cli,
4 common,
5 config,
6 datasets,
7 eval,
8 learncurve,
9 metrics,
10 models,
11 nets,
12 nn,
13 plot,
14 predict,
15 prep,
16 train,
17 transforms,
18 )
19 from .__about__ import (
20 __author__,
21 __commit__,
(...)
28 __version__,
29 )
31 __all__ = [
32 "__main__",
33 "__author__",
(...)
56 "transforms",
57 ]
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/__main__.py:8
5 import argparse
6 from pathlib import Path
----> 8 from .cli import cli
11 def get_parser():
12 """returns ArgumentParser instance used by main()"""
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/cli/__init__.py:4
1 """command-line interface functions for training,
2 creating learning curves, etc."""
----> 4 from . import cli, eval, learncurve, predict, prep, train
6 __all__ = [
7 "cli",
8 "eval",
(...)
12 "train",
13 ]
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/cli/eval.py:4
1 import logging
2 from pathlib import Path
----> 4 from .. import config
5 from .. import eval as eval_module
6 from ..common.logging import config_logging_for_cli, log_version
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/config/__init__.py:2
1 """sub-package that parses config.toml files and returns config object"""
----> 2 from . import (
3 config,
4 eval,
5 learncurve,
6 model,
7 parse,
8 predict,
9 prep,
10 spect_params,
11 train,
12 validators,
13 )
16 __all__ = [
17 "config",
18 "eval",
(...)
26 "validators",
27 ]
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/config/config.py:4
1 import attr
2 from attr.validators import instance_of, optional
----> 4 from .eval import EvalConfig
5 from .learncurve import LearncurveConfig
6 from .predict import PredictConfig
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/config/eval.py:8
6 from ..common import device
7 from ..common.converters import expanded_user_path
----> 8 from .validators import is_valid_model_name
11 def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict:
12 post_tfm_kwargs = dict(post_tfm_kwargs)
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/config/validators.py:6
2 from pathlib import Path
4 import toml
----> 6 from .. import models
7 from ..common import constants
10 def is_a_directory(instance, attribute, value):
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/models/__init__.py:12
10 from .tweetynet import TweetyNet
11 from .vae_model import VAEModel
---> 12 from .ava import AVA
14 __all__ = [
15 "base",
16 "ConvEncoderUMAP",
(...)
30
31 ]
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/models/ava.py:11
7 from .vae_model import VAEModel
8 from ..nn.loss import VaeElboLoss
10 @model(family=VAEModel)
---> 11 class AVA:
12 """
13 """
14 network = nets.Ava
File ~/Documents/repos/coding/vocalpy/vak-vocalpy/src/vak/models/decorator.py:79, in model.<locals>._model(definition)
75 raise ModelDefinitionValidationError(
76 f"Validation failed for the following model definition:\n{definition}"
77 ) from err
78 except TypeError as err:
---> 79 raise ModelDefinitionValidationError(
80 f"Validation failed for the following model definition:\n{definition}"
81 ) from err
83 attributes = dict(family.__dict__)
84 attributes.update({"definition": definition})
ModelDefinitionValidationError: Validation failed for the following model definition:
<class 'vak.models.ava.AVA'>
@all-contributors add @marisbasha for code |
I've put up a pull request to add @marisbasha! 🎉 |
@marisbasha following up on our discussion today: let's use the TorchMetrics.KLDivergence class. But just for future reference, as far as I can tell it would be equivalent use I'm basing that on this discussion: https://lightning.ai/forums/t/understanding-logging-and-validation-step-validation-epoch-end/291/2 |
@NickleDave added the requested changes. Also as we talked on the last call I made the commit very expressive on everything that changed in that commit. |
🚀 great, thank you @marisbasha -- I will get back to this this week. (I need to get this branch merged in so I can re-start experiments and deal with some life stuff) |
Hi @marisbasha I made some progress on this but I get an error when I start training -- if you have a chance before I do to test with the notebook I added and you have a guess about what's going on, please let me know. Looks like for some reason we end up with 5D input going into the batchnorm?
|
@NickleDave sure, I'll have a look during the weekend and I'll let you know! |
Hello @NickleDave, as far as I was able to infer, the problem is related to batching. In line 98 of nets/ava.py we have: The error is caused by unsqueeze, that transforms the already 4d input (Batch, Channel, Width, Height) into a 5d tensor. Another error that emerges, which we need to talk, is the view operation after the encoder in the same line Let me know if you want me to proceed somehow with any change. |
Thank you @marisbasha for figuring out where the error is coming from. |
Hi @marisbasha, sorry I wasn't able to look at this sooner. Now that I have, I think you are right -- we should give the network an I added a commit making that change. We should also make those assumptions explicit with a default dataloader we use for the model, but we don't need to figure that out now. Before we meet tomorrow I will
But I think we are closer and I can start doing stuff like adding docstrings, functions to train/test the model, etc. |
Hi @marisbasha, getting closer on this. Changes I made:
|
Okay, comment two: what do we need to change I have a branch here with a couple of notebooks in the project root:
So what we need to do is have a series of transforms that further pads the spectrograms to give us an input shape that we'll be able to reshape to when we decode. Basically we'll add a I'm not actually sure what valid shapes are. I think we have to work backwards using the equations that determine output sizes of convolutional layers. I'm pretty sure any power of 2 will work but this would involve adding a lot of padding in some cases, e.g. if we go to the next power of 2 from 256 (for 257 frequency bins) we'd be padding to (512 x 512), which is a lot of wasted computation. |
Hi @marisbasha I thought about this some more and I feel like it's best to replicate the AVA pre-processing pipeline. I'm working on something else with a deadline but I expect to be able to get back to this after Dec 6 next week |
…ataset.py -- AVA default caused entire spectrogram of 1 value
…/make_splits.py, remove unused parameter 'purpose' from (renamed) make_splits function
…k/prep/parametric_umap/parametric_umap.py
Hey @marisbasha thanks for reaching back out and picking this up again. Just writing notes on next steps from our call today:
Just let me know if you have any questions! Like I said, happy to help however--we can discuss here or jump on a video call, whatever works best for you. Thanks so much for working on this |
@NickleDave Here's the pull request. I've changed some things to the network to ensure it's the same as AVA. Also, I've made it such that you can use any input shape to train it, although we'd need to use 128x128 inputs to use the AVA weights.