Skip to content

Commit

Permalink
Merge pull request #242 from martinghunt/tb_amr
Browse files Browse the repository at this point in the history
Tb amr
  • Loading branch information
martinghunt authored Nov 16, 2018
2 parents cf8e126 + 32d1969 commit dc378cf
Show file tree
Hide file tree
Showing 38 changed files with 315,836 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ addons:
- libgfortran3
- libncurses5-dev
python:
- '3.4'
- '3.5'
sudo: false
install:
- source ./install_dependencies.sh
Expand All @@ -17,4 +17,4 @@ before_script:
script:
- coverage run setup.py test
after_success:
- codecov
- codecov
1 change: 1 addition & 0 deletions ariba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
'summary_cluster_variant',
'summary_sample',
'tasks',
'tb',
'versions',
'vfdb_parser',
]
Expand Down
16 changes: 15 additions & 1 deletion ariba/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import time
import os
import copy
import json
import tempfile
import pickle
import itertools
import sys
import multiprocessing
import pyfastaq
import minimap_ariba
from ariba import cluster, common, histogram, mlst_reporter, read_store, report, report_filter, reference_data
from ariba import cluster, common, histogram, mlst_reporter, read_store, report, report_filter, reference_data, tb

class Error (Exception): pass

Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(self,
self.report_file_filtered = os.path.join(self.outdir, 'report.tsv')
self.mlst_reports_prefix = os.path.join(self.outdir, 'mlst_report')
self.mlst_profile_file = os.path.join(self.refdata_dir, 'pubmlst.profile.txt')
self.tb_resistance_calls_file = os.path.join(self.outdir, 'tb.resistance.json')
self.catted_assembled_seqs_fasta = os.path.join(self.outdir, 'assembled_seqs.fa.gz')
self.catted_genes_matching_refs_fasta = os.path.join(self.outdir, 'assembled_genes.fa.gz')
self.catted_assemblies_fasta = os.path.join(self.outdir, 'assemblies.fa.gz')
Expand Down Expand Up @@ -226,12 +228,14 @@ def _load_reference_data_from_dir(indir):
fasta_file = os.path.join(indir, '02.cdhit.all.fa')
metadata_file = os.path.join(indir, '01.filter.check_metadata.tsv')
info_file = os.path.join(indir, '00.info.txt')
parameters_file = os.path.join(indir, '00.params.json')
clusters_pickle_file = os.path.join(indir, '02.cdhit.clusters.pickle')
params = Clusters._load_reference_data_info_file(info_file)
refdata = reference_data.ReferenceData(
[fasta_file],
[metadata_file],
genetic_code=params['genetic_code'],
parameters_file=parameters_file,
)

with open(clusters_pickle_file, 'rb') as f:
Expand Down Expand Up @@ -587,6 +591,13 @@ def _write_mlst_reports(cls, mlst_profile_file, ariba_report_tsv, outprefix, ver
reporter.run()


@classmethod
def _write_tb_resistance_calls_json(cls, ariba_report_tsv, outfile):
calls = tb.report_to_resistance_dict(ariba_report_tsv)
with open(outfile, 'w') as f:
json.dump(calls, f, sort_keys=True, indent=4)


def write_versions_file(self, original_dir):
with open('version_info.txt', 'w') as f:
print('ARIBA run with this command:', file=f)
Expand Down Expand Up @@ -667,6 +678,9 @@ def _run(self):

Clusters._write_mlst_reports(self.mlst_profile_file, self.report_file_filtered, self.mlst_reports_prefix, verbose=self.verbose)

if 'tb' in self.refdata.extra_parameters and self.refdata.extra_parameters['tb']:
Clusters._write_tb_resistance_calls_json(self.report_file_filtered, self.tb_resistance_calls_file)

if self.clusters_all_ran_ok and self.verbose:
print('\nAll done!\n')
finally:
Expand Down
8 changes: 8 additions & 0 deletions ariba/reference_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import sys
import re
Expand All @@ -19,6 +20,7 @@ def __init__(self,
min_gene_length=6,
max_gene_length=10000,
genetic_code=11,
parameters_file=None,
):
self.seq_filenames = {}
self.seq_dicts = {}
Expand All @@ -38,6 +40,12 @@ def __init__(self,
else:
self.ariba_to_original_name = ReferenceData._load_rename_file(rename_file)

if parameters_file is None or not os.path.exists(parameters_file):
self.extra_parameters = {}
else:
with open(parameters_file) as f:
self.extra_parameters = json.load(f)


@classmethod
def _load_rename_file(cls, filename):
Expand Down
1 change: 1 addition & 0 deletions ariba/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
'getref',
'micplot',
'prepareref',
'prepareref_tb',
'pubmlstget',
'pubmlstspecies',
'refquery',
Expand Down
7 changes: 7 additions & 0 deletions ariba/tasks/prepareref_tb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import sys
import argparse
from ariba import tb

def run(options):
tb.make_prepareref_dir(options.outdir)

249 changes: 249 additions & 0 deletions ariba/tb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import csv
import json
import os
import re
import sys
import tempfile

from Bio import SeqIO
import pyfastaq

from ariba import common, flag, ref_preparer

data_dir = os.path.join(os.path.dirname(__file__), 'tb_data')
assert os.path.exists(data_dir)


def report_to_resistance_dict(infile):
'''Takes final ariba report.tsv file, and extracts
resistance calls, returning a dict of
drug name -> list of mutations.
each "mutation" in the list is a tuple of (gene name, mutation).
Mutation is of the form X42Y, or "incomplete_gene" for katG and
pncA when they are not complete.
This all assumes that the reference data are in the "correct"
form, where the variant descriptions in the var_description column of the
TSV file ends with a comma-separated list of the drug names'''
complete_genes = {'katG': 'Isoniazid', 'pncA': 'Pyrazinamide'}
res_calls = {}
incomplete_genes = set()
with open(infile) as f:
reader = csv.DictReader(f, delimiter='\t')
for d in reader:
if d['ref_name'] in complete_genes and d['gene'] == '1':
f = flag.Flag(int(d['flag']))
if not f.has('complete_gene'):
incomplete_genes.add(d['ref_name'])

if d['has_known_var'] == '1':
if 'Original mutation' in d['var_description']:
drugs = d['var_description'].split(':')[-1].split('.')[0].split()[-1].split(',')
change = d['var_description'].split()[-1]
else:
drugs = d['var_description'].split()[-1].split(',')
change = d['known_var_change']
for drug in drugs:
if drug not in res_calls:
res_calls[drug] = []
res_calls[drug].append((d['ref_name'], change))

for gene in incomplete_genes:
drug = complete_genes[gene]
if drug not in res_calls:
res_calls[drug] = []
res_calls[drug].append((gene, 'Incomplete_gene'))

return res_calls


def genbank_to_gene_coords(infile, genes):
'''Input file in genbank format. genes = list of gene names to find.
Returns dict of gene name -> {start: x, end: y}, where x and y are
zero-based. x<y iff gene is on forwards strand'''
coords = {}

for seq_record in SeqIO.parse(infile, "genbank"):
for feature in seq_record.features:
if feature.type == 'gene':
gene_name = feature.qualifiers.get('gene', [None])[0]
if gene_name not in genes:
continue

if feature.location.strand == 1:
coords[gene_name] = {'start': int(feature.location.start), 'end': int(feature.location.end) - 1}
else:
coords[gene_name] = {'end': int(feature.location.start), 'start': int(feature.location.end) - 1}

return coords


def load_mutations(gene_coords, mutation_to_drug_json, variants_txt, upstream_before=100):
'''Load mutations from "mykrobe-style" files. mutation_to_drug_json is json file
of mutation -> list of drugs. variants_txt is text file of variants used my mykrobe's
make probes. gene_coords should be dict of gene coords made by the function
genbank_to_gene_coords'''
with open(mutation_to_drug_json) as f:
drug_data = json.load(f)

mutations = []
genes_with_indels = set()
genes_need_upstream = set()
genes_non_upstream = set()

with open(variants_txt) as f:
for line in f:
gene, variant, d_or_p = line.rstrip().split('\t')
coding = 0 if gene == 'rrs' else 1
d = {'gene': gene, 'var': variant, 'coding': coding, 'upstream': False}
drug_data_key = d['gene'] + '_' + d['var']
if drug_data_key not in drug_data:
print('KEY', drug_data_key, 'NOT FOUND', file=sys.stderr)
else:
d['drugs'] = ','.join(sorted(drug_data[drug_data_key]))

if d_or_p == 'DNA' and gene != 'rrs':
assert gene != 'rrs'
re_match = re.match('([ACGT]+)(-?[0-9]+)([ACGTX]+)', d['var'])
try:
ref, pos, alt = re_match.groups()
except:
print('regex error:', d['var'], file=sys.stderr)
continue

pos = int(pos)
if len(ref) != len(alt):
genes_with_indels.add(d['gene'])
continue
elif pos > 0:
#print('ignoring synonymous change (not implemented):', d['gene'], d['var'], d['drugs'], file=sys.stderr)
continue
elif pos < 0:
this_gene_coords = gene_coords[d['gene']]
d['upstream'] = True
if this_gene_coords['start'] < this_gene_coords['end']:
variant_pos_in_output_seq = upstream_before + pos + 1
else:
variant_pos_in_output_seq = upstream_before + pos + 1
assert variant_pos_in_output_seq > 0
d['var'] = ref + str(variant_pos_in_output_seq) + alt
d['original_mutation'] = variant
genes_need_upstream.add(d['gene'])
elif pos == 0:
print('Zero coord!', d, file=sys.stderr)
continue
else:
print('deal with?', d, file=sys.stderr)
continue

mutations.append(d)
if not d['upstream']:
genes_non_upstream.add(d['gene'])

return mutations, genes_with_indels, genes_need_upstream, genes_non_upstream


def write_prepareref_fasta_file(outfile, gene_coords, genes_need_upstream, genes_non_upstream, upstream_before=100, upstream_after=100):
'''Writes fasta file to be used with -f option of prepareref'''
tmp_dict = {}
fasta_in = os.path.join(data_dir, 'NC_000962.3.fa.gz')
pyfastaq.tasks.file_to_dict(fasta_in, tmp_dict)
ref_seq = tmp_dict['NC_000962.3']

with open(outfile, 'w') as f:
for gene in genes_non_upstream:
start = gene_coords[gene]['start']
end = gene_coords[gene]['end']
if start < end:
gene_fa = pyfastaq.sequences.Fasta(gene, ref_seq[start:end+1])
else:
gene_fa = pyfastaq.sequences.Fasta(gene, ref_seq[end:start+1])
gene_fa.revcomp()

print(gene_fa, file=f)

for gene in genes_need_upstream:
start = gene_coords[gene]['start']
end = gene_coords[gene]['end']
if start < end:
gene_fa = pyfastaq.sequences.Fasta(gene, ref_seq[start - upstream_before:start + upstream_after])
else:
gene_fa = pyfastaq.sequences.Fasta(gene, ref_seq[start - upstream_after + 1:start + upstream_before + 1])
gene_fa.revcomp()

gene_fa.id += '_upstream'
print(gene_fa, file=f)


def write_prepareref_metadata_file(mutations, outfile):
aa_letters = {'G', 'P', 'A', 'V', 'L', 'I', 'M', 'C', 'F', 'Y',
'W', 'H', 'K', 'R', 'Q', 'N', 'E', 'D', 'S', 'T'}

with open(outfile, 'w') as f:
for d in mutations:
if d['upstream']:
gene = d['gene'] + '_upstream'
coding = '0'
else:
gene = d['gene']
coding = d['coding']

if 'original_mutation' in d:
original_mutation_string = '. Original mutation ' + d['original_mutation']
else:
original_mutation_string = ''

if d['var'].endswith('X'):
ref = d['var'][0]
if d['coding'] == 1 and not d['upstream']:
letters = aa_letters
else:
letters = {'A', 'C', 'G', 'T'}

assert ref in letters

for x in letters:
if x == ref:
continue
variant = d['var'][:-1] + x
print(gene, coding, 1, variant, '.', 'Resistance to ' + d['drugs'] + original_mutation_string, sep='\t', file=f)
else:
print(gene, coding, 1, d['var'], '.', 'Resistance to ' + d['drugs'] + original_mutation_string, sep='\t', file=f)


def make_prepareref_files(outprefix):
genbank_file = os.path.join(data_dir, 'NC_000962.3.gb')
mut_to_drug_json = os.path.join(data_dir, 'panel.20181115.json')
panel_txt_file = os.path.join(data_dir, 'panel.20181115.txt')
fasta_out = outprefix + '.fa'
tsv_out = outprefix + '.tsv'

with open(panel_txt_file) as f:
genes = set([x.split()[0] for x in f])

ref_gene_coords = genbank_to_gene_coords(genbank_file, genes)
mutations, genes_with_indels, genes_need_upstream, genes_non_upstream = load_mutations(ref_gene_coords, mut_to_drug_json, panel_txt_file)
write_prepareref_fasta_file(fasta_out, ref_gene_coords, genes_need_upstream, genes_non_upstream)
write_prepareref_metadata_file(mutations, tsv_out)


def make_prepareref_dir(outdir):
if os.path.exists(outdir):
raise Exception('Output directory ' + outdir + ' already exists. Cannot continue')

tmpdir = tempfile.mkdtemp(prefix=outdir + '.tmp', dir=os.getcwd())
tmp_prefix = os.path.join(tmpdir, 'out')
make_prepareref_files(tmp_prefix)
ref_prep = ref_preparer.RefPreparer(
[tmp_prefix + '.fa'],
None,
metadata_tsv_files=[tmp_prefix + '.tsv'],
run_cdhit=False,
threads=1,
)
ref_prep.run(outdir)
common.rmtree(tmpdir)

json_data = {'tb': True}
json_file = os.path.join(outdir, '00.params.json')
with open(json_file, 'w') as f:
json.dump(json_data, f)
Binary file added ariba/tb_data/NC_000962.3.fa.gz
Binary file not shown.
Loading

0 comments on commit dc378cf

Please sign in to comment.