Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

key_based_loss update #188

Merged
merged 3 commits into from
Oct 14, 2024
Merged

key_based_loss update #188

merged 3 commits into from
Oct 14, 2024

Conversation

danielward27
Copy link
Owner

Deprecate fit_to_variational_target, and add fit_to_key_based_loss.

Reason:

  • It was set to return the parameters when the minimum loss was reached by default. This provides some protection against instability in training, but is bad for two reasons: 1) for very stochastic losses, it can result in the "best" parameters based on the minimum loss being far the actual best model 2) some objectives give useful gradients, without any expectation to minimize the "loss", e.g. for some contrastive and adversarial approaches. I have also renamed the function, to reflect the more general utility of the function (doesn't have to use flowjax distributions at all).

We also move the training loops into the same file, meaning the module flowjax.train.data_fit is deprecated.

@danielward27 danielward27 merged commit c006585 into main Oct 14, 2024
1 check passed
@danielward27 danielward27 deleted the key_based_loss branch October 14, 2024 11:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant