Skip to content

Commit

Permalink
Rewrote model checkpoints cleaning script.
Browse files Browse the repository at this point in the history
  • Loading branch information
madlag committed Mar 30, 2021
1 parent fd8333f commit d93cbae
Show file tree
Hide file tree
Showing 4 changed files with 3,631 additions and 1,202 deletions.
12 changes: 10 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY style test
.PHONY: style test

# Run code quality checks
style:
Expand All @@ -7,4 +7,12 @@ style:

# Run tests for the library
test:
python -m pytest nn_pruning
python -m pytest nn_pruning

build_dist:
rm -fr build
rm -fr dist
python -m build

pypi_upload: build_dist
python -m twine upload dist/*
104 changes: 83 additions & 21 deletions analysis/cleanup.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,87 @@
import click
from pathlib import Path
import sys
import json
from datetime import datetime

with open(sys.argv[1]) as f:
whitelist = json.load(f)["checkpoints"]

print("whitelist len= ", len(whitelist))

base = Path("/data_2to/devel_data/nn_pruning/output")
for dir in base.iterdir():
set_dir = dir.resolve()
for hp_name in set_dir.iterdir():
for checkpoint in hp_name.iterdir():
checkpoint_str = str(checkpoint)
print(checkpoint)
if checkpoint_str in whitelist or "squad_test_large_regu_10_d0.25" in checkpoint_str:
print("KEEPING", checkpoint)
continue
else:
model_file = checkpoint / "pytorch_model.bin"
if model_file.exists():
print("REMOVING", model_file)
# model_file.unlink()
@click.group()
@click.pass_context
def cli(ctx):
ctx.obj = {}

@cli.command()
@click.pass_context
@click.argument("basedir", type=click.Path(resolve_path=True), nargs = 1)
@click.argument('result_files', type=click.Path(resolve_path=True), nargs=-1) #help="Result files used as whitelist (files/results_*.json for example) "
@click.option('--execute', is_flag=True)
def main(ctxt, basedir, result_files, execute):
if execute:
click.echo("EXECUTING")
else:
click.echo("DRY RUN")
click.echo("Base dir")
click.echo(" " + basedir)
click.echo()
click.echo("Result files:")
for r in result_files:
click.echo(" " + r)
click.echo()

if len(result_files) == 0:
click.Abort("Empty result files")

whitelist = {}
for filename in result_files:
with open(filename) as f:
single_whitelist = json.load(f)["checkpoints"]
for k in single_whitelist:
whitelist[k] = True

click.echo("Whitelisted checkpoints:")
whitelisted = len(whitelist)
click.echo(f" {whitelisted}")
click.echo()

kept = {}
removed = {}
removed_size = 0

for dir in Path(basedir).iterdir():
set_dir = dir.resolve()
for hp_name in set_dir.iterdir():
for checkpoint in hp_name.iterdir():
checkpoint_str = str(checkpoint)
if checkpoint_str in whitelist:
kept[checkpoint_str] = True
else:
model_file = checkpoint / "pytorch_model.bin"
if model_file.exists():
removed[model_file] = True
removed_size += model_file.stat().st_size

click.echo("Kept / Whitelisted")
click.echo(f" {len(kept)} / {whitelisted}")

click.echo()
click.echo("Removed")
click.echo(f" {len(removed)} pytorch_model.bin files")
click.echo(" %0.2fGB" % (removed_size / (1024**3)))

if execute:
d = datetime.now().replace(microsecond=0)
d = d.isoformat().replace(":", "_").replace("T", "_")
removed_filename = "files/removed_files_%s.json" % d
click.echo()
with Path(removed_filename).open("w") as f:
for model_file in removed:
f.write(str(model_file) + "\n")

for model_file in removed:
# click.echo("REMOVING", model_file)
model_file.unlink()

click.echo("Wrote removed files list to:")
click.echo(f" {removed_filename}")


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion analysis/create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def run(self):
checkpoint_path = "/data_2to/devel_data/nn_pruning/output/squad_test_9_fullpatch4/hp_od-__data_2to__devel_data__nn_pruning__output__squad_test_9_fullpatch4___es-steps_nte20_ls250_stl50_est5000_rn-__data_2to__devel_data__nn_pruning__output__squad_test_9_fullpatch4_--6cb2db64e9a885f1/checkpoint-110000"
checkpoint_path = "/data_2to/devel_data/nn_pruning/output/squad_test_9_fullpatch6/hp_od-__data_2to__devel_data__nn_pruning__output__squad_test_9_fullpatch6___es-steps_nte20_ls250_stl50_est5000_rn-__data_2to__devel_data__nn_pruning__output__squad_test_9_fullpatch6_--5f772c87c5edbc85/checkpoint-100000"
checkpoint_path = "/data_2to/devel_data/nn_pruning/output/squad_test_8_mvp_lt/hp_od-__data_2to__devel_data__nn_pruning__output__squad_test_8_mvp_lt___es-steps_nte20_ls250_stl50_est5000_rn-__data_2to__devel_data__nn_pruning__output__squad_test_8_mvp_lt___dpm-si--7fe43555f854fbb6/checkpoint-110000"
kind = "unstruct"
checkpoint_path = "/data_2to/devel_data/nn_pruning/output/squad_test4/hp_od-__data_2to__devel_data__nn_pruning__output__squad4___es-steps_nte20_ls250_stl50_est5000_rn-__data_2to__devel_data__nn_pruning__output__squad4___dpm-sigmoied_threshold:1d_alt_ap--17cd29ad8a563746/checkpoint-110000"
#kind = "unstruct"
kind = "hybrid"
task = "squadv1"

git_base_path = (Path(__file__).resolve().parent.parent.parent / "models").resolve()
Expand Down
Loading

0 comments on commit d93cbae

Please sign in to comment.