Skip to content

Commit

Permalink
Merge pull request #36 from martinghunt/syn_variants_flag
Browse files Browse the repository at this point in the history
Improvements to summary output; new flag for non-synonymous variants
  • Loading branch information
bewt85 committed Oct 28, 2015
2 parents 876004f + a01c303 commit dbc9f44
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 15 deletions.
22 changes: 21 additions & 1 deletion ariba/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,10 @@ def _get_vcf_variant_counts(self):
self.status_flag.add('variants_suggest_collapsed_repeat')


def _make_report_lines(self):
def _initial_make_report_lines(self):
'''Makes report lines. While they are being made, we discover if there were
and non-synonymous variants. This affects the flag, which also gets updated
by the function. To then fix the report lines, must run _update_flag_in_report_lines()'''
self.report_lines = []
total_reads = self._get_read_counts()

Expand Down Expand Up @@ -903,6 +906,9 @@ def _make_report_lines(self):
t = self._get_variant_effect(variants)
if t is not None:
effect, new_bases = t
if effect != 'SYN':
self.status_flag.add('has_nonsynonymous_variants')

for v in variants:
depths = self._get_assembly_read_depths(contig, v.qry_start)
if depths is None:
Expand Down Expand Up @@ -987,6 +993,20 @@ def _make_report_lines(self):

self.report_lines.sort(key=itemgetter(0, 14, 15))


def _update_flag_in_report_lines(self):
'''This corrects the flag in all the report lines made by _initial_make_report_lines()'''
flag_column = 1
if self.status_flag.has('has_nonsynonymous_variants'):
for line in self.report_lines:
line[flag_column] = self.status_flag.to_number()


def _make_report_lines(self):
self._initial_make_report_lines()
self._update_flag_in_report_lines()


def _clean(self):
if self.verbose:
print('Cleaning', self.root_dir)
Expand Down
2 changes: 1 addition & 1 deletion ariba/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import subprocess

version = '0.5.0'
version = '0.6.0'

def syscall(cmd, allow_fail=False, verbose=False):
if verbose:
Expand Down
1 change: 1 addition & 0 deletions ariba/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Error (Exception): pass
'assembly_fail',
'variants_suggest_collapsed_repeat',
'hit_both_strands',
'has_nonsynonymous_variants',
]


Expand Down
16 changes: 12 additions & 4 deletions ariba/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
outfile,
filenames=None,
fofn=None,
filter_output=True,
min_id=90.0
):
if filenames is None and fofn is None:
Expand All @@ -61,6 +62,7 @@ def __init__(
if fofn is not None:
self.filenames.extend(self._load_fofn(fofn))

self.filter_output = filter_output
self.min_id = min_id
self.outfile = outfile

Expand Down Expand Up @@ -122,10 +124,13 @@ def _to_summary_number(self, l):
if f.has('hit_both_strands') or (not f.has('complete_orf')):
return 1

if f.has('unique_contig') and f.has('gene_assembled_into_one_contig'):
return 3

return 2
if f.has('unique_contig') and f.has('gene_assembled_into_one_contig') and f.has('complete_orf'):
if f.has('has_nonsynonymous_variants'):
return 3
else:
return 4
else:
return 2


def _pc_id_of_longest(self, l):
Expand Down Expand Up @@ -165,6 +170,9 @@ def _gather_output_rows(self):


def _filter_output_rows(self):
if not self.filter_output:
return

# remove rows that are all zeros
self.rows_out = [x for x in self.rows_out if x[1:] != [0]*(len(x)-1)]

Expand Down
2 changes: 2 additions & 0 deletions ariba/tasks/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def run():
epilog = 'Files must be listed after the output file and/or the option --fofn must be used. If both used, all files in the filename specified by --fofn AND the files listed after the output file will be used as input. The input report files must be in tsv format, not xls.')
parser.add_argument('-f', '--fofn', help='File of filenames of ariba reports in tsv format (not xls) to be summarised. Must be used if no input files listed after the outfile.', metavar='FILENAME')
parser.add_argument('--min_id', type=float, help='Minimum percent identity cutoff to count as assembled [%(default)s]', default=90, metavar='FLOAT')
parser.add_argument('--no_filter', action='store_true', help='Do not filter rows or columns of output that are all 0 (by deafult, they are removed from the output)')
parser.add_argument('outfile', help='Name of output file. If file ends with ".xls", then an excel spreadsheet is written. Otherwise a tsv file is written')
parser.add_argument('infiles', nargs='*', help='Files to be summarised')
options = parser.parse_args()
Expand All @@ -18,6 +19,7 @@ def run():
options.outfile,
fofn=options.fofn,
filenames=options.infiles,
filter_output=(not options.no_filter),
min_id=options.min_id
)
s.run()
45 changes: 44 additions & 1 deletion ariba/tests/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,50 @@ def test_get_vcf_variant_counts(self):
clean_cluster_dir(cluster_dir)


def test_make_report_lines(self):
def test_make_report_lines_nonsynonymous(self):
'''test _make_report_lines'''
cluster_dir = os.path.join(data_dir, 'cluster_test_generic')
clean_cluster_dir(cluster_dir)
c = cluster.Cluster(cluster_dir, 'cluster_name')
c.gene = pyfastaq.sequences.Fasta('gene', 'GATCGCGAAGCGATGACCCATGAAGCGACCGAACGCTGA')
v1 = pymummer.variant.Variant(pymummer.snp.Snp('8\tA\tG\t8\tx\tx\t39\t39\tx\tx\tgene\tcontig'))

nucmer_hit = ['1', '10', '1', '10', '10', '10', '90.00', '1000', '1000', '1', '1', 'gene', 'contig']
c.nucmer_hits = {'contig': [pymummer.alignment.Alignment('\t'.join(nucmer_hit))]}
c.mummer_variants = {'contig': [[v1]]}
c.percent_identities = {'contig': 92.42}
c.status_flag.set_flag(42)
c.assembled_ok = True
c.final_assembly_read_depths = os.path.join(data_dir, 'cluster_test_make_report_lines.read_depths.gz')
c._make_report_lines()
expected = [[
'gene',
554,
2,
'cluster_name',
39,
10,
92.42,
'SNP',
'NONSYN',
'E3G',
8,
8,
'A',
'contig',
39,
8,
8,
'G',
'.',
'.',
'.'
]]
self.assertEqual(expected, c.report_lines)
clean_cluster_dir(cluster_dir)


def test_make_report_lines_synonymous(self):
'''test _make_report_lines'''
cluster_dir = os.path.join(data_dir, 'cluster_test_generic')
clean_cluster_dir(cluster_dir)
Expand Down
3 changes: 2 additions & 1 deletion ariba/tests/flag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_set_flag(self):
def test_add(self):
'''Test add'''
f = flag.Flag()
expected = [1, 3, 7, 15, 31, 63, 127, 255, 511]
expected = [1, 3, 7, 15, 31, 63, 127, 255, 511, 1023]
for i in range(len(flag.flags_in_order)):
f.add(flag.flags_in_order[i])
self.assertEqual(f.to_number(), expected[i])
Expand All @@ -50,6 +50,7 @@ def test_to_long_str(self):
'[ ] assembly_fail',
'[ ] variants_suggest_collapsed_repeat',
'[ ] hit_both_strands',
'[ ] has_nonsynonymous_variants',
])

self.assertEqual(expected, f.to_long_string())
Expand Down
26 changes: 22 additions & 4 deletions ariba/tests/summary_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import copy
import filecmp
import os
from ariba import summary, flag
Expand Down Expand Up @@ -75,7 +76,8 @@ def test_to_summary_number(self):
(7, 1),
(259, 1),
(15, 2),
(27, 3),
(539, 3),
(27, 4),
]

for t in tests:
Expand All @@ -96,13 +98,13 @@ def test_gather_output_rows(self):
s._gather_output_rows()
expected = [
['filename', 'gene1', 'gene2', 'gene3'],
[infiles[0], 3, 2, 0],
[infiles[1], 3, 0, 3],
[infiles[0], 4, 2, 0],
[infiles[1], 4, 0, 4],
]
self.assertEqual(expected, s.rows_out)


def test_filter_output_rows(self):
def test_filter_output_rows_filter_true(self):
'''Test _filter_output_rows'''
s = summary.Summary('out', filenames=['spam', 'eggs'])
s.rows_out = [
Expand All @@ -122,6 +124,22 @@ def test_filter_output_rows(self):
self.assertEqual(s.rows_out, expected)


def test_filter_output_rows_filter_false(self):
'''Test _filter_output_rows'''
s = summary.Summary('out', filenames=['spam', 'eggs'], filter_output=False)
rows_out = [
['filename', 'gene1', 'gene2', 'gene3'],
['file1', 0, 0, 0],
['file2', 1, 0, 3],
['file3', 2, 0, 4],
]

s.rows_out = copy.copy(rows_out)

s._filter_output_rows()
self.assertEqual(s.rows_out, rows_out)


def test_write_tsv(self):
'''Test _write_tsv'''
tmp_out = 'tmp.out.tsv'
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name='ariba',
version='0.5.0',
version='0.6.0',
description='ARIBA: Antibiotic Resistance Identification By Assembly',
packages = find_packages(),
author='Martin Hunt',
Expand All @@ -18,9 +18,9 @@
tests_require=['nose >= 1.3'],
install_requires=[
'openpyxl',
'pyfastaq >= 3.0.1',
'pyfastaq >= 3.10.0',
'pysam >= 0.8.1',
'pymummer>=0.0.2'
'pymummer>=0.6.1'
],
license='GPLv3',
classifiers=[
Expand Down

0 comments on commit dbc9f44

Please sign in to comment.