Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive BAL #330

Merged
merged 21 commits into from
Sep 17, 2024
Merged

Adaptive BAL #330

merged 21 commits into from
Sep 17, 2024

Conversation

M-R-Schaefer
Copy link
Contributor

This PR adds a few features:

  • adaptive BAL which selects up to a minimum inter sample distance instead of a fixed number of samples.
  • fixed force features for BAL
  • added full gradient features with random projections to the available feature maps
  • new visualizations to the BAL IPS node

@M-R-Schaefer
Copy link
Contributor Author

pre-commit.ci autofix

apax/bal/feature_maps.py Outdated Show resolved Hide resolved
Comment on lines -123 to +161
if self.return_raw:
if self.strategy == "raw":
(gb, gw), _ = jax.tree_util.tree_flatten(g_ll)

# g: n_atoms, 3, n_features
g = gw[:, :, :, 0]
else:
g_flat = jax.tree_map(
elif self.strategy == "sum":
g_summed = jax.tree_map(
lambda arr: jnp.reshape(jnp.sum(jnp.sum(arr, 0), 0), (-1,)), g_ll
)
(gb, gw), _ = jax.tree_util.tree_flatten(g_flat)
(gb, gw), _ = jax.tree_util.tree_flatten(g_summed)
g = [gw, gb]
g = jnp.concatenate(g)

elif self.strategy == "flatten":
g_flat = jax.tree_map(lambda arr: jnp.reshape(arr, (-1,)), g_ll)
(gb, gw), _ = jax.tree_util.tree_flatten(g_flat)
g = gw
else:
raise ValueError(f"unknown strategy: {self.strategy}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe doing that like:

def raw():
.
.
.
strategies = {"raw": raw}
.
.
.
if strategy in strategies:
    g = strategies[strategy](g_ll)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the implementations of the different options is 19 lines of code, I don't think further abstraction is required.

apax/bal/feature_maps.py Outdated Show resolved Hide resolved
@M-R-Schaefer M-R-Schaefer merged commit 6639e07 into main Sep 17, 2024
1 of 2 checks passed
@M-R-Schaefer M-R-Schaefer deleted the adaptive_bal branch September 17, 2024 13:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants