Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch weight_decay ambiguity with shared params #1319

Closed
JackTemaki opened this issue May 3, 2023 · 28 comments · Fixed by #1320
Closed

PyTorch weight_decay ambiguity with shared params #1319

JackTemaki opened this issue May 3, 2023 · 28 comments · Fixed by #1320
Assignees

Comments

@JackTemaki
Copy link
Collaborator

JackTemaki commented May 3, 2023

The code logic of _get_optimizer_param_groups branch with activated weight decay can not run in its current form. Iterating over all named modules while at the same time iterating recursively over parameters will yield the same parameter multiple times with a different module reference.

In my case this was:

<class 'i6_experiments.users.rossenbach.experiments.librispeech.tts_architecture_improvement_23.pytorch_networks.ctc_aligner_v1.Model'>                                                                            
<class 'torch.nn.modules.sparse.Embedding'> 

for
speaker_embedding.weight

This means we need a completely new logic if we want to exclude some modules.

@albertz
Copy link
Member

albertz commented May 3, 2023

I don't understand. Why is it not possible to fix it? Why not make the params unique?

You would just keep a list of visited params, and if some param is already in there, skip it.

@Icemole
Copy link
Collaborator

Icemole commented May 3, 2023

Can you please give a short example in which different modules share common parameters? I didn't know that this was possible (although it would make sense) and a brief demo would help me grasp the concept.

In principle we can do what @albertz mentioned, keeping a set of Parameter references that have been already visited; then the code wouldn't need much tinkering. Assuming that the Parameter reference is the same, of course, but it seems to be the issue you're having to begin with, so I think this would solve the issue.

@JackTemaki
Copy link
Collaborator Author

JackTemaki commented May 3, 2023

Can you please give a short example in which different modules share common parameters? I didn't know that this was possible (although it would make sense) and a brief demo would help me grasp the concept.

Any nested modules will cause this problem, because the parameter access is recursive. Making it not recursive might already solve the problem, but then we can not use strings anymore like it is currently done, because you get name repetitions.

In principle we can do what @albertz mentioned, keeping a set of Parameter references that have been already visited; then the code wouldn't need much tinkering. Assuming that the Parameter reference is the same, of course, but it seems to be the issue you're having to begin with, so I think this would solve the issue.

It does not solve this issue, because we are testing for the module name in the blacklist. Thus in my case the same parameter was added in both parameter groups and I got a crash.

For sure there are solutions, but as after some minutes I did not have one, I wanted to post this.

@albertz
Copy link
Member

albertz commented May 3, 2023

It does not solve this issue

Sure, it will solve it. You will never visit the same param twice. Thus the problem is solved.

@albertz
Copy link
Member

albertz commented May 3, 2023

See the code in rf.Module.named_parameters() as an example.

@Icemole
Copy link
Collaborator

Icemole commented May 3, 2023

I think there might be a solution if we check only for collisions with respect to the Parameter itself (p in the code), instead of the full parameter name (fpn in the code). I'm working on a fix, which you can check out in an isolated branch (an exception in this case for the "working in main" workflow, given the isolate fix).

@albertz
Copy link
Member

albertz commented May 3, 2023

Yes sure, just like in named_parameters(). I mean, that is already what the error tells you, right?

Icemole added a commit that referenced this issue May 3, 2023
@Icemole Icemole linked a pull request May 3, 2023 that will close this issue
@Icemole
Copy link
Collaborator

Icemole commented May 3, 2023

@JackTemaki please check #1320, hopefully it will fix the issue 🙂

@albertz I used similar functionality as named_parameters(), but as Parameter already has a hashing function (as it inherits from torch.Tensor) I didn't feel the need to use the class RefIdEq.

@Icemole
Copy link
Collaborator

Icemole commented May 3, 2023

As per pytorch/pytorch#2569, the hashing function of a Parameter runs on CPU. Would that pose a problem?

@albertz
Copy link
Member

albertz commented May 3, 2023

Yes the use of RefIdEq is needed if the underlying object does not support equality comparison and hashing. I would have expected that it would not work for torch.nn.Parameter? What happens when you do a check param1 != param2? We really want to avoid that this does an elemwise check. We want to avoid that the equality/hashing does depend on the actual values. We want that it just depends on the reference.

@Icemole
Copy link
Collaborator

Icemole commented May 3, 2023

@albertz got it. Both functions are obtained from torch.Tensor, and the hashing function uses id() but the equality function is overwritten to compare by value:

>>> t1 = torch.tensor([1])
>>> t2 = torch.tensor([1])
>>> torch.nn.Parameter(t1, requires_grad=False) == torch.nn.Parameter(t2, requires_grad=False)
tensor([True])

I'll use RefIdEq then.

@JackTemaki
Copy link
Collaborator Author

It seems my initial posting was not clear. The error is NOT only that we are visiting the same parameter twice, but that the module we use for the blacklist check is different/wrong. If you put in your fix now, speaker_embedding.weight will only be visited together with the ctc_aligner_v1.Model, and thus get assigned to the weight decay parameter group, which is wrong.

@albertz
Copy link
Member

albertz commented May 4, 2023

Why is it wrong? It's anyway a heuristic. There is no "right way" here. As long as it is not crashing, it is fine I would say.

Or do you have a suggestion for a better heuristic?

@JackTemaki
Copy link
Collaborator Author

Why is it wrong? It's anyway a heuristic. There is no "right way" here. As long as it is not crashing, it is fine I would say.

Or do you have a suggestion for a better heuristic?

Lets say it is not "wrong", but a no-op then. If we anyway do not use the blacklist, the whole code can be removed, or keep the "bias" part only.

@albertz
Copy link
Member

albertz commented May 4, 2023

I don't understand what you mean. We still have that blacklist? It's working fine for all the normal cases? Why would we remove it?

It is just ambiguous for the case you are describing here, i.e. sharing params. As long as the heuristic is deterministic for such (rare) ambiguous case, which it is, it is fine, I would say. I don't see a reason to completely remove the heuristic now.

@Icemole
Copy link
Collaborator

Icemole commented May 4, 2023

I see, I think I understand better now. The solution would then involve something like letting the user specify certain groups of modules they wish to blacklist, as I understand it?

albertz added a commit that referenced this issue May 4, 2023
@albertz
Copy link
Member

albertz commented May 4, 2023

I see, I think I understand better now. The solution would then involve something like letting the user specify certain groups of modules they wish to blacklist, as I understand it?

This could be another solution. But I would say this should be up to the user. For most users, I think the current heuristic should be fine anyway.

@albertz albertz changed the title PyTorch weight_decay code includes parameters twice. PyTorch weight_decay ambiguity with shared params May 4, 2023
@albertz albertz reopened this May 4, 2023
@Icemole
Copy link
Collaborator

Icemole commented May 4, 2023

But I would say this should be up to the user.

I think @JackTemaki's point is that this isn't customizable as of now. Maybe we can work on customizing it?

A straight-forward solution could involve adding a key additional_blacklist_wd_modules (or similar) that contained a list of modules to blacklist, besides the current blacklist that we specify. What do you think?

@albertz
Copy link
Member

albertz commented May 4, 2023

So I will leave this open, and I understand this as a feature request now, not as bug anymore (it was a bug that it was crashing, but that was fixed via #1320). I don't think you can say that the current code is wrong now, as it is just a heuristic. You could just say the heuristic is bad. So, I understand this as a feature request to also have other potential heuristics, or maybe the requirement that the user specifies everything explicitly, or so. This should all be configurable then, as soon as we have such other options.

@albertz
Copy link
Member

albertz commented May 4, 2023

A straight-forward solution could involve adding a key additional_blacklist_wd_modules (or similar) that contained a list of modules to blacklist, besides the current blacklist that we specify. What do you think?

First of all, it should be possible for the user to specify an own custom function to determine for each param the optimizer group. Something like (name, param) -> opt_group. That gives full flexibility to the user.

I'm a bit hesitating to add options for the heuristic. additional_blacklist_wd_modules could be a solution here, but I would not add this unless there is someone who would want to use that. So this is a question maybe for @JackTemaki, if this is sth he would want to use then. If this is not the case, we should not add this now.

@Icemole
Copy link
Collaborator

Icemole commented May 4, 2023

I see. Then we can simply let the user specify their own optim_groups and if specified we don't split them through our own heuristics? It's anyway simply a list of network parameters.

This would work in the PyTorch backend, but I'm wondering if this would have any effect in the TensorFlow net-dict backend (I'm currently thinking of issues when defining a RETURNN frontend config), as we don't have any custom weight decay heuristics set there, do we? But I guess this would be a separate issue.

@albertz
Copy link
Member

albertz commented May 4, 2023

I see. Then we can simply let the user specify their own optim_groups and if specified we don't split them through our own heuristics? It's anyway simply a list of network parameters.

Yes. But I think it should be a function, because that is usually more reasonable to specify it.

But before starting to implement this now, I would wait until someone says that he/she really wants to have this now.

I'm wondering if this would have any effect in the TensorFlow net-dict backend

Why would it? This option is purely for the PT engine, or not?

we don't have any custom weight decay heuristics set there, do we?

No heuristic, but the user just specifies it per layer. Every layer has the L2 option. Btw, with param sharing, in RETURNN there is still always clear which layer actually owns the variable, and L2 is used then of that layer.

@JackTemaki
Copy link
Collaborator Author

So this is a question maybe for @JackTemaki, if this is sth he would want to use then.

No, no need for that.

I just do not understand where this issue ended up, this was at no point about shared params.

@albertz
Copy link
Member

albertz commented May 4, 2023

So this is a question maybe for @JackTemaki, if this is sth he would want to use then.

No, no need for that.

So, how do you solve this problem then?

this was at no point about shared params.

I don't understand. In your initial post, you write:

iterating recursively over parameters will yield the same parameter multiple times with a different module reference

So this means shared params, or not?

@JackTemaki
Copy link
Collaborator Author

So this is a question maybe for @JackTemaki, if this is sth he would want to use then.

No, no need for that.

So, how do you solve this problem then?

What problem? I have no need for a custom blacklist.

this was at no point about shared params.

I don't understand. In your initial post, you write:

iterating recursively over parameters will yield the same parameter multiple times with a different module reference

So this means shared params, or not?

No. As I clearly wrote in my second post:

Any nested modules will cause this problem, because the parameter access is recursive. (recursive meaning it will access parameters of submodules, but I thought this is clear).

@albertz
Copy link
Member

albertz commented May 4, 2023

So, how do you solve this problem then?

What problem?

You said the current (or original) logic of _get_optimizer_param_groups cannot work for your case. I assume you got an exception? Or wrong behavior? So now, we extended the heuristic to cover that. But from your comments, I understood that you are not happy with this solution. But I assumed you have another solution?

So this means shared params, or not?

No. As I clearly wrote in my second post:

"Any nested modules will cause this problem, because the parameter access is recursive." (recursive meaning it will access parameters of submodules, but I thought this is clear).

Oh, I think I understand now. From this, originally I thought you mean shared parameters.

But the fix is easy, or not? Just this line:

for pn, p in m.named_parameters():

needs to be changed to:

for pn, p in m.named_parameters(recurse=False):

That's all, or not?

@JackTemaki
Copy link
Collaborator Author

So, how do you solve this problem then?

What problem?

You said the current (or original) logic of _get_optimizer_param_groups cannot work for your case. I assume you got an exception? Or wrong behavior? So now, we extended the heuristic to cover that. But from your comments, I understood that you are not happy with this solution. But I assumed you have another solution?

So this means shared params, or not?

No. As I clearly wrote in my second post:
"Any nested modules will cause this problem, because the parameter access is recursive." (recursive meaning it will access parameters of submodules, but I thought this is clear).

Oh, I think I understand now. From this, originally I thought you mean shared parameters.

But the fix is easy, or not? Just this line:

for pn, p in m.named_parameters():

needs to be changed to:

for pn, p in m.named_parameters(recurse=False):

That's all, or not?

Now maybe yes, but in the beginning this would not have solved the issue because how the dicts with string parameter names are handled, not sure for the current code.

@albertz
Copy link
Member

albertz commented May 4, 2023

That's all, or not?

Now maybe yes, but in the beginning this would not have solved the issue because how the dicts with string parameter names are handled, not sure for the current code.

I don't understand how our current code changes anything about this. We just skip duplicate parameters, that's all. Even without this change, using named_parameters(recurse=False) should have worked for your case (when you don't have shared params).

@albertz albertz closed this as completed in 3d8c8ef May 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants