Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes #70

Merged
merged 2 commits into from
Apr 28, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions ariba/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,36 @@ def _gather_output_rows(self):
return rows


@classmethod
def _filter_clusters(cls, rows):
'''Removes any cluster where every sample has assembled == "no".
Returns tuple: (filtered rows, number of remaining columns)'''
found_a_yes = set()
first_filename = True
all_clusters = set()

for filename in rows:
for cluster in rows[filename]:
if first_filename:
all_clusters.add(cluster)

if cluster in found_a_yes:
continue

if rows[filename][cluster]['assembled'] == 'yes':
found_a_yes.add(cluster)

first_filename = False

to_delete = all_clusters.difference(found_a_yes)

for filename in rows:
for cluster in to_delete:
del rows[filename][cluster]

return rows, len(found_a_yes)


@classmethod
def _write_csv(cls, filenames, rows, outfile, phandango=False):
lines = []
Expand Down Expand Up @@ -223,6 +253,11 @@ def run(self):
self._check_files_exist()
self.samples = self._load_input_files(self.filenames, self.min_id)
self.rows = self._gather_output_rows()
self.rows, remaining_clusters = Summary._filter_clusters(self.rows)
if remaining_clusters == 0:
print('No clusters found that are present in any sample. Will not write any output files', file=sys.stderr)
sys.exit(1)

Summary._write_csv(self.filenames, self.rows, self.outprefix + '.csv', phandango=False)

if len(self.samples) > 1:
Expand Down
34 changes: 34 additions & 0 deletions ariba/tests/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,40 @@ def test_gather_output_rows(self):
self.assertEqual(expected, got)


def test_filter_clusters(self):
'''Test _filter_clusters'''
rows = {
'file1': {
'cluster1': {'assembled': 'yes'},
'cluster2': {'assembled': 'yes'},
'cluster3': {'assembled': 'no'},
'cluster4': {'assembled': 'no'},
},
'file2': {
'cluster1': {'assembled': 'yes'},
'cluster2': {'assembled': 'no'},
'cluster3': {'assembled': 'yes'},
'cluster4': {'assembled': 'no'},
}
}

expected = {
'file1': {
'cluster1': {'assembled': 'yes'},
'cluster2': {'assembled': 'yes'},
'cluster3': {'assembled': 'no'},
},
'file2': {
'cluster1': {'assembled': 'yes'},
'cluster2': {'assembled': 'no'},
'cluster3': {'assembled': 'yes'},
}
}

got = summary.Summary._filter_clusters(rows)
self.assertEqual((expected, 3), got)


def test_write_csv(self):
'''Test _write_csv'''
tmp_out = 'tmp.out.tsv'
Expand Down