Skip to content

Commit

Permalink
Merge branch 'v5_models' into 'master'
Browse files Browse the repository at this point in the history
V5 models

See merge request machine-learning/bonito!181
  • Loading branch information
iiSeymour committed May 20, 2024
2 parents 9bc1d58 + facf3e2 commit 454324a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
2 changes: 1 addition & 1 deletion bonito/cli/basecaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main(args):
else:
sys.stderr.write(f"> outputting {fmt.aligned} {fmt.name}\n")

if args.model_directory in models and args.model_directory not in os.listdir(__models_dir__):
if args.model_directory in models and not (__models_dir__ / args.model_directory).exists():
sys.stderr.write("> downloading model\n")
Downloader(__models_dir__).download(args.model_directory)

Expand Down
28 changes: 20 additions & 8 deletions bonito/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,18 @@ def download(self, fname):
return fpath

def _unzip(self, fpath):
unzip_path = fpath.with_suffix("")
unzip_path = fpath.parent.with_suffix("")
with ZipFile(fpath, 'r') as zfile:
zfile.extractall(path=unzip_path)
fpath.unlink()
return unzip_path


models = [
"[email protected]",
"[email protected]",
"[email protected]",

"[email protected]",
"[email protected]",
"[email protected]",
Expand Down Expand Up @@ -110,25 +114,33 @@ def _unzip(self, fpath):
"[email protected]",
"[email protected]",

"rna002_70bps_fast@v3",
"rna002_70bps_hac@v3",
"rna002_70bps_sup@v3",
"[email protected]",
"[email protected]",
"[email protected]",

"[email protected]",
"[email protected]",
"[email protected]",

"rna002_70bps_fast@v3",
"rna002_70bps_hac@v3",
"rna002_70bps_sup@v3",
]

training = [
"example_data_dna_r9.4.1_v0"
"example_data_dna_r9.4.1_v0",
"example_data_dna_r10.4.1_v0",
"example_data_rna004_v0",
]


def download_files(out_dir, file_list, show, force):
dl = Printer() if show else Downloader(out_dir, force)
for remote_file in file_list:
with contextlib.suppress(FileNotFoundError):
try:
dl.download(remote_file)
except FileNotFoundError:
print(f" - Failed to download: {remote_file}")


def main(args):
Expand All @@ -138,7 +150,7 @@ def main(args):
if args.models or args.all:
out_dir = __models_dir__ if args.out_dir is None else args.out_dir
download_files(out_dir, models, args.show, args.force)
elif args.training:
if args.training or args.all:
out_dir = __data_dir__ if args.out_dir is None else args.out_dir
download_files(out_dir, training, args.show, args.force)

Expand All @@ -152,7 +164,7 @@ def argparser():
parser.add_argument('--list', '--show', dest='show', action='store_true')
parser.add_argument('-f', '--force', action='store_true')

group = parser.add_mutually_exclusive_group()
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--all', action='store_true')
group.add_argument('--models', action='store_true')
group.add_argument('--training', action='store_true')
Expand Down
9 changes: 5 additions & 4 deletions bonito/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from operator import itemgetter
from importlib import import_module
from collections import defaultdict, OrderedDict
from pathlib import Path

import toml
import torch
Expand All @@ -26,12 +27,12 @@
pass


__dir__ = os.path.dirname(os.path.realpath(__file__))
__models_dir__ = os.path.join(__dir__, "models")
__data_dir__ = os.path.join(__dir__, "data")
__dir__ = Path(__file__).parent
__models_dir__ = __dir__ / "models"
__data_dir__ = __dir__ / "data"

split_cigar = re.compile(r"(?P<len>\d+)(?P<op>\D+)")
default_config = os.path.join(__dir__, "models/configs", "[email protected]")
default_config = __dir__ / "models/configs/[email protected]"


logger = getLogger('bonito')
Expand Down

0 comments on commit 454324a

Please sign in to comment.