Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cdhit #10

Merged
merged 18 commits into from
Feb 20, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Installation
------------

ARIBA has the following dependencies, which need to be installed:
* [cd-hit] [cdhit] version >= 4.6
* [samtools and bcftools] [samtools] version >= 1.2
* [SSPACE-basic scaffolder] [sspace]
* [GapFiller] [gapfiller]
Expand Down Expand Up @@ -39,6 +40,7 @@ Usage
Please read the [ARIBA wiki page] [ARIBA wiki] for usage instructions.


[cdhit]: http://weizhongli-lab.org/cd-hit/
[ARIBA wiki]: https://github.com/sanger-pathogens/ariba/wiki
[gapfiller]: http://www.baseclear.com/genomics/bioinformatics/basetools/gapfiller
[mummer]: http://mummer.sourceforge.net/
Expand Down
2 changes: 2 additions & 0 deletions ariba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
__all__ = [
'bam_parse',
'cdhit',
'cluster',
'clusters',
'common',
'external_progs',
'faidx',
'flag',
'histogram',
'link',
Expand Down
121 changes: 121 additions & 0 deletions ariba/cdhit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import tempfile
import shutil
import os
import pyfastaq
from ariba import common

class Error (Exception): pass



class Runner:
def __init__(
self,
infile,
outfile,
seq_identity_threshold=0.9,
threads=1,
length_diff_cutoff=0.9,
verbose=False,
):

if not os.path.exists(infile):
raise Error('File not found: "' + infile + '". Cannot continue')

self.infile = os.path.abspath(infile)
self.outfile = os.path.abspath(outfile)
self.seq_identity_threshold = seq_identity_threshold
self.threads = threads
self.length_diff_cutoff = length_diff_cutoff
self.verbose = verbose


def run(self):
tmpdir = tempfile.mkdtemp(prefix='tmp.run_cd-hit.', dir=os.getcwd())
cdhit_fasta = os.path.join(tmpdir, 'cdhit')
cluster_info_outfile = cdhit_fasta + '.bak.clstr'
infile_renamed = os.path.join(tmpdir, 'input.renamed.fa')

# cd-hit truncates all names to 19 bases in its report of which
# sequences belong to which clusters. So need to temporarily
# rename all sequences to have short enough names. Grrr.
new_to_old_name = self._enumerate_fasta(self.infile, infile_renamed)

cmd = ' '.join([
'cd-hit',
'-i', infile_renamed,
'-o', cdhit_fasta,
'-c', str(self.seq_identity_threshold),
'-T', str(self.threads),
'-s', str(self.length_diff_cutoff),
'-bak 1',
])

common.syscall(cmd, verbose=self.verbose)

cluster_representatives = self._get_ids(cdhit_fasta)
clusters, cluster_rep_to_cluster = self._parse_cluster_info_file(cluster_info_outfile, new_to_old_name, cluster_representatives)
self._rename_fasta(cdhit_fasta, self.outfile, cluster_rep_to_cluster)
shutil.rmtree(tmpdir)
return clusters


def _enumerate_fasta(self, infile, outfile):
rename_file = outfile + '.tmp.rename_info'
assert not os.path.exists(rename_file)
pyfastaq.tasks.enumerate_names(infile, outfile, rename_file=rename_file)

with open(rename_file) as f:
lines = [x.rstrip().split('\t') for x in f.readlines() if x != '#old\tnew\n']
new_to_old_name = {x[1]: x[0] for x in lines}
if len(lines) != len(new_to_old_name):
raise Error('Sequence names in input file not unique! Cannot continue')

os.unlink(rename_file)
return new_to_old_name


def _rename_fasta(self, infile, outfile, names_dict):
seq_reader = pyfastaq.sequences.file_reader(infile)
f = pyfastaq.utils.open_file_write(outfile)
for seq in seq_reader:
seq.id = names_dict[seq.id]
print(seq, file=f)

pyfastaq.utils.close(f)


def _parse_cluster_info_file(self, infile, names_dict, cluster_representatives):
f = pyfastaq.utils.open_file_read(infile)
clusters = {}
cluster_representative_to_cluster_number = {}
for line in f:
data = line.rstrip().split()
cluster = data[0]
seqname = data[2]
if not (seqname.startswith('>') and seqname.endswith('...')):
raise Error('Unexpected format of sequence name in line:\n' + line)
seqname = seqname[1:-3]

if seqname in cluster_representatives:
cluster_representative_to_cluster_number[seqname] = cluster

seqname = names_dict[seqname]

if cluster not in clusters:
clusters[cluster] = set()

if seqname in clusters[cluster]:
raise Error('Duplicate name "' + seqname + '" found in cluster ' + str(cluster))

clusters[cluster].add(seqname)

pyfastaq.utils.close(f)

return clusters, cluster_representative_to_cluster_number


def _get_ids(self, infile):
seq_reader = pyfastaq.sequences.file_reader(infile)
return set([seq.id for seq in seq_reader])

119 changes: 114 additions & 5 deletions ariba/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import operator
import pyfastaq
import pymummer
from ariba import common, mapping, bam_parse, flag
from ariba import common, mapping, bam_parse, flag, faidx

class Error (Exception): pass


class Cluster:
def __init__(self,
root_dir,
name,
assembly_kmer=0,
assembler='velvet',
max_insert=1000,
Expand All @@ -39,22 +40,24 @@ def __init__(self,
sspace_exe='SSPACE_Basic_v2.0.pl',
velvet_exe='velvet', # prefix of velvet{g,h}
spades_other=None,
clean=1,
):

self.root_dir = os.path.abspath(root_dir)
if not os.path.exists(self.root_dir):
raise Error('Directory ' + self.root_dir + ' not found. Cannot continue')

self.name = name
self.reads1 = os.path.join(self.root_dir, 'reads_1.fq')
self.reads2 = os.path.join(self.root_dir, 'reads_2.fq')
self.gene_fa = os.path.join(self.root_dir, 'gene.fa')
self.genes_fa = os.path.join(self.root_dir, 'genes.fa')
self.gene_bam = os.path.join(self.root_dir, 'gene.reads_mapped.bam')

for fname in [self.reads1, self.reads2, self.gene_fa]:
for fname in [self.reads1, self.reads2, self.genes_fa]:
if not os.path.exists(fname):
raise Error('File ' + fname + ' not found. Cannot continue')

self.gene = self._get_gene()

self.max_insert = max_insert
self.min_scaff_depth = min_scaff_depth
Expand Down Expand Up @@ -104,6 +107,7 @@ def __init__(self,
self.unique_threshold = unique_threshold
self.status_flag = flag.Flag()
self.flag_file = os.path.join(self.root_dir, 'flag')
self.clean = clean

self.assembly_dir = os.path.join(self.root_dir, 'Assembly')
try:
Expand All @@ -123,7 +127,64 @@ def __init__(self,
self.variants = {}


def _get_gene(self):
def _get_total_alignment_score(self, gene_name):
tmp_bam = os.path.join(self.root_dir, 'tmp.get_total_alignment_score.bam')
assert not os.path.exists(tmp_bam)
tmp_fa = os.path.join(self.root_dir, 'tmp.get_total_alignment_score.ref.fa')
assert not os.path.exists(tmp_fa)
faidx.write_fa_subset([gene_name], self.genes_fa, tmp_fa, samtools_exe=self.samtools_exe, verbose=self.verbose)
mapping.run_smalt(
self.reads1,
self.reads2,
tmp_fa,
tmp_bam[:-4],
threads=self.threads,
samtools=self.samtools_exe,
smalt=self.smalt_exe,
verbose=self.verbose,
)

score = mapping.get_total_alignment_score(tmp_bam)
os.unlink(tmp_bam)
os.unlink(tmp_fa)
os.unlink(tmp_fa + '.fai')
return score


def _get_best_gene_by_alignment_score(self):
cluster_size = pyfastaq.tasks.count_sequences(self.genes_fa)
if cluster_size == 1:
seqs = {}
pyfastaq.tasks.file_to_dict(self.genes_fa, seqs)
assert len(seqs) == 1
gene_name = list(seqs.values())[0].id
if self.verbose:
print('No need to choose gene for this cluster because only has one gene:', gene_name)
return gene_name

if self.verbose:
print('\nChoosing best gene from cluster of', cluster_size, 'genes...')
file_reader = pyfastaq.sequences.file_reader(self.genes_fa)
best_score = 0
best_gene_name = None
for seq in file_reader:
score = self._get_total_alignment_score(seq.id)
if self.verbose:
print('Total alignment score for gene', seq.id, 'is', score)
if score > best_score:
best_score = score
best_gene_name = seq.id

if self.verbose:
print('Best gene is', best_gene_name, 'with total alignment score of', best_score)
print()

return best_gene_name


def _choose_best_gene(self):
gene_name = self._get_best_gene_by_alignment_score()
faidx.write_fa_subset([gene_name], self.genes_fa, self.gene_fa, samtools_exe=self.samtools_exe, verbose=self.verbose)
seqs = {}
pyfastaq.tasks.file_to_dict(self.gene_fa, seqs)
assert len(seqs) == 1
Expand Down Expand Up @@ -342,6 +403,7 @@ def _fix_contig_orientation(self):
else:
to_revcomp.add(hit.qry_name)

os.unlink(tmp_coords)
in_both = to_revcomp.intersection(not_revcomp)
for name in in_both:
print('WARNING: hits to both strands of gene for scaffold. Interpretation of any variants cannot be trusted', name, file=sys.stderr)
Expand Down Expand Up @@ -649,7 +711,7 @@ def _make_report_lines(self):
self.report_lines = []

if len(self.variants) == 0:
self.report_lines.append([self.gene.id, self.status_flag.to_number(), len(self.gene)] + ['.'] * 11)
self.report_lines.append([self.gene.id, self.status_flag.to_number(), self.name, len(self.gene)] + ['.'] * 11)

for contig in self.variants:
for variants in self.variants[contig]:
Expand All @@ -660,6 +722,7 @@ def _make_report_lines(self):
self.report_lines.append([
self.gene.id,
self.status_flag.to_number(),
self.name,
len(self.gene),
pymummer.variant.var_types[v.var_type],
effect,
Expand All @@ -675,7 +738,52 @@ def _make_report_lines(self):
])


def _clean(self):
if self.verbose:
print('Cleaning', self.root_dir)

if self.clean > 0:
if self.verbose:
print(' rm -r', self.assembly_dir)
shutil.rmtree(self.assembly_dir)

to_clean = [
[
'assembly.reads_mapped.unsorted.bam',
],
[
'assembly.fa.fai',
'assembly.reads_mapped.bam.scaff',
'assembly.reads_mapped.bam.soft_clipped',
'assembly.reads_mapped.bam.unmapped_mates',
'assembly_vs_gene.coords',
'assembly_vs_gene.coords.snps',
'genes.fa',
'genes.fa.fai',
'reads_1.fq',
'reads_2.fq',
],
[
'assembly.fa.fai',
'assembly.reads_mapped.bam',
'assembly.reads_mapped.bam.vcf',
'assembly_vs_gene.coords',
'assembly_vs_gene.coords.snps',
]
]

for i in range(self.clean + 1):
for fname in to_clean[i]:
fullname = os.path.join(self.root_dir, fname)
if os.path.exists(fullname):
if self.verbose:
print(' rm', fname)
os.unlink(fullname)


def run(self):
self.gene = self._choose_best_gene()

if self.assembler == 'velvet':
self._assemble_with_velvet()
elif self.assembler == 'spades':
Expand Down Expand Up @@ -720,3 +828,4 @@ def run(self):
self._get_vcf_variant_counts()

self._make_report_lines()
self._clean()
Loading