Skip to content

Commit

Permalink
Merge pull request #55 from martinghunt/cdhit_cluster_bug
Browse files Browse the repository at this point in the history
Do not assume cluster representative comes first
  • Loading branch information
martinghunt committed Apr 18, 2016
2 parents 2ead25c + 6127784 commit e31c6ea
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
27 changes: 15 additions & 12 deletions ariba/cdhit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def _get_ids(self, infile):
@staticmethod
def _parse_cluster_info_file(infile, cluster_representatives):
f = pyfastaq.utils.open_file_read(infile)
clusters = {}
current_cluster = None
cluster_sets = {}
found_representatives = {} # store cluster number -> representative name

for line in f:
data = line.rstrip().split()
Expand All @@ -66,21 +66,24 @@ def _parse_cluster_info_file(infile, cluster_representatives):
raise Error('Unexpected format of line from cdhit output file "' + infile + '". Line is:\n' + line)
seqname = seqname[1:-3]

if data[3] == '*':
current_cluster = seqname
assert current_cluster not in clusters
clusters[current_cluster] = {current_cluster}
else:
assert current_cluster in clusters
if seqname in clusters[current_cluster]:
raise Error('Duplicate name "' + seqname + '" found in cluster ' + current_cluster)
cluster_number = int(data[0]) # this is the cluster number used by cdhit
if cluster_number not in cluster_sets:
cluster_sets[cluster_number] = set()

clusters[current_cluster].add(seqname)
cluster_sets[cluster_number].add(seqname)

if data[3] == '*':
found_representatives[cluster_number] = seqname

pyfastaq.utils.close(f)
if set(clusters.keys()) != cluster_representatives:

if set(found_representatives.values()) != cluster_representatives:
raise Error('Mismatch in cdhit output sequence names between fasta file and clusters file. Cannot continue')

clusters = {}
for cluster_number, cluster_name in found_representatives.items():
clusters[cluster_name] = cluster_sets[cluster_number]

return clusters


Expand Down
5 changes: 3 additions & 2 deletions ariba/tests/cdhit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ def test_get_ids(self):

def test_parse_cluster_info_file(self):
'''test _parse_cluster_info_file'''
cluster_representatives = {'seq1', 'seq4'}
cluster_representatives = {'seq1', 'seq4', 'seq6'}
infile = os.path.join(data_dir, 'cdhit_test_parse_cluster_info_file.infile')
got_clusters = cdhit.Runner._parse_cluster_info_file(infile, cluster_representatives)
expected_clusters = {
'seq1': {'seq1', 'seq2', 'seq3'},
'seq4': {'seq4'}
'seq4': {'seq4'},
'seq6': {'seq5', 'seq6'},
}
self.assertEqual(expected_clusters, got_clusters)

Expand Down
2 changes: 2 additions & 0 deletions ariba/tests/data/cdhit_test_parse_cluster_info_file.infile
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
0 499aa, >seq2... at 99.40%
0 499aa, >seq3... at 98.40%
1 500aa, >seq4... *
2 300aa, >seq5... at 90.42%
2 301aa, >seq6... *

0 comments on commit e31c6ea

Please sign in to comment.