From 3fe893328d7590364f7f47e334008b2050edde0e Mon Sep 17 00:00:00 2001 From: martinghunt Date: Thu, 28 Apr 2016 14:40:32 +0000 Subject: [PATCH 1/2] Add method _filter_clusters --- ariba/summary.py | 30 ++++++++++++++++++++++++++++++ ariba/tests/summary_test.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/ariba/summary.py b/ariba/summary.py index 82fab952..881ac7b3 100644 --- a/ariba/summary.py +++ b/ariba/summary.py @@ -114,6 +114,36 @@ def _gather_output_rows(self): return rows + @classmethod + def _filter_clusters(cls, rows): + '''Removes any cluster where every sample has assembled == "no"''' + found_a_yes = set() + first_filename = True + all_clusters = set() + + for filename in rows: + for cluster in rows[filename]: + if first_filename: + all_clusters.add(cluster) + + if cluster in found_a_yes: + continue + + if rows[filename][cluster]['assembled'] == 'yes': + found_a_yes.add(cluster) + break + + first_filename = False + + to_delete = all_clusters.difference(found_a_yes) + + for filename in rows: + for cluster in to_delete: + del rows[filename][cluster] + + return rows + + @classmethod def _write_csv(cls, filenames, rows, outfile, phandango=False): lines = [] diff --git a/ariba/tests/summary_test.py b/ariba/tests/summary_test.py index 0b7a9716..fc13b01c 100644 --- a/ariba/tests/summary_test.py +++ b/ariba/tests/summary_test.py @@ -120,6 +120,40 @@ def test_gather_output_rows(self): self.assertEqual(expected, got) + def test_filter_clusters(self): + '''Test _filter_clusters''' + rows = { + 'file1': { + 'cluster1': {'assembled': 'yes'}, + 'cluster2': {'assembled': 'yes'}, + 'cluster3': {'assembled': 'no'}, + 'cluster4': {'assembled': 'no'}, + }, + 'file2': { + 'cluster1': {'assembled': 'yes'}, + 'cluster2': {'assembled': 'no'}, + 'cluster3': {'assembled': 'yes'}, + 'cluster4': {'assembled': 'no'}, + } + } + + expected = { + 'file1': { + 'cluster1': {'assembled': 'yes'}, + 'cluster2': {'assembled': 'yes'}, + 'cluster3': {'assembled': 'no'}, + }, + 'file2': { + 'cluster1': {'assembled': 'yes'}, + 'cluster2': {'assembled': 'no'}, + 'cluster3': {'assembled': 'yes'}, + } + } + + got = summary.Summary._filter_clusters(rows) + self.assertEqual(expected, got) + + def test_write_csv(self): '''Test _write_csv''' tmp_out = 'tmp.out.tsv' From 4f0caffcfee41d133b24f080461c2fb03e5d9c66 Mon Sep 17 00:00:00 2001 From: martinghunt Date: Thu, 28 Apr 2016 14:53:00 +0000 Subject: [PATCH 2/2] Report error if all columns get removed --- ariba/summary.py | 11 ++++++++--- ariba/tests/summary_test.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ariba/summary.py b/ariba/summary.py index 881ac7b3..3e354f68 100644 --- a/ariba/summary.py +++ b/ariba/summary.py @@ -116,7 +116,8 @@ def _gather_output_rows(self): @classmethod def _filter_clusters(cls, rows): - '''Removes any cluster where every sample has assembled == "no"''' + '''Removes any cluster where every sample has assembled == "no". + Returns tuple: (filtered rows, number of remaining columns)''' found_a_yes = set() first_filename = True all_clusters = set() @@ -131,7 +132,6 @@ def _filter_clusters(cls, rows): if rows[filename][cluster]['assembled'] == 'yes': found_a_yes.add(cluster) - break first_filename = False @@ -141,7 +141,7 @@ def _filter_clusters(cls, rows): for cluster in to_delete: del rows[filename][cluster] - return rows + return rows, len(found_a_yes) @classmethod @@ -253,6 +253,11 @@ def run(self): self._check_files_exist() self.samples = self._load_input_files(self.filenames, self.min_id) self.rows = self._gather_output_rows() + self.rows, remaining_clusters = Summary._filter_clusters(self.rows) + if remaining_clusters == 0: + print('No clusters found that are present in any sample. Will not write any output files', file=sys.stderr) + sys.exit(1) + Summary._write_csv(self.filenames, self.rows, self.outprefix + '.csv', phandango=False) if len(self.samples) > 1: diff --git a/ariba/tests/summary_test.py b/ariba/tests/summary_test.py index fc13b01c..8f5094e3 100644 --- a/ariba/tests/summary_test.py +++ b/ariba/tests/summary_test.py @@ -151,7 +151,7 @@ def test_filter_clusters(self): } got = summary.Summary._filter_clusters(rows) - self.assertEqual(expected, got) + self.assertEqual((expected, 3), got) def test_write_csv(self):