diff --git a/bonito/cli/basecaller.py b/bonito/cli/basecaller.py index 804007a..2253af0 100644 --- a/bonito/cli/basecaller.py +++ b/bonito/cli/basecaller.py @@ -43,7 +43,7 @@ def main(args): else: sys.stderr.write(f"> outputting {fmt.aligned} {fmt.name}\n") - if args.model_directory in models and args.model_directory not in os.listdir(__models_dir__): + if args.model_directory in models and not (__models_dir__ / args.model_directory).exists(): sys.stderr.write("> downloading model\n") Downloader(__models_dir__).download(args.model_directory) diff --git a/bonito/cli/download.py b/bonito/cli/download.py index 1f9b31c..952fd47 100755 --- a/bonito/cli/download.py +++ b/bonito/cli/download.py @@ -66,7 +66,7 @@ def download(self, fname): return fpath def _unzip(self, fpath): - unzip_path = fpath.with_suffix("") + unzip_path = fpath.parent.with_suffix("") with ZipFile(fpath, 'r') as zfile: zfile.extractall(path=unzip_path) fpath.unlink() @@ -74,6 +74,10 @@ def _unzip(self, fpath): models = [ + "dna_r10.4.1_e8.2_400bps_fast@v5.0.0", + "dna_r10.4.1_e8.2_400bps_hac@v5.0.0", + "dna_r10.4.1_e8.2_400bps_sup@v5.0.0", + "dna_r10.4.1_e8.2_400bps_fast@v4.3.0", "dna_r10.4.1_e8.2_400bps_hac@v4.3.0", "dna_r10.4.1_e8.2_400bps_sup@v4.3.0", @@ -110,25 +114,33 @@ def _unzip(self, fpath): "dna_r9.4.1_e8_hac@v3.3", "dna_r9.4.1_e8_fast@v3.4", - "rna002_70bps_fast@v3", - "rna002_70bps_hac@v3", - "rna002_70bps_sup@v3", + "rna004_130bps_fast@v5.0.0", + "rna004_130bps_hac@v5.0.0", + "rna004_130bps_sup@v5.0.0", "rna004_130bps_fast@v3.0.1", "rna004_130bps_hac@v3.0.1", "rna004_130bps_sup@v3.0.1", + + "rna002_70bps_fast@v3", + "rna002_70bps_hac@v3", + "rna002_70bps_sup@v3", ] training = [ - "example_data_dna_r9.4.1_v0" + "example_data_dna_r9.4.1_v0", + "example_data_dna_r10.4.1_v0", + "example_data_rna004_v0", ] def download_files(out_dir, file_list, show, force): dl = Printer() if show else Downloader(out_dir, force) for remote_file in file_list: - with contextlib.suppress(FileNotFoundError): + try: dl.download(remote_file) + except FileNotFoundError: + print(f" - Failed to download: {remote_file}") def main(args): @@ -138,7 +150,7 @@ def main(args): if args.models or args.all: out_dir = __models_dir__ if args.out_dir is None else args.out_dir download_files(out_dir, models, args.show, args.force) - elif args.training: + if args.training or args.all: out_dir = __data_dir__ if args.out_dir is None else args.out_dir download_files(out_dir, training, args.show, args.force) @@ -152,7 +164,7 @@ def argparser(): parser.add_argument('--list', '--show', dest='show', action='store_true') parser.add_argument('-f', '--force', action='store_true') - group = parser.add_mutually_exclusive_group() + group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--all', action='store_true') group.add_argument('--models', action='store_true') group.add_argument('--training', action='store_true') diff --git a/bonito/util.py b/bonito/util.py index 7574ec0..b296ec1 100755 --- a/bonito/util.py +++ b/bonito/util.py @@ -11,6 +11,7 @@ from operator import itemgetter from importlib import import_module from collections import defaultdict, OrderedDict +from pathlib import Path import toml import torch @@ -26,12 +27,12 @@ pass -__dir__ = os.path.dirname(os.path.realpath(__file__)) -__models_dir__ = os.path.join(__dir__, "models") -__data_dir__ = os.path.join(__dir__, "data") +__dir__ = Path(__file__).parent +__models_dir__ = __dir__ / "models" +__data_dir__ = __dir__ / "data" split_cigar = re.compile(r"(?P\d+)(?P\D+)") -default_config = os.path.join(__dir__, "models/configs", "dna_r9.4.1@v3.1.toml") +default_config = __dir__ / "models/configs/dna_r9.4.1@v3.1.toml" logger = getLogger('bonito')