Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TrainOps + Two Stage Optim for depparse #1337

Merged
merged 3 commits into from
Jan 31, 2024
Merged

TrainOps + Two Stage Optim for depparse #1337

merged 3 commits into from
Jan 31, 2024

Conversation

Jemoka
Copy link
Member

@Jemoka Jemoka commented Jan 28, 2024

  • Optional two-stage optimization scheme after first optim converged
  • wandb gradient logging
  • Model checkpointing with optimizer

@Jemoka Jemoka requested a review from AngledLuffa January 28, 2024 18:17
@Jemoka Jemoka changed the base branch from main to dev January 28, 2024 18:17
@Jemoka
Copy link
Member Author

Jemoka commented Jan 28, 2024

ack; there is a patch here which is embedded in a commit in #1336 to get the tests to stop barfing. @AngledLuffa would love your thoughts on this

@AngledLuffa
Copy link
Collaborator

I still hope to be able to review things one piece at a time - mind pulling that patch into this PR? What I normally do in situations like this is args.get('second_optim', None) and the handle the case of no predefined value for the argument in a reasonable manner.

I also do kind of wonder, it doesn't need an optimizer when it's in eval mode, right? So how much memory & time is it wasting by creating that optimizer in this setting anyway?

@Jemoka
Copy link
Member Author

Jemoka commented Jan 29, 2024

sg; will address these first thing tmr if that's ok. Its probably not too much time lost, but you are right that this shouldn't be created during eval mode, etc.

@Jemoka
Copy link
Member Author

Jemoka commented Jan 29, 2024

I still hope to be able to review things one piece at a time - mind pulling that patch into this PR? What I normally do in situations like this is args.get('second_optim', None) and the handle the case of no predefined value for the argument in a reasonable manner.

I also do kind of wonder, it doesn't need an optimizer when it's in eval mode, right? So how much memory & time is it wasting by creating that optimizer in this setting anyway?

@Jemoka Jemoka closed this Jan 29, 2024
@Jemoka Jemoka reopened this Jan 29, 2024
@Jemoka
Copy link
Member Author

Jemoka commented Jan 29, 2024

I still hope to be able to review things one piece at a time - mind pulling that patch into this PR? What I normally do in situations like this is args.get('second_optim', None) and the handle the case of no predefined value for the argument in a reasonable manner.

I also do kind of wonder, it doesn't need an optimizer when it's in eval mode, right? So how much memory & time is it wasting by creating that optimizer in this setting anyway?

Done. Addressed by dac72d0.

import wandb
# track gradients!
wandb.watch(self.model, log_freq=4, log="all", log_graph=True)
if ignore_model_config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? To make sure the model shapes are the same? But this doesn't allow for training with a different optimizer, unless I'm mistaken about that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that might work would be separating the parameters into two separate maps, or possibly when the model itself is constructed, the model shapes parameters are reused from the save file whereas the passed in args are used for the embedding locations and the optimizer

self.model = self.model.to(device)
self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0))
self.primary_optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would worry that having primary and secondary optimizers is kind of wasteful in terms of space. The primary optimizer is kept around after switching, but never used again, and its derivatives etc will take up a lot of GPU

What I did with the constituency parser - which is by no means definitive - was to just keep one optimizer as part of the trainer. At the end of each epoch, if the switching condition was triggered, throw away the old one. We could then save a flag in the save file which marked whether the optimizer was the first or second optimizer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to discuss alternatives for how best to keep track of which optimizer is being used and how best to save / load them

@@ -191,7 +194,10 @@ def train(args):
wandb.run.define_metric('dev_score', summary='max')

logger.info("Training parser...")
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'])
if args["continue_from"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any mechanism that saves & loads the optimizer? I don't see such a thing. I think that would be a rather useful feature to add to the saving / loading & continuing.

The way I did this for the sentiment & constituency was to keep a "checkpoint" file which was always the latest optimizer state alongside the regular save file, which was the best so far. Then, if the checkpoint file exists when doing a training run with the same save_name, it would automatically load that checkpoint and continue from there. I suppose there should also be a flag which specifies that the user wants to discard the checkpoint, but I haven't done that so far.

I think we want to have a unified approach for the different models and their loading / checkpointing mechanisms

@Jemoka Jemoka requested a review from AngledLuffa January 30, 2024 02:00
@Jemoka
Copy link
Member Author

Jemoka commented Jan 30, 2024

Done, the last few comments should address the two optimizers situation (only loads one, which one to load is stored in args and saved with the model) as well as implement the requested checkpointing w/ optim. Running a training run now to confirm everything.

@AngledLuffa
Copy link
Collaborator

Looks pretty good, I'd say. Thanks for making those changes!

There's a unit test for training the depparse (doesn't check the results, just checks that it runs)

stanza/tests/depparse/test_parser.py

How would you feel about extending it to check that it is producing the expected checkpoints and switches optimizer when expected? I can take that on myself tomorrow afternoon, actually. Probably isn't too difficult to test those items.

I think this can all be squashed into one change, what do you think?

@AngledLuffa
Copy link
Collaborator

I can take on the testing and the squashing tomorrow, actually, a couple of the other tasks I had lined up for this week are in "waiting for response" mode now

wandb.watch(self.model, log_freq=4, log="all", log_graph=True)

def __init_optim(self):
if not (self.args.get("second_stage", False) and self.args.get('second_optim')):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one small readability comment might be to switch the order of the boolean so that it's easier to read: if second_optim and second_stage, build second optimizer, otherwise build the first

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, addressed by 0738498.

logger.info("Switching to AMSGrad")
if not is_second_stage and args.get('second_optim', None) is not None:
logger.info("Switching to second optimizer: {}".format(args.get('second_optim', None)))
args["second_stage"] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hang on - since the tagger is now copying the whole args with deepcopy (which I still think might not be necessary), will changing this field actually change the args in the tagger?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we pass the edited args into the constructor of Trainer, which, given we passed a model_file, will be passed into

if model_file is not None:
# load everything from file
self.load(model_file, pretrain, args, foundation_cache)

from there, because the args are then passed in to load, it overwrite anything that the Trainer originally had:

if args is not None: self.args.update(args)

Apologies in advance if I mixed something up—but I'm pretty sure this should work in terms of stage switching. Also, empirically it seems to work with a single run I did.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is correct. Thanks for pointing that out

Jemoka and others added 3 commits January 30, 2024 16:28
1) save optimizer in checkpoint 2) two stage tracking using args

Save the checkpoints after switching the optimizer, if applicable, so that reloading uses the new optimizer once it has been created
…s and the checkpoints are loadable

Add a flag which forces the optimizer to switch after a certain number of steps - useful for writing tests which check the behavior of the second optimizer
…t when a checkpoint gets loaded, the training continues from the position it was formerly at rather than restarting from 0

Report some details of the model being loaded after loading it
@AngledLuffa AngledLuffa merged commit 5bc22dd into dev Jan 31, 2024
2 checks passed
@AngledLuffa AngledLuffa deleted the depparse-ops branch January 31, 2024 00:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants