Skip to content

Commit

Permalink
Merge branch 'tmp' into chore/update-libs
Browse files Browse the repository at this point in the history
  • Loading branch information
morgangiraud committed Apr 3, 2024
2 parents f50b3a5 + 464747b commit a15fc51
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
10 changes: 9 additions & 1 deletion examples/qd_cmame.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@ def run(omegaConf: DictConfig) -> None:
grid_shape = config["grid"]["shape"]
assert len(grid_shape) == len(features_domain)
if True:
archive = GridArchive(solution_dim=len(grid_shape), dims=grid_shape, ranges=features_domain, seed=seed)
archive = GridArchive(
solution_dim=len(grid_shape),
dims=grid_shape,
ranges=features_domain,
seed=seed,
extra_fields={
"metadata": ((), object),
},
)
else:
bins = math.prod(grid_shape)
archive = CVTArchive(bins, features_domain, seed=seed, use_kd_tree=True)
Expand Down
9 changes: 4 additions & 5 deletions leniax/qd.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,7 @@ def run_qd_search(
fits.append(ind.fitness)
bcs.append(ind.features)
metadata.append(ind.get_config())
# scheduler.tell(fits, bcs, **metadata) # TODO: check the impact of removing metadata
scheduler.tell(fits, bcs)
scheduler.tell(fits, bcs, metadata=metadata)

# Log statistics
if itr % log_freq == 0 or itr == nb_iter:
Expand Down Expand Up @@ -355,10 +354,10 @@ def render_best(grid: ArchiveBase, fitness_threshold: float):
rng_key = leniax_utils.seed_everything(seed)

real_bests = []
for idx in grid._occupied_indices:
if abs(grid._objective_values[idx]) >= fitness_threshold:
for elite in grid:
if abs(elite["objective"]) >= fitness_threshold:
rng_key, subkey = jax.random.split(rng_key)
lenia = LeniaIndividual(grid._metadata[idx], subkey, grid._solutions[idx])
lenia = LeniaIndividual(elite["metadata"], subkey, elite["solution"])
real_bests.append(lenia)

logging.info(f"Found {len(real_bests)} beast!")
Expand Down
10 changes: 9 additions & 1 deletion scripts/qd_rastrigin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,15 @@ def run(omegaConf: DictConfig) -> None:
assert len(grid_shape) == len(features_domain)

bins = math.prod(grid_shape)
archive = CVTArchive(bins, features_domain, seed=seed, use_kd_tree=True)
archive = CVTArchive(
bins,
features_domain,
seed=seed,
use_kd_tree=True,
extra_fields={
"metadata": ((), object),
},
)
archive.qd_config = config

# Emitters
Expand Down
8 changes: 4 additions & 4 deletions tools/vis_behaviours_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def run() -> None:

fitness_threshold = qd_config["run_params"]["max_run_iter"]
real_bests = []
for idx in grid._occupied_indices:
if abs(grid._objective_values[idx]) >= fitness_threshold:
for elite in grid:
if abs(elite["objective"]) >= fitness_threshold:
rng_key, subkey = jax.random.split(rng_key)
lenia = LeniaIndividual(grid._metadata[idx], subkey, grid._solutions[idx])
lenia = LeniaIndividual(elite["metadata"], subkey, elite["solution"])
real_bests.append(lenia)

print(f"Found {len(real_bests)} beast in {file_path}")
Expand Down Expand Up @@ -89,7 +89,7 @@ def run() -> None:
]
behaviour_archive.add(lenia, 1024, behaviour, config)

print(len(behaviour_archive._occupied_indices))
print(len(behaviour_archive))
leniax_qd.save_heatmap(behaviour_archive, fitness_domain, f"{subdir}/behaviour_archive_heatmap.png")
with open(f"{subdir}/behaviour_final.p", "wb") as handle:
pickle.dump(behaviour_archive, handle, protocol=pickle.HIGHEST_PROTOCOL)
Expand Down

0 comments on commit a15fc51

Please sign in to comment.