diff --git a/ariba/cdhit.py b/ariba/cdhit.py index 1d85210d..85749c2e 100644 --- a/ariba/cdhit.py +++ b/ariba/cdhit.py @@ -1,5 +1,6 @@ import tempfile import shutil +import sys import os import pyfastaq from ariba import common, external_progs @@ -47,14 +48,26 @@ def fake_run(self): @staticmethod - def _load_user_clusters_file(filename): + def _load_user_clusters_file(filename, all_ref_seqs, rename_dict=None): + if rename_dict is None: + rename_dict = {} + f = pyfastaq.utils.open_file_read(filename) clusters = {} used_names = set() for line in f: names_list = line.rstrip().split() - new_names = set(names_list) + to_remove = set() + + for name in names_list: + new_name = rename_dict.get(name, name) + if new_name not in all_ref_seqs: + to_remove.add(name) + print('WARNING: ignoring sequence', name, 'from clusters file because not in fasta file. This probably means it failed sanity checks - see the log files 01.filter.check_genes.log, 01.filter.check_metadata.log.', file=sys.stderr) + + names_list = [x for x in names_list if x not in to_remove] + new_names = set([rename_dict.get(name, name) for name in names_list]) if len(names_list) != len(new_names) or not new_names.isdisjoint(used_names): pyfastaq.utils.close(f) raise Error('Error in user-provided clusters file ' + filename + '. Non unique name found at this line:\n' + line) @@ -66,9 +79,10 @@ def _load_user_clusters_file(filename): return clusters - def run_get_clusters_from_file(self, clusters_infile): + def run_get_clusters_from_file(self, clusters_infile, all_ref_seqs, rename_dict=None): '''Instead of running cdhit, gets the clusters info from the input file.''' - clusters = self._load_user_clusters_file(clusters_infile) + if rename_dict is None: + rename_dict = {} # check that every sequence in the clusters file can be # found in the fasta file @@ -76,6 +90,8 @@ def run_get_clusters_from_file(self, clusters_infile): names_list_from_fasta_file = [seq.id for seq in seq_reader] names_set_from_fasta_file = set(names_list_from_fasta_file) + clusters = self._load_user_clusters_file(clusters_infile, all_ref_seqs, rename_dict=rename_dict) + if len(names_set_from_fasta_file) != len(names_list_from_fasta_file): raise Error('At least one duplicate name in fasta file ' + self.infile + '. Cannot continue') diff --git a/ariba/reference_data.py b/ariba/reference_data.py index 8d25e923..a945a4a4 100644 --- a/ariba/reference_data.py +++ b/ariba/reference_data.py @@ -8,7 +8,7 @@ class Error (Exception): pass -rename_sub_regex = re.compile(r'''[':!@,-]''') +rename_sub_regex = re.compile(r'''[|()\[\];"':!@,-]''') class ReferenceData: @@ -30,6 +30,7 @@ def __init__(self, self.genetic_code = genetic_code pyfastaq.sequences.genetic_code = self.genetic_code + self.rename_dict = None @classmethod @@ -353,15 +354,15 @@ def _rename_names_in_metadata(cls, meta_dict, rename_dict): def rename_sequences(self, outfile): - rename_dict = ReferenceData._seq_names_to_rename_dict(self.sequences.keys()) - if len(rename_dict): + self.rename_dict = ReferenceData._seq_names_to_rename_dict(self.sequences.keys()) + if len(self.rename_dict): print('Had to rename some sequences. See', outfile, 'for old -> new names', file=sys.stderr) with open(outfile, 'w') as f: - for old_name, new_name in sorted(rename_dict.items()): + for old_name, new_name in sorted(self.rename_dict.items()): print(old_name, new_name, sep='\t', file=f) - self.sequences = ReferenceData._rename_names_in_seq_dict(self.sequences, rename_dict) - self.metadata = ReferenceData._rename_names_in_metadata(self.metadata, rename_dict) + self.sequences = ReferenceData._rename_names_in_seq_dict(self.sequences, self.rename_dict) + self.metadata = ReferenceData._rename_names_in_metadata(self.metadata, self.rename_dict) def sequence_type(self, sequence_name): @@ -427,7 +428,7 @@ def cluster_with_cdhit(self, outprefix, seq_identity_threshold=0.9, threads=1, l ) if clusters_file is not None: - new_clusters = cdhit_runner.run_get_clusters_from_file(clusters_file) + new_clusters = cdhit_runner.run_get_clusters_from_file(clusters_file, self.sequences, rename_dict=self.rename_dict) elif nocluster: new_clusters = cdhit_runner.fake_run() else: diff --git a/ariba/tests/cdhit_test.py b/ariba/tests/cdhit_test.py index 82b74eb2..da2d9d26 100644 --- a/ariba/tests/cdhit_test.py +++ b/ariba/tests/cdhit_test.py @@ -102,7 +102,26 @@ def test_load_user_clusters_file_good_file(self): '2': {'seq5', 'seq6'} } - got = cdhit.Runner._load_user_clusters_file(infile) + got = cdhit.Runner._load_user_clusters_file(infile, {'seq' + str(i) for i in range(1,7,1)}) + self.assertEqual(expected, got) + + expected['2'] = {'seq5'} + got = cdhit.Runner._load_user_clusters_file(infile, {'seq' + str(i) for i in range(1,6,1)}) + self.assertEqual(expected, got) + + + def test_load_user_clusters_file_good_file_with_renaming(self): + '''test _load_user_clusters_file with good input file with some renamed''' + rename_dict = {'seq2': 'seq2_renamed', 'seq6': 'seq6_renamed'} + infile = os.path.join(data_dir, 'cdhit_test_load_user_clusters_file.good') + expected = { + '0': {'seq1', 'seq2_renamed', 'seq3'}, + '1': {'seq4'}, + '2': {'seq5', 'seq6_renamed'} + } + + names = {'seq1', 'seq2_renamed', 'seq3', 'seq4', 'seq5', 'seq6_renamed'} + got = cdhit.Runner._load_user_clusters_file(infile, names, rename_dict=rename_dict) self.assertEqual(expected, got) @@ -115,7 +134,7 @@ def test_load_user_clusters_file_bad_file(self): ] for filename in infiles: with self.assertRaises(cdhit.Error): - cdhit.Runner._load_user_clusters_file(filename) + cdhit.Runner._load_user_clusters_file(filename, {'seq1', 'seq2', 'seq3'}) def test_run_get_clusters_from_file(self): @@ -123,9 +142,23 @@ def test_run_get_clusters_from_file(self): fa_infile = os.path.join(data_dir, 'cdhit_test_run_get_clusters_from_dict.in.fa') clusters_infile = os.path.join(data_dir, 'cdhit_test_run_get_clusters_from_dict.in.clusters') r = cdhit.Runner(fa_infile) - clusters = r.run_get_clusters_from_file(clusters_infile) + clusters = r.run_get_clusters_from_file(clusters_infile, {'seq1', 'seq2', 'seq3'}) expected_clusters = { '0': {'seq1', 'seq2'}, '1': {'seq3'}, } self.assertEqual(clusters, expected_clusters) + + + def test_run_get_clusters_from_file_with_renaming(self): + '''test run_get_clusters_from_file with renaming''' + rename_dict = {'seq2': 'seq2_renamed'} + fa_infile = os.path.join(data_dir, 'cdhit_test_run_get_clusters_from_dict_rename.in.fa') + clusters_infile = os.path.join(data_dir, 'cdhit_test_run_get_clusters_from_dict.in.clusters') + r = cdhit.Runner(fa_infile) + clusters = r.run_get_clusters_from_file(clusters_infile, {'seq1', 'seq2_renamed', 'seq3'}, rename_dict=rename_dict) + expected_clusters = { + '0': {'seq1', 'seq2_renamed'}, + '1': {'seq3'}, + } + self.assertEqual(clusters, expected_clusters) diff --git a/ariba/tests/data/cdhit_test_run_get_clusters_from_dict_rename.in.fa b/ariba/tests/data/cdhit_test_run_get_clusters_from_dict_rename.in.fa new file mode 100644 index 00000000..6760b692 --- /dev/null +++ b/ariba/tests/data/cdhit_test_run_get_clusters_from_dict_rename.in.fa @@ -0,0 +1,6 @@ +>seq1 +ACGT +>seq2_renamed +AAAA +>seq3 +CCCC diff --git a/ariba/tests/reference_data_test.py b/ariba/tests/reference_data_test.py index f23d4cc7..9b12088a 100644 --- a/ariba/tests/reference_data_test.py +++ b/ariba/tests/reference_data_test.py @@ -278,12 +278,33 @@ def test_new_seq_name(self): def test_seq_names_to_rename_dict(self): '''Test _seq_names_to_rename_dict''' - names = {'foo', 'bar!', 'bar:', 'bar,', 'spam', 'eggs,123'} + names = { + 'foo', + 'bar!', + 'bar:', + 'bar,', + 'spam', + 'eggs,123', + 'ab(c1', + 'ab(c)2', + 'ab[c]3', + 'abc;4', + "abc'5", + 'abc"6', + 'abc|7', + } got = reference_data.ReferenceData._seq_names_to_rename_dict(names) expected = { 'bar!': 'bar_', 'bar,': 'bar__1', 'bar:': 'bar__2', + 'ab(c1': 'ab_c1', + 'ab(c)2': 'ab_c_2', + 'ab[c]3': 'ab_c_3', + 'abc;4': 'abc_4', + "abc'5": 'abc_5', + 'abc"6': 'abc_6', + 'abc|7': 'abc_7', 'eggs,123': 'eggs_123' } @@ -423,6 +444,17 @@ def test_rename_sequences(self): self.assertEqual(expected_seqs_dict, refdata.sequences) + expected_rename_dict = { + 'pres!abs3': 'pres_abs3', + 'pres\'abs1': 'pres_abs1', + 'pres_abs1': 'pres_abs1_1', + 'var,only1': 'var_only1', + 'var:only1': 'var_only1_1', + 'var_only1': 'var_only1_2', + } + + self.assertEqual(expected_rename_dict, refdata.rename_dict) + def test_sequence_type(self): '''Test sequence_type''' diff --git a/setup.py b/setup.py index 9261ebb0..5a7e806e 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ setup( ext_modules=[minimap_mod, fermilite_mod], name='ariba', - version='2.2.1', + version='2.2.2', description='ARIBA: Antibiotic Resistance Identification By Assembly', packages = find_packages(), package_data={'ariba': ['test_run_data/*']},