diff --git a/ariba/cluster.py b/ariba/cluster.py index 91089710..38026939 100644 --- a/ariba/cluster.py +++ b/ariba/cluster.py @@ -1,8 +1,10 @@ import os +import random +import math import shutil import sys import pyfastaq -from ariba import assembly, assembly_compare, assembly_variants, bam_parse, best_seq_chooser, external_progs, flag, mapping, report, samtools_variants +from ariba import assembly, assembly_compare, assembly_variants, bam_parse, best_seq_chooser, common, external_progs, flag, mapping, report, samtools_variants class Error (Exception): pass @@ -13,6 +15,9 @@ def __init__(self, root_dir, name, refdata, + total_reads, + total_reads_bases, + assembly_coverage=100, assembly_kmer=21, assembler='spades', max_insert=1000, @@ -34,6 +39,7 @@ def __init__(self, spades_other_options=None, clean=1, extern_progs=None, + random_seed=42, ): self.root_dir = os.path.abspath(root_dir) @@ -42,6 +48,9 @@ def __init__(self, self.name = name self.refdata = refdata + self.total_reads = total_reads + self.total_reads_bases = total_reads_bases + self.assembly_coverage = assembly_coverage self.assembly_kmer = assembly_kmer self.assembler = assembler self.sspace_k = sspace_k @@ -49,12 +58,14 @@ def __init__(self, self.reads_insert = reads_insert self.spades_other_options = spades_other_options - self.reads1 = os.path.join(self.root_dir, 'reads_1.fq') - self.reads2 = os.path.join(self.root_dir, 'reads_2.fq') + self.all_reads1 = os.path.join(self.root_dir, 'reads_1.fq') + self.all_reads2 = os.path.join(self.root_dir, 'reads_2.fq') + self.reads_for_assembly1 = os.path.join(self.root_dir, 'reads_for_assembly_1.fq') + self.reads_for_assembly2 = os.path.join(self.root_dir, 'reads_for_assembly_2.fq') self.reference_fa = os.path.join(self.root_dir, 'reference.fa') self.references_fa = os.path.join(self.root_dir, 'references.fa') - for fname in [self.reads1, self.reads2, self.references_fa]: + for fname in [self.all_reads1, self.all_reads2, self.references_fa]: if not os.path.exists(fname): raise Error('File ' + fname + ' not found. Cannot continue') @@ -92,7 +103,6 @@ def __init__(self, self.mummer_variants = {} self.variant_depths = {} self.percent_identities = {} - self.total_reads = self._count_reads(self.reads1, self.reads2) # The log filehandle self.log_fh is set at the start of the run() method. # Lots of other methods use self.log_fh. But for unit testing, run() isn't @@ -109,13 +119,7 @@ def __init__(self, else: self.extern_progs = extern_progs - - @staticmethod - def _count_reads(reads1, reads2): - count1 = pyfastaq.tasks.count_sequences(reads1) - count2 = pyfastaq.tasks.count_sequences(reads2) - assert(count1 == count2) - return count1 + count2 + self.random_seed = random_seed def _clean(self): @@ -132,8 +136,8 @@ def _clean(self): shutil.rmtree(self.assembly_dir) to_delete = [ - self.reads1, - self.reads2, + self.all_reads1, + self.all_reads2, self.references_fa, self.references_fa + '.fai', self.final_assembly_bam + '.read_depths.gz', @@ -153,14 +157,64 @@ def _clean(self): raise Error('Error deleting file', filename) + @staticmethod + def _number_of_reads_for_assembly(reference_fa, insert_size, total_bases, total_reads, coverage): + file_reader = pyfastaq.sequences.file_reader(reference_fa) + ref_length = sum([len(x) for x in file_reader]) + assert ref_length > 0 + ref_length += 2 * insert_size + mean_read_length = total_bases / total_reads + wanted_bases = coverage * ref_length + wanted_reads = int(math.ceil(wanted_bases / mean_read_length)) + wanted_reads += wanted_reads % 2 + return wanted_reads + + + @staticmethod + def _make_reads_for_assembly(number_of_wanted_reads, total_reads, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=None): + '''Makes fastq files that are random subset of input files. Returns total number of reads in output files. + If the number of wanted reads is >= total reads, then just makes symlinks instead of making + new copies of the input files.''' + random.seed(random_seed) + + if number_of_wanted_reads < total_reads: + reads_written = 0 + percent_wanted = 100 * number_of_wanted_reads / total_reads + file_reader1 = pyfastaq.sequences.file_reader(reads_in1) + file_reader2 = pyfastaq.sequences.file_reader(reads_in2) + out1 = pyfastaq.utils.open_file_write(reads_out1) + out2 = pyfastaq.utils.open_file_write(reads_out2) + + for read1 in file_reader1: + try: + read2 = next(file_reader2) + except StopIteration: + pyfastaq.utils.close(out1) + pyfastaq.utils.close(out2) + raise Error('Error subsetting reads. No mate found for read ' + read1.id) + + if random.randint(0, 100) <= percent_wanted: + print(read1, file=out1) + print(read2, file=out2) + reads_written += 2 + + pyfastaq.utils.close(out1) + pyfastaq.utils.close(out2) + return reads_written + else: + os.symlink(reads_in1, reads_out1) + os.symlink(reads_in2, reads_out2) + return total_reads + + def run(self): self.logfile = os.path.join(self.root_dir, 'log.txt') self.log_fh = pyfastaq.utils.open_file_write(self.logfile) print('Choosing best reference sequence:', file=self.log_fh, flush=True) seq_chooser = best_seq_chooser.BestSeqChooser( - self.reads1, - self.reads2, + self.all_reads1, + self.all_reads2, self.references_fa, self.log_fh, samtools_exe=self.extern_progs.exe('samtools'), @@ -174,12 +228,15 @@ def run(self): self.status_flag.add('ref_seq_choose_fail') self.assembled_ok = False else: - print('\nAssembling reads:', file=self.log_fh, flush=True) + wanted_reads = self._number_of_reads_for_assembly(self.reference_fa, self.reads_insert, self.total_reads_bases, self.total_reads, self.assembly_coverage) + made_reads = self._make_reads_for_assembly(wanted_reads, self.total_reads, self.all_reads1, self.all_reads2, self.reads_for_assembly1, self.reads_for_assembly2, random_seed=self.random_seed) + print('\nUsing', made_reads, 'from a total of', self.total_reads, 'for assembly.', file=self.log_fh, flush=True) + print('Assembling reads:', file=self.log_fh, flush=True) self.ref_sequence_type = self.refdata.sequence_type(self.ref_sequence.id) assert self.ref_sequence_type is not None self.assembly = assembly.Assembly( - self.reads1, - self.reads2, + self.reads_for_assembly1, + self.reads_for_assembly2, self.reference_fa, self.assembly_dir, self.final_assembly_fa, @@ -202,8 +259,8 @@ def run(self): print('\nAssembly was successful\n\nMapping reads to assembly:', file=self.log_fh, flush=True) mapping.run_bowtie2( - self.reads1, - self.reads2, + self.all_reads1, + self.all_reads2, self.final_assembly_fa, self.final_assembly_bam[:-4], threads=1, diff --git a/ariba/clusters.py b/ariba/clusters.py index 28b718c7..d78849fb 100644 --- a/ariba/clusters.py +++ b/ariba/clusters.py @@ -29,6 +29,7 @@ def __init__(self, outdir, extern_progs, assembly_kmer=21, + assembly_coverage=100, threads=1, verbose=False, assembler='spades', @@ -57,6 +58,7 @@ def __init__(self, self.assembler = assembler assert self.assembler in ['spades'] self.assembly_kmer = assembly_kmer + self.assembly_coverage = assembly_coverage self.spades_other = spades_other self.refdata_files_prefix = os.path.join(self.outdir, 'refdata') @@ -91,6 +93,8 @@ def __init__(self, self.cluster_to_dir = {} # gene name -> abs path of cluster directory self.clusters = {} # gene name -> Cluster object + self.cluster_read_counts = {} # gene name -> number of reads + self.cluster_base_counts = {} # gene name -> number of bases self.cdhit_seq_identity_threshold = cdhit_seq_identity_threshold self.cdhit_length_diff_cutoff = cdhit_length_diff_cutoff @@ -210,6 +214,8 @@ def _bam_to_clusters_reads(self): print(read1, file=filehandles_1[ref]) print(read2, file=filehandles_2[ref]) + self.cluster_read_counts[ref] = self.cluster_read_counts.get(ref, 0) + 2 + self.cluster_base_counts[ref] = self.cluster_base_counts.get(ref, 0) + len(read1) + len(read2) sam1 = None @@ -257,7 +263,10 @@ def _init_and_run_clusters(self): cluster_list.append(cluster.Cluster( new_dir, seq_name, - refdata=self.refdata, + self.refdata, + self.cluster_read_counts[seq_name], + self.cluster_base_counts[seq_name], + assembly_coverage=self.assembly_coverage, assembly_kmer=self.assembly_kmer, assembler=self.assembler, max_insert=self.insert_proper_pair_max, diff --git a/ariba/tasks/run.py b/ariba/tasks/run.py index e558385b..9c488021 100644 --- a/ariba/tasks/run.py +++ b/ariba/tasks/run.py @@ -31,6 +31,7 @@ def run(): nucmer_group.add_argument('--nucmer_breaklen', type=int, help='Value to use for -breaklen when running nucmer [%(default)s]', default=50, metavar='INT') assembly_group = parser.add_argument_group('Assembly options') + assembly_group.add_argument('--assembly_cov', type=int, help='Target read coverage when sampling reads for assembly [%(default)s]', default=100, metavar='INT') assembly_group.add_argument('--assembler_k', type=int, help='kmer size to use with assembler. You can use 0 to set kmer to 2/3 of the read length. Warning - lower kmers are usually better. [%(default)s]', metavar='INT', default=21) assembly_group.add_argument('--spades_other', help='Put options string to be used with spades in quotes. This will NOT be sanity checked. Do not use -k or -t: for these options you should use the ariba run options --assembler_k and --threads [%(default)s]', default="--only-assembler", metavar="OPTIONS") assembly_group.add_argument('--min_scaff_depth', type=int, help='Minimum number of read pairs needed as evidence for scaffold link between two contigs. This is also the value used for sspace -k when scaffolding [%(default)s]', default=10, metavar='INT') @@ -72,6 +73,7 @@ def run(): options.outdir, extern_progs, assembly_kmer=options.assembler_k, + assembly_coverage=options.assembly_cov, assembler='spades', threads=options.threads, verbose=options.verbose, diff --git a/ariba/tests/cluster_test.py b/ariba/tests/cluster_test.py index 2bb9f140..3585e5e6 100644 --- a/ariba/tests/cluster_test.py +++ b/ariba/tests/cluster_test.py @@ -48,18 +48,60 @@ def test_init_fail_files_missing(self): tmpdir = 'tmp.cluster_test_init_fail_files_missing' shutil.copytree(d, tmpdir) with self.assertRaises(cluster.Error): - c = cluster.Cluster(tmpdir, 'name', refdata=refdata) + c = cluster.Cluster(tmpdir, 'name', refdata=refdata, total_reads=42, total_reads_bases=4242) shutil.rmtree(tmpdir) with self.assertRaises(cluster.Error): - c = cluster.Cluster('directorydoesnotexistshouldthrowerror', 'name', refdata=refdata) - + c = cluster.Cluster('directorydoesnotexistshouldthrowerror', 'name', refdata=refdata, total_reads=42, total_reads_bases=4242) + + + def test_number_of_reads_for_assembly(self): + '''Test _number_of_reads_for_assembly''' + # ref is 100bp long + ref_fa = os.path.join(data_dir, 'cluster_test_number_of_reads_for_assembly.ref.fa') + tests = [ + (50, 1000, 10, 20, 40), + (50, 999, 10, 20, 42), + (50, 1000, 10, 10, 20), + (50, 1000, 10, 5, 10), + ] - def test_count_reads(self): - '''test _count_reads pass''' - reads1 = os.path.join(data_dir, 'cluster_test_count_reads_1.fq') - reads2 = os.path.join(data_dir, 'cluster_test_count_reads_2.fq') - self.assertEqual(4, cluster.Cluster._count_reads(reads1, reads2)) + for insert, bases, reads, coverage, expected in tests: + self.assertEqual(expected, cluster.Cluster._number_of_reads_for_assembly(ref_fa, insert, bases, reads, coverage)) + + + def test_make_reads_for_assembly_proper_sample(self): + '''Test _make_reads_for_assembly when sampling from reads''' + reads_in1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in1.fq') + reads_in2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in2.fq') + expected_out1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out1.fq') + expected_out2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out2.fq') + reads_out1 = 'tmp.test_make_reads_for_assembly.reads.out1.fq' + reads_out2 = 'tmp.test_make_reads_for_assembly.reads.out2.fq' + reads_written = cluster.Cluster._make_reads_for_assembly(10, 20, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=42) + self.assertEqual(14, reads_written) + self.assertTrue(filecmp.cmp(expected_out1, reads_out1, shallow=False)) + self.assertTrue(filecmp.cmp(expected_out2, reads_out2, shallow=False)) + os.unlink(reads_out1) + os.unlink(reads_out2) + + + def test_make_reads_for_assembly_symlinks(self): + '''Test _make_reads_for_assembly when just makes symlinks''' + reads_in1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in1.fq') + reads_in2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in2.fq') + expected_out1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out1.fq') + expected_out2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out2.fq') + reads_out1 = 'tmp.test_make_reads_for_assembly.reads.out1.fq' + reads_out2 = 'tmp.test_make_reads_for_assembly.reads.out2.fq' + reads_written = cluster.Cluster._make_reads_for_assembly(20, 20, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=42) + self.assertEqual(20, reads_written) + self.assertTrue(os.path.islink(reads_out1)) + self.assertTrue(os.path.islink(reads_out2)) + self.assertEqual(os.readlink(reads_out1), reads_in1) + self.assertEqual(os.readlink(reads_out2), reads_in2) + os.unlink(reads_out1) + os.unlink(reads_out2) def test_full_run_choose_ref_fail(self): @@ -70,7 +112,7 @@ def test_full_run_choose_ref_fail(self): tmpdir = 'tmp.test_full_run_choose_ref_fail' shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_choose_ref_fail'), tmpdir) - c = cluster.Cluster(tmpdir, 'cluster_name', refdata) + c = cluster.Cluster(tmpdir, 'cluster_name', refdata, total_reads=2, total_reads_bases=108) c.run() expected = '\t'.join(['.', '.', '1088', '2', 'cluster_name'] + ['.'] * 23) @@ -88,7 +130,7 @@ def test_full_run_assembly_fail(self): tmpdir = 'tmp.test_full_run_assembly_fail' shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_assembly_fail'), tmpdir) - c = cluster.Cluster(tmpdir, 'cluster_name', refdata) + c = cluster.Cluster(tmpdir, 'cluster_name', refdata, total_reads=4, total_reads_bases=304) c.run() expected = '\t'.join(['noncoding_ref_seq', 'non_coding', '64', '4', 'cluster_name'] + ['.'] * 23) @@ -108,7 +150,7 @@ def test_full_run_ok_non_coding(self): tmpdir = 'tmp.test_full_run_ok_non_coding' shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_non_coding'), tmpdir) - c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler') + c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=72, total_reads_bases=3600) c.run() expected = [ @@ -134,7 +176,7 @@ def test_full_run_ok_presence_absence(self): tmpdir = 'tmp.cluster_test_full_run_ok_presence_absence' shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_presence_absence'), tmpdir) - c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler') + c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=64, total_reads_bases=3200) c.run() expected = [ @@ -160,7 +202,7 @@ def test_full_run_ok_variants_only_variant_not_present(self): tmpdir = 'tmp.cluster_test_full_run_ok_variants_only.not_present' shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_variants_only'), tmpdir) - c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler') + c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=66, total_reads_bases=3300) c.run() expected = [ 'variants_only1\tvariants_only\t27\t66\tcluster_name\t96\t96\t100.0\tvariants_only1.scaffold.1\t215\t1\tSNP\tp\tR3S\t0\t.\t.\t7\t9\tC;G;C\t65\t67\tC;G;C\t18;18;19\t.;.;.\t18;18;19\tvariants_only1_p_R3S_Ref and assembly have wild type, so do not report\tGeneric description of variants_only1' @@ -179,7 +221,7 @@ def test_full_run_ok_variants_only_variant_not_present_always_report(self): tmpdir = 'tmp.cluster_test_full_run_ok_variants_only.not_present.always_report' shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_variants_only'), tmpdir) - c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler') + c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=66, total_reads_bases=3300) c.run() expected = [ 'variants_only1\tvariants_only\t27\t66\tcluster_name\t96\t96\t100.0\tvariants_only1.scaffold.1\t215\t1\tSNP\tp\tR3S\t0\t.\t.\t7\t9\tC;G;C\t65\t67\tC;G;C\t18;18;19\t.;.;.\t18;18;19\tvariants_only1_p_R3S_Ref and assembly have wild type, but always report anyway\tGeneric description of variants_only1' @@ -198,7 +240,7 @@ def test_full_run_ok_variants_only_variant_is_present(self): tmpdir = 'tmp.cluster_test_full_run_ok_variants_only.present' shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_variants_only'), tmpdir) - c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler') + c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=66, total_reads_bases=3300) c.run() expected = [ diff --git a/ariba/tests/clusters_test.py b/ariba/tests/clusters_test.py index 3d20c310..528ad618 100644 --- a/ariba/tests/clusters_test.py +++ b/ariba/tests/clusters_test.py @@ -93,6 +93,8 @@ def test_bam_to_clusters_reads(self): self.assertTrue(filecmp.cmp(expected[i], got[i], shallow=False)) self.assertEqual({780:1}, c.insert_hist.bins) + self.assertEqual({'ref1': 4, 'ref2': 2}, c.cluster_read_counts) + self.assertEqual({'ref1': 240, 'ref2': 120}, c.cluster_base_counts) shutil.rmtree(clusters_dir) diff --git a/ariba/tests/data/cluster_test_make_reads_for_assembly.in1.fq b/ariba/tests/data/cluster_test_make_reads_for_assembly.in1.fq new file mode 100644 index 00000000..976e11d7 --- /dev/null +++ b/ariba/tests/data/cluster_test_make_reads_for_assembly.in1.fq @@ -0,0 +1,40 @@ +@read1/1 +ACGT ++ +ABCD +@read2/1 +ACGT ++ +ABCD +@read3/1 +ACGT ++ +ABCD +@read4/1 +ACGT ++ +ABCD +@read5/1 +ACGT ++ +ABCD +@read6/1 +ACGT ++ +ABCD +@read7/1 +ACGT ++ +ABCD +@read8/1 +ACGT ++ +ABCD +@read9/1 +ACGT ++ +ABCD +@read10/1 +ACGT ++ +ABCD diff --git a/ariba/tests/data/cluster_test_make_reads_for_assembly.in2.fq b/ariba/tests/data/cluster_test_make_reads_for_assembly.in2.fq new file mode 100644 index 00000000..a7f82df9 --- /dev/null +++ b/ariba/tests/data/cluster_test_make_reads_for_assembly.in2.fq @@ -0,0 +1,40 @@ +@read1/2 +ACGTA ++ +DEFGH +@read2/2 +ACGTA ++ +DEFGH +@read3/2 +ACGTA ++ +DEFGH +@read4/2 +ACGTA ++ +DEFGH +@read5/2 +ACGTA ++ +DEFGH +@read6/2 +ACGTA ++ +DEFGH +@read7/2 +ACGTA ++ +DEFGH +@read8/2 +ACGTA ++ +DEFGH +@read9/2 +ACGTA ++ +DEFGH +@read10/2 +ACGTA ++ +DEFGH diff --git a/ariba/tests/data/cluster_test_make_reads_for_assembly.out1.fq b/ariba/tests/data/cluster_test_make_reads_for_assembly.out1.fq new file mode 100644 index 00000000..0155caf8 --- /dev/null +++ b/ariba/tests/data/cluster_test_make_reads_for_assembly.out1.fq @@ -0,0 +1,28 @@ +@read2/1 +ACGT ++ +ABCD +@read3/1 +ACGT ++ +ABCD +@read5/1 +ACGT ++ +ABCD +@read6/1 +ACGT ++ +ABCD +@read7/1 +ACGT ++ +ABCD +@read8/1 +ACGT ++ +ABCD +@read10/1 +ACGT ++ +ABCD diff --git a/ariba/tests/data/cluster_test_make_reads_for_assembly.out2.fq b/ariba/tests/data/cluster_test_make_reads_for_assembly.out2.fq new file mode 100644 index 00000000..dad30a6b --- /dev/null +++ b/ariba/tests/data/cluster_test_make_reads_for_assembly.out2.fq @@ -0,0 +1,28 @@ +@read2/2 +ACGTA ++ +DEFGH +@read3/2 +ACGTA ++ +DEFGH +@read5/2 +ACGTA ++ +DEFGH +@read6/2 +ACGTA ++ +DEFGH +@read7/2 +ACGTA ++ +DEFGH +@read8/2 +ACGTA ++ +DEFGH +@read10/2 +ACGTA ++ +DEFGH diff --git a/ariba/tests/data/cluster_test_number_of_reads_for_assembly.ref.fa b/ariba/tests/data/cluster_test_number_of_reads_for_assembly.ref.fa new file mode 100644 index 00000000..15481846 --- /dev/null +++ b/ariba/tests/data/cluster_test_number_of_reads_for_assembly.ref.fa @@ -0,0 +1,3 @@ +>ref +TTTCTCGGTACCTCATCACGAGCCTCGTCCATACGCGTACCTTTAGAGGTTATGGACGTA +TGGCTAGTACGTTGATGACAAAGTTGATGTCGGAGCCTAT