From 6127784969af69126770e3b06246b61c884bab3f Mon Sep 17 00:00:00 2001 From: Martin Hunt Date: Mon, 18 Apr 2016 10:29:58 +0100 Subject: [PATCH] Do not assume cluster representative comes first --- ariba/cdhit.py | 27 ++++++++++--------- ariba/tests/cdhit_test.py | 5 ++-- .../cdhit_test_parse_cluster_info_file.infile | 2 ++ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/ariba/cdhit.py b/ariba/cdhit.py index 208f4350..7d5ff00b 100644 --- a/ariba/cdhit.py +++ b/ariba/cdhit.py @@ -56,8 +56,8 @@ def _get_ids(self, infile): @staticmethod def _parse_cluster_info_file(infile, cluster_representatives): f = pyfastaq.utils.open_file_read(infile) - clusters = {} - current_cluster = None + cluster_sets = {} + found_representatives = {} # store cluster number -> representative name for line in f: data = line.rstrip().split() @@ -66,21 +66,24 @@ def _parse_cluster_info_file(infile, cluster_representatives): raise Error('Unexpected format of line from cdhit output file "' + infile + '". Line is:\n' + line) seqname = seqname[1:-3] - if data[3] == '*': - current_cluster = seqname - assert current_cluster not in clusters - clusters[current_cluster] = {current_cluster} - else: - assert current_cluster in clusters - if seqname in clusters[current_cluster]: - raise Error('Duplicate name "' + seqname + '" found in cluster ' + current_cluster) + cluster_number = int(data[0]) # this is the cluster number used by cdhit + if cluster_number not in cluster_sets: + cluster_sets[cluster_number] = set() - clusters[current_cluster].add(seqname) + cluster_sets[cluster_number].add(seqname) + + if data[3] == '*': + found_representatives[cluster_number] = seqname pyfastaq.utils.close(f) - if set(clusters.keys()) != cluster_representatives: + + if set(found_representatives.values()) != cluster_representatives: raise Error('Mismatch in cdhit output sequence names between fasta file and clusters file. Cannot continue') + clusters = {} + for cluster_number, cluster_name in found_representatives.items(): + clusters[cluster_name] = cluster_sets[cluster_number] + return clusters diff --git a/ariba/tests/cdhit_test.py b/ariba/tests/cdhit_test.py index c0421d7c..11fb2cbc 100644 --- a/ariba/tests/cdhit_test.py +++ b/ariba/tests/cdhit_test.py @@ -25,12 +25,13 @@ def test_get_ids(self): def test_parse_cluster_info_file(self): '''test _parse_cluster_info_file''' - cluster_representatives = {'seq1', 'seq4'} + cluster_representatives = {'seq1', 'seq4', 'seq6'} infile = os.path.join(data_dir, 'cdhit_test_parse_cluster_info_file.infile') got_clusters = cdhit.Runner._parse_cluster_info_file(infile, cluster_representatives) expected_clusters = { 'seq1': {'seq1', 'seq2', 'seq3'}, - 'seq4': {'seq4'} + 'seq4': {'seq4'}, + 'seq6': {'seq5', 'seq6'}, } self.assertEqual(expected_clusters, got_clusters) diff --git a/ariba/tests/data/cdhit_test_parse_cluster_info_file.infile b/ariba/tests/data/cdhit_test_parse_cluster_info_file.infile index 69a6f14b..548e060b 100644 --- a/ariba/tests/data/cdhit_test_parse_cluster_info_file.infile +++ b/ariba/tests/data/cdhit_test_parse_cluster_info_file.infile @@ -2,3 +2,5 @@ 0 499aa, >seq2... at 99.40% 0 499aa, >seq3... at 98.40% 1 500aa, >seq4... * +2 300aa, >seq5... at 90.42% +2 301aa, >seq6... *