diff --git a/ariba/__init__.py b/ariba/__init__.py index ca6fc157..57037a35 100644 --- a/ariba/__init__.py +++ b/ariba/__init__.py @@ -39,6 +39,7 @@ 'ref_seq_chooser', 'report', 'report_filter', + 'report_flag_expander', 'scaffold_graph', 'samtools_variants', 'sequence_metadata', diff --git a/ariba/flag.py b/ariba/flag.py index 201d4e76..f0e91718 100644 --- a/ariba/flag.py +++ b/ariba/flag.py @@ -61,3 +61,6 @@ def to_long_string(self): def has(self, s): return self.flags[s] + + def to_comma_separated_string(self): + return ','.join([f for f in flags_in_order if self.flags[f]]) diff --git a/ariba/report_flag_expander.py b/ariba/report_flag_expander.py new file mode 100644 index 00000000..da72fea3 --- /dev/null +++ b/ariba/report_flag_expander.py @@ -0,0 +1,37 @@ +import copy +import sys + +import pyfastaq + +from ariba import flag + +class Error (Exception): pass + +class ReportFlagExpander: + def __init__(self, infile, outfile): + self.infile = infile + self.outfile = outfile + + + def run(self): + f_in = pyfastaq.utils.open_file_read(self.infile) + f_out = pyfastaq.utils.open_file_write(self.outfile) + flag_index = None + + for line in f_in: + fields = line.rstrip().split() + + if flag_index is None: + try: + flag_index = fields.index('flag') + except: + raise Error('"flag" column not found in first line of file ' + self.infile +'. Cannot continue') + else: + f = flag.Flag(int(fields[flag_index])) + fields[flag_index] = f.to_comma_separated_string() + + print(*fields, sep='\t', file=f_out) + + f_in.close() + f_out.close() + diff --git a/ariba/tasks/__init__.py b/ariba/tasks/__init__.py index 299f5181..5d501d7c 100644 --- a/ariba/tasks/__init__.py +++ b/ariba/tasks/__init__.py @@ -1,5 +1,6 @@ __all__ = [ 'aln2meta', + 'expandflag', 'flag', 'getref', 'micplot', diff --git a/ariba/tasks/expandflag.py b/ariba/tasks/expandflag.py new file mode 100644 index 00000000..bf257e74 --- /dev/null +++ b/ariba/tasks/expandflag.py @@ -0,0 +1,8 @@ +import argparse +import sys +import ariba + +def run(options): + expander = ariba.report_flag_expander.ReportFlagExpander(options.infile, options.outfile) + expander.run() + diff --git a/ariba/tests/data/report_flag_expander.run.in.tsv b/ariba/tests/data/report_flag_expander.run.in.tsv new file mode 100644 index 00000000..c1919fa0 --- /dev/null +++ b/ariba/tests/data/report_flag_expander.run.in.tsv @@ -0,0 +1,3 @@ +#ariba column1 flag foo +name 1 1 foo +name 2 27 bar diff --git a/ariba/tests/data/report_flag_expander.run.out.tsv b/ariba/tests/data/report_flag_expander.run.out.tsv new file mode 100644 index 00000000..c520cf95 --- /dev/null +++ b/ariba/tests/data/report_flag_expander.run.out.tsv @@ -0,0 +1,3 @@ +#ariba column1 flag foo +name 1 assembled foo +name 2 assembled,assembled_into_one_contig,complete_gene,unique_contig bar diff --git a/ariba/tests/flag_test.py b/ariba/tests/flag_test.py index 7704965d..5ebe4037 100644 --- a/ariba/tests/flag_test.py +++ b/ariba/tests/flag_test.py @@ -64,3 +64,11 @@ def test_has(self): self.assertFalse(f.has(x)) f.add(x) self.assertTrue(f.has(x)) + + + def test_to_comma_separated_string(self): + '''Test to_comma_separated_string''' + f = flag.Flag(27) + expected = 'assembled,assembled_into_one_contig,complete_gene,unique_contig' + self.assertEqual(expected, f.to_comma_separated_string()) + diff --git a/ariba/tests/report_flag_expander_test.py b/ariba/tests/report_flag_expander_test.py new file mode 100644 index 00000000..99f66715 --- /dev/null +++ b/ariba/tests/report_flag_expander_test.py @@ -0,0 +1,20 @@ +import unittest +import os +import filecmp +from ariba import report_flag_expander + +modules_dir = os.path.dirname(os.path.abspath(report_flag_expander.__file__)) +data_dir = os.path.join(modules_dir, 'tests', 'data') + + +class TestReportFlagExpander(unittest.TestCase): + def test_run(self): + '''test run''' + infile = os.path.join(data_dir, 'report_flag_expander.run.in.tsv') + expected = os.path.join(data_dir, 'report_flag_expander.run.out.tsv') + tmp_out = 'tmp.report_flag_expander.out.tsv' + expander = report_flag_expander.ReportFlagExpander(infile, tmp_out) + expander.run() + self.assertTrue(filecmp.cmp(expected, tmp_out, shallow=False)) + os.unlink(tmp_out) + diff --git a/scripts/ariba b/scripts/ariba index 9a76385e..498dddd0 100755 --- a/scripts/ariba +++ b/scripts/ariba @@ -29,6 +29,18 @@ subparser_aln2meta.add_argument('outprefix', help='Prefix of output filenames') subparser_aln2meta.set_defaults(func=ariba.tasks.aln2meta.run) +#---------------------------- expandflag ------------------------------ +subparser_expandflag = subparsers.add_parser( + 'expandflag', + help='Expands flag column of report file', + usage='ariba expandflag