Skip to content

Commit

Permalink
Update libs and switch from flake8 to ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
morgangiraud committed Apr 2, 2024
1 parent 32cd07e commit a240480
Show file tree
Hide file tree
Showing 94 changed files with 2,281 additions and 2,124 deletions.
69 changes: 37 additions & 32 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
update_results,
check_consistency,
get_task,
setup_jax
setup_jax,
)

from leniax import utils as leniax_utils
Expand All @@ -33,44 +33,44 @@ class RunCB(Callback):
def on_job_end(self, config: DictConfig, job_return) -> None:
ret = job_return.return_value
job_id = job_return.hydra_cfg.hydra.job.config_name
date_val = job_return.working_dir.split('/')[1]
timings = ret['timings']
device = ret['device']
save_dir = ret['save_dir']
date_val = job_return.working_dir.split("/")[1]
timings = ret["timings"]
device = ret["device"]
save_dir = ret["save_dir"]

stats_df = compute_statistics(timings)

stats_df['job_id'] = job_id
stats_df['day'] = date_val
stats_df["job_id"] = job_id
stats_df["day"] = date_val
stats_df = stats_df.sort_values(
['job_id', 'day', 'size', 'delta'],
["job_id", "day", "size", "delta"],
ascending=[True, False, True, False],
).set_index(['job_id', 'day', 'size', 'task'])
).set_index(["job_id", "day", "size", "task"])
logging.info(format_output(stats_df, job_id, device=device))

shutil.rmtree(save_dir)

results_dir = os.path.join(cdir, 'results')
results_dir = os.path.join(cdir, "results")
leniax_utils.check_dir(results_dir)

results_fullpath = os.path.join(results_dir, 'results.json')
results_fullpath = os.path.join(results_dir, "results.json")

all_results_df = update_results(results_fullpath, stats_df)
max_val = all_results_df['mean'].max() + all_results_df['stdev'].max()
max_val = all_results_df["mean"].max() + all_results_df["stdev"].max()
for indexes, sub_results_df in all_results_df.groupby(level=(0, 1)):
means = sub_results_df.loc[indexes]['mean']
stds = sub_results_df.loc[indexes]['stdev']
means = sub_results_df.loc[indexes]["mean"]
stds = sub_results_df.loc[indexes]["stdev"]
ax = means.unstack().plot.bar(
yerr=stds.unstack(),
title=job_id,
ylabel='Mean duration',
ylabel="Mean duration",
ylim=[0, max_val],
)
fig = ax.get_figure()
fig.savefig(os.path.join(results_dir, f'{indexes[0]}-{indexes[1]}.png'))
fig.savefig(os.path.join(results_dir, f"{indexes[0]}-{indexes[1]}.png"))


@hydra.main(config_path=config_path, config_name=config_name)
@hydra.main(version_base="1.1", config_path=config_path, config_name=config_name)
def bench(omegaConf: DictConfig) -> None:
"""Leniax benchmark
Expand All @@ -82,7 +82,11 @@ def bench(omegaConf: DictConfig) -> None:
$ taskset -c 0 python run.py bench.task='single_run' bench.device='gpu'
$ python run.py bench.task='single_run' bench.device='gpu' run_params.nb_init_search=64 world_params.nb_channels=16
$ python run.py \
bench.task='single_run' \
bench.device='gpu' \
run_params.nb_init_search=64 \
world_params.nb_channels=16
"""
device = omegaConf.bench.device
jax = setup_jax(device)
Expand All @@ -97,27 +101,28 @@ def bench(omegaConf: DictConfig) -> None:
logging.info(f"Output directory: {save_dir}")

# We seed the whole python environment.
rng_key = leniax_utils.seed_everything(config['run_params']['seed'])
rng_key = leniax_utils.seed_everything(config["run_params"]["seed"])

tasks = config['bench']['tasks']
if type(tasks) != list:
tasks = config["bench"]["tasks"]

if not isinstance(tasks, list):
tasks = [tasks]
burnin = config['bench']['burnin']
multipliers = config['bench']['multipliers']
repetitions = config['bench']['repetitions']
burnin = config["bench"]["burnin"]
multipliers = config["bench"]["multipliers"]
repetitions = config["bench"]["repetitions"]

all_tasks = {}
for task in tasks:
task_module, task_identifier = get_task(task)
all_tasks[task_identifier] = {'tm': task_module}
all_tasks[task_identifier] = {"tm": task_module}

runs = sorted(itertools.product(all_tasks.keys(), multipliers))

for run in runs:
rng_key, subkey = jax.random.split(rng_key)
current_task = all_tasks[run[0]]
run_fn = current_task['tm'].make_run_fn(subkey, copy.deepcopy(config), run[1])
current_task['fn'] = run_fn
run_fn = current_task["tm"].make_run_fn(subkey, copy.deepcopy(config), run[1])
current_task["fn"] = run_fn

if len(runs) == 0:
logging.info("Nothing to do")
Expand All @@ -131,7 +136,7 @@ def bench(omegaConf: DictConfig) -> None:

for run_id, mul in runs:
# use end-to-end runtime for repetition estimation
repetitions[(run_id, mul)] = estimate_repetitions(all_tasks[run_id]['fn'])
repetitions[(run_id, mul)] = estimate_repetitions(all_tasks[run_id]["fn"])
else:
repetitions = {(run_id, mul): repetitions for run_id, mul in runs}

Expand All @@ -147,7 +152,7 @@ def bench(omegaConf: DictConfig) -> None:
with pbar:
for run_id, mul in all_runs:
with Timer() as t:
res = all_tasks[run_id]['fn']()
res = all_tasks[run_id]["fn"]()

# YOWO (you only warn once)
if not checked[(run_id, mul)]:
Expand All @@ -172,9 +177,9 @@ def bench(omegaConf: DictConfig) -> None:
assert len(timings[(run_id, mul)]) == repetitions[(run_id, mul)] + burnin

return {
'timings': timings,
'device': device,
'save_dir': save_dir,
"timings": timings,
"device": device,
"save_dir": save_dir,
}


Expand Down
37 changes: 18 additions & 19 deletions benchmarks/tasks/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def K_init(rng, shape):


def all_states_loss_fn(rng_key, preds, targets):
x = preds['cells']
x_true = targets['cells']
x = preds["cells"]
x_true = targets["cells"]

loss = jnp.mean(jnp.square(x - x_true))

Expand All @@ -33,8 +33,7 @@ class BenchModel(nn.Module):

@nn.compact
def __call__(self, rng_key, cells_state, dt=1.0):

K = self.param('K', K_init, self.K_shape)
K = self.param("K", K_init, self.K_shape)
potential = self.get_potential_fn(cells_state, K)

y = potential
Expand All @@ -49,33 +48,33 @@ def __call__(self, rng_key, cells_state, dt=1.0):


def make_run_fn(rng_key, config, multiplier):
config['bench']['nb_k'] *= multiplier
config['world_params']['R'] = int(config['world_params']['R'] * math.log(multiplier + 2))
config['world_params']['nb_channels'] *= multiplier
config['run_params']['nb_init_search'] *= multiplier
config['run_params']['max_run_iter'] *= multiplier

R = config['world_params']['R']
T = config['world_params']['T']
C = config['world_params']['nb_channels']
N_init = config['run_params']['nb_init_search']
world_size = config['render_params']['world_size']
max_iter = config['run_params']['max_run_iter']
config["bench"]["nb_k"] *= multiplier
config["world_params"]["R"] = int(config["world_params"]["R"] * math.log(multiplier + 2))
config["world_params"]["nb_channels"] *= multiplier
config["run_params"]["nb_init_search"] *= multiplier
config["run_params"]["max_run_iter"] *= multiplier

R = config["world_params"]["R"]
T = config["world_params"]["T"]
C = config["world_params"]["nb_channels"]
N_init = config["run_params"]["nb_init_search"]
world_size = config["render_params"]["world_size"]
max_iter = config["run_params"]["max_run_iter"]
cells_shape = [N_init] + world_size + [C]

subkeys = jax.random.split(rng_key)
init_cells = jax.random.uniform(subkeys[0], [N_init] + world_size + [C])
targets = {'cells': jax.random.uniform(subkeys[1], [N_init] + world_size + [C])}
targets = {"cells": jax.random.uniform(subkeys[1], [N_init] + world_size + [C])}

K_shape = (2 * R, 2 * R, 1, 32)
m = BenchModel(
features=[4, C],
K_shape=K_shape,
get_potential_fn=build_get_potential_fn(K_shape, fft=False, channel_first=False)
get_potential_fn=build_get_potential_fn(K_shape, fft=False, channel_first=False),
)
rng_key, subkey = jax.random.split(rng_key)
variables = m.init(rng_key, subkey, jnp.ones(cells_shape))
vars, params = variables.pop('params')
vars, params = variables.pop("params")
del variables # Delete variables to avoid wasting resources

t_state = train_state.TrainState.create(
Expand Down
37 changes: 18 additions & 19 deletions benchmarks/tasks/pipeline_with_rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def K_init(rng, shape):


def all_states_loss_fn(rng_key, preds, targets):
x = preds['cells']
x_true = targets['cells']
x = preds["cells"]
x_true = targets["cells"]

loss = jnp.mean(jnp.square(x - x_true))

Expand All @@ -33,8 +33,7 @@ class BenchModel(nn.Module):

@nn.compact
def __call__(self, rng_key, cells_state, dt=1.0):

K = self.param('K', K_init, self.K_shape)
K = self.param("K", K_init, self.K_shape)
potential = self.get_potential_fn(cells_state, K)

y = potential
Expand All @@ -49,33 +48,33 @@ def __call__(self, rng_key, cells_state, dt=1.0):


def make_run_fn(rng_key, config, multiplier):
config['bench']['nb_k'] *= multiplier
config['world_params']['R'] = int(config['world_params']['R'] * math.log(multiplier + 2))
config['world_params']['nb_channels'] *= multiplier
config['run_params']['nb_init_search'] *= multiplier
config['run_params']['max_run_iter'] *= multiplier

R = config['world_params']['R']
T = config['world_params']['T']
C = config['world_params']['nb_channels']
N_init = config['run_params']['nb_init_search']
world_size = config['render_params']['world_size']
max_iter = config['run_params']['max_run_iter']
config["bench"]["nb_k"] *= multiplier
config["world_params"]["R"] = int(config["world_params"]["R"] * math.log(multiplier + 2))
config["world_params"]["nb_channels"] *= multiplier
config["run_params"]["nb_init_search"] *= multiplier
config["run_params"]["max_run_iter"] *= multiplier

R = config["world_params"]["R"]
T = config["world_params"]["T"]
C = config["world_params"]["nb_channels"]
N_init = config["run_params"]["nb_init_search"]
world_size = config["render_params"]["world_size"]
max_iter = config["run_params"]["max_run_iter"]
cells_shape = [N_init] + world_size + [C]

subkeys = jax.random.split(rng_key)
init_cells = jax.random.uniform(subkeys[0], [N_init] + world_size + [C])
targets = {'cells': jax.random.uniform(subkeys[1], [N_init] + world_size + [C])}
targets = {"cells": jax.random.uniform(subkeys[1], [N_init] + world_size + [C])}

K_shape = (2 * R, 2 * R, 1, 32)
m = BenchModel(
features=[4, C],
K_shape=K_shape,
get_potential_fn=build_get_potential_fn(K_shape, fft=False, channel_first=False)
get_potential_fn=build_get_potential_fn(K_shape, fft=False, channel_first=False),
)
rng_key, subkey = jax.random.split(rng_key)
variables = m.init(rng_key, subkey, jnp.ones(cells_shape))
vars, params = variables.pop('params')
vars, params = variables.pop("params")
del variables # Delete variables to avoid wasting resources

t_state = train_state.TrainState.create(
Expand Down
20 changes: 10 additions & 10 deletions benchmarks/tasks/potential_cfirst.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@


def make_run_fn(rng_key, config, multiplier):
config['bench']['nb_k'] *= multiplier
config['world_params']['R'] = int(config['world_params']['R'] * math.log(multiplier + 2))
config['world_params']['nb_channels'] *= multiplier
config['run_params']['nb_init_search'] *= multiplier
config['run_params']['max_run_iter'] *= multiplier

R = config['world_params']['R']
C = config['world_params']['nb_channels']
N_init = config['run_params']['nb_init_search']
world_size = config['render_params']['world_size']
config["bench"]["nb_k"] *= multiplier
config["world_params"]["R"] = int(config["world_params"]["R"] * math.log(multiplier + 2))
config["world_params"]["nb_channels"] *= multiplier
config["run_params"]["nb_init_search"] *= multiplier
config["run_params"]["max_run_iter"] *= multiplier

R = config["world_params"]["R"]
C = config["world_params"]["nb_channels"]
N_init = config["run_params"]["nb_init_search"]
world_size = config["render_params"]["world_size"]
subkeys = jax.random.split(rng_key)

state = jax.random.uniform(subkeys[0], [N_init, C] + world_size)
Expand Down
24 changes: 12 additions & 12 deletions benchmarks/tasks/potential_cfirst_raw_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@


def make_run_fn(rng_key, config, multiplier):
config['bench']['nb_k'] *= multiplier
config['world_params']['R'] = int(config['world_params']['R'] * math.log(multiplier + 2))
config['world_params']['nb_channels'] *= multiplier
config['run_params']['nb_init_search'] *= multiplier
config['run_params']['max_run_iter'] *= multiplier

R = config['world_params']['R']
C = config['world_params']['nb_channels']
N_init = config['run_params']['nb_init_search']
world_size = config['render_params']['world_size']
config["bench"]["nb_k"] *= multiplier
config["world_params"]["R"] = int(config["world_params"]["R"] * math.log(multiplier + 2))
config["world_params"]["nb_channels"] *= multiplier
config["run_params"]["nb_init_search"] *= multiplier
config["run_params"]["max_run_iter"] *= multiplier

R = config["world_params"]["R"]
C = config["world_params"]["nb_channels"]
N_init = config["run_params"]["nb_init_search"]
world_size = config["render_params"]["world_size"]
subkeys = jax.random.split(rng_key)

state = jax.random.uniform(subkeys[0], [N_init, C] + world_size)
Expand All @@ -33,8 +33,8 @@ def make_run_fn(rng_key, config, multiplier):

@jax.jit
def apply_fn(state):
padded_state = jax.numpy.pad(state, padding, mode='wrap')
return jax.lax.conv_general_dilated(padded_state, K, (1, 1), 'VALID', feature_group_count=C)
padded_state = jax.numpy.pad(state, padding, mode="wrap")
return jax.lax.conv_general_dilated(padded_state, K, (1, 1), "VALID", feature_group_count=C)

def bench_fn():
potential = apply_fn(state)
Expand Down
20 changes: 10 additions & 10 deletions benchmarks/tasks/potential_cfirst_tcnone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@


def make_run_fn(rng_key, config, multiplier):
config['bench']['nb_k'] *= multiplier
config['world_params']['R'] = int(config['world_params']['R'] * math.log(multiplier + 2))
config['world_params']['nb_channels'] *= multiplier
config['run_params']['nb_init_search'] *= multiplier
config['run_params']['max_run_iter'] *= multiplier

R = config['world_params']['R']
C = config['world_params']['nb_channels']
N_init = config['run_params']['nb_init_search']
world_size = config['render_params']['world_size']
config["bench"]["nb_k"] *= multiplier
config["world_params"]["R"] = int(config["world_params"]["R"] * math.log(multiplier + 2))
config["world_params"]["nb_channels"] *= multiplier
config["run_params"]["nb_init_search"] *= multiplier
config["run_params"]["max_run_iter"] *= multiplier

R = config["world_params"]["R"]
C = config["world_params"]["nb_channels"]
N_init = config["run_params"]["nb_init_search"]
world_size = config["render_params"]["world_size"]
subkeys = jax.random.split(rng_key)

state = jax.random.uniform(subkeys[0], [N_init, C] + world_size)
Expand Down
Loading

0 comments on commit a240480

Please sign in to comment.