Skip to content

Commit

Permalink
Merge pull request #115 from martinghunt/het_snp_reporting
Browse files Browse the repository at this point in the history
Het snp reporting
  • Loading branch information
martinghunt authored Aug 2, 2016
2 parents 981bdf0 + fe6c8fb commit a342ec6
Show file tree
Hide file tree
Showing 72 changed files with 6,420 additions and 29 deletions.
4 changes: 0 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ addons:
- liblapack-dev
- libgfortran3
- libncurses5-dev
cache:
directories:
- "build"
- "$HOME/.cache/pip"
python:
- "3.4"
sudo: false
Expand Down
4 changes: 3 additions & 1 deletion ariba/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(self,
self.final_assembly_vcf = os.path.join(self.root_dir, 'assembly.reads_mapped.bam.vcf')
self.samtools_vars_prefix = self.final_assembly_bam
self.assembly_compare = None
self.variants_from_samtools = {}
self.assembly_compare_prefix = os.path.join(self.root_dir, 'assembly_compare')

self.mummer_variants = {}
Expand Down Expand Up @@ -415,7 +416,8 @@ def _run(self):

self.total_contig_depths = self.samtools_vars.total_depth_per_contig(self.samtools_vars.read_depths_file)

if self.samtools_vars.variants_in_coords(self.assembly_compare.assembly_match_coords(), self.samtools_vars.vcf_file):
self.variants_from_samtools = self.samtools_vars.variants_in_coords(self.assembly_compare.assembly_match_coords(), self.samtools_vars.vcf_file)
if len(self.variants_from_samtools):
self.status_flag.add('variants_suggest_collapsed_repeat')
elif not self.assembled_ok:
print('\nAssembly failed\n', file=self.log_fh, flush=True)
Expand Down
31 changes: 30 additions & 1 deletion ariba/report.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import sys
import pymummer

Expand Down Expand Up @@ -169,6 +170,7 @@ def _report_lines_for_one_contig(cluster, contig_name, ref_cov_per_contig, pymum
else:
free_text_column = ';'.join(['.'])

remaining_samtools_variants = copy.copy(cluster.variants_from_samtools)
if cluster.assembled_ok and contig_name in cluster.assembly_variants and len(cluster.assembly_variants[contig_name]) > 0:
for (position, var_seq_type, ref_ctg_change, var_effect, contributing_vars, matching_vars_set, metainfo_set) in cluster.assembly_variants[contig_name]:
if len(matching_vars_set) > 0:
Expand Down Expand Up @@ -196,7 +198,11 @@ def _report_lines_for_one_contig(cluster, contig_name, ref_cov_per_contig, pymum
smtls_alt_depth = []

for var in contributing_vars:
if contig_name in remaining_samtools_variants:
remaining_samtools_variants[contig_name].discard(var.qry_start)

depths_tuple = cluster.samtools_vars.get_depths_at_position(contig_name, var.qry_start)

if depths_tuple is not None:
smtls_alt_nt.append(depths_tuple[1])
smtls_total_depth.append(str(depths_tuple[2]))
Expand Down Expand Up @@ -231,6 +237,10 @@ def _report_lines_for_one_contig(cluster, contig_name, ref_cov_per_contig, pymum

if samtools_columns is None:
samtools_columns = [['.'] * 9]
else:
for ctg_pos in range(int(samtools_columns[3]) - 1, int(samtools_columns[4]), 1):
if contig_name in remaining_samtools_variants:
remaining_samtools_variants[contig_name].discard(ctg_pos)

lines.append('\t'.join(common_first_columns + variant_columns + samtools_columns + [matching_vars_column] + [free_text_column]))
else:
Expand All @@ -239,7 +249,26 @@ def _report_lines_for_one_contig(cluster, contig_name, ref_cov_per_contig, pymum
samtools_columns + \
[matching_vars_column] + [free_text_column]
))
else:

for contig_name in remaining_samtools_variants:
for var_position in remaining_samtools_variants[contig_name]:
depths_tuple = cluster.samtools_vars.get_depths_at_position(contig_name, var_position)
if depths_tuple is not None:
new_cols = [
'0', # known_var column
'HET', # var_type
'.', '.', '.', '.', '.', '.', '.', '.', # var_seq_type ... ref_nt
str(var_position + 1), str(var_position + 1), # ctg_start, ctg_end
depths_tuple[0], # ctg_nt
str(depths_tuple[2]), # smtls_total_depth
depths_tuple[1], # smtls_alt_nt
str(depths_tuple[3]), # smtls_alt_depth
'.',
free_text_column,
]
lines.append('\t'.join(common_first_columns + new_cols))

if len(lines) == 0:
lines.append('\t'.join(common_first_columns + ['.'] * (len(columns) - len(common_first_columns) - 1) + [free_text_column]))

return lines
Expand Down
1 change: 1 addition & 0 deletions ariba/report_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _report_dict_passes_essential_filters(self, report_dict):
return ReportFilter._flag_passes_filter(report_dict['flag'], self.exclude_flags) \
and report_dict['pc_ident'] >= self.min_pc_ident \
and report_dict['ref_base_assembled'] >= self.min_ref_base_assembled \
and report_dict['var_type'] != 'HET'


def _filter_list_of_dicts(self, dicts_list):
Expand Down
8 changes: 5 additions & 3 deletions ariba/samtools_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def total_depth_per_contig(read_depths_file):
def variants_in_coords(nucmer_matches, vcf_file):
'''nucmer_matches = made by assembly_compare.assembly_match_coords().
Returns number of variants that lie in nucmer_matches'''
vcf_variant_counts = {}
found_variants = {}
f = pyfastaq.utils.open_file_read(vcf_file)
for line in f:
if line.startswith('#'):
Expand All @@ -174,10 +174,12 @@ def variants_in_coords(nucmer_matches, vcf_file):
i = pyfastaq.intervals.Interval(position, position)
intersects = len([x for x in nucmer_matches[scaff] if x.intersects(i)]) > 0
if intersects:
vcf_variant_counts[scaff] = vcf_variant_counts.get(scaff, 0) + 1
if scaff not in found_variants:
found_variants[scaff] = set()
found_variants[scaff].add(position)

pyfastaq.utils.close(f)
return sum(list(vcf_variant_counts.values()))
return found_variants


def get_depths_at_position(self, seq_name, position):
Expand Down
40 changes: 39 additions & 1 deletion ariba/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
filter_rows=True,
filter_columns=True,
min_id=90.0,
show_known_het=False,
cluster_cols='assembled,match,ref_seq,pct_id,known_var,novel_var',
variant_cols='groups,grouped,ungrouped,novel',
verbose=False,
Expand All @@ -33,6 +34,7 @@ def __init__(
if fofn is not None:
self.filenames.extend(self._load_fofn(fofn))

self.show_known_het = show_known_het
self.cluster_columns = self._determine_cluster_cols(cluster_cols)
self.var_columns = self._determine_var_cols(variant_cols)
self.filter_rows = filter_rows
Expand Down Expand Up @@ -112,6 +114,17 @@ def _get_all_variant_columns(cls, samples_dict):
return columns


@classmethod
def _get_all_het_snps(cls, samples_dict):
snps = set()
for filename, sample in samples_dict.items():
for cluster, snp_dict in sample.het_snps.items():
if len(snp_dict):
for snp in snp_dict:
snps.add((cluster, snp))

return snps

@classmethod
def _get_all_var_groups(cls, samples_dict):
groups = {}
Expand All @@ -126,6 +139,8 @@ def _get_all_var_groups(cls, samples_dict):
def _gather_output_rows(self):
all_cluster_names = Summary._get_all_cluster_names(self.samples)
all_var_columns = Summary._get_all_variant_columns(self.samples)
all_het_snps = Summary._get_all_het_snps(self.samples)

if self.var_columns['groups']:
var_groups = Summary._get_all_var_groups(self.samples)
else:
Expand Down Expand Up @@ -163,12 +178,22 @@ def _gather_output_rows(self):
continue

key = ref_name + '.' + variant

if rows[filename][cluster]['assembled'] == 'no':
rows[filename][cluster][key] = 'NA'
elif cluster in sample.variant_column_names_tuples and (ref_name, variant, grouped_or_novel, group_name) in sample.variant_column_names_tuples[cluster]:
rows[filename][cluster][key] = 'yes'
if self.show_known_het:
if cluster in sample.het_snps and variant in sample.het_snps[cluster]:
rows[filename][cluster][key] = 'het'
rows[filename][cluster][key + '.%'] = sample.het_snps[cluster][variant]
else:
rows[filename][cluster][key] = 'no'
if self.show_known_het and (cluster, variant) in all_het_snps:
rows[filename][cluster][key + '.%'] = 'NA'

if self.show_known_het and (cluster, variant) in all_het_snps and key + '.%' not in rows[filename][cluster]:
rows[filename][cluster][key + '.%'] = 'NA'

for key, wanted in self.cluster_columns.items():
if not wanted:
Expand Down Expand Up @@ -213,7 +238,8 @@ def _to_matrix(cls, filenames, rows, cluster_cols):

if making_header_lines:
csv_header.append(cluster_name + '.' + col)
phandango_header.append(cluster_name + '.' + col + ':o1')
suffix = ':c2' if col.endswith('.%') else ':o1'
phandango_header.append(cluster_name + '.' + col + suffix)

line.append(rows[filename][cluster_name][col])

Expand Down Expand Up @@ -270,6 +296,7 @@ def _add_phandango_colour_columns(cls, header, matrix):
'yes_nonunique': '#a6cee3',
'no': '#33a02c',
'NA': '#b2df8a',
'het': '#fb9a99',
}

cols_to_add_colour_col.reverse()
Expand Down Expand Up @@ -351,6 +378,12 @@ def run(self):
self.rows = self._gather_output_rows()
phandango_header, csv_header, matrix = Summary._to_matrix(self.filenames, self.rows, self.cluster_columns)

# sanity check same number of columns in headers and matrix
lengths = {len(x) for x in matrix}
print(lengths, len(phandango_header), len(csv_header))
assert len(lengths) == 1
assert len(matrix[0]) == len(phandango_header) == len(csv_header)

if self.filter_rows:
if self.verbose:
print('Filtering rows', flush=True)
Expand All @@ -368,6 +401,11 @@ def run(self):
if len(matrix) == 0 or len(matrix[0]) == 0:
print('No columns left after filtering columns. Cannot continue', file=sys.stderr)

# sanity check same number of columns in headers and matrix
lengths = {len(x) for x in matrix}
assert len(lengths) == 1
assert len(matrix[0]) == len(phandango_header) == len(csv_header)

csv_file = self.outprefix + '.csv'
if self.verbose:
print('Writing csv file', csv_file, flush=True)
Expand Down
41 changes: 41 additions & 0 deletions ariba/summary_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,38 @@ def _to_cluster_summary_has_nonsynonymous(self, assembled_summary):
return self._has_any_nonsynonymous()


@staticmethod
def _get_known_noncoding_het_snp(data_dict):
'''If ref is coding, return None. If the data dict has a known snp, and
samtools made a call, then return the string ref_name_change and the
% of reads supporting the variant type. If noncoding, but no
samtools call, then return None'''
if data_dict['gene'] == '1':
return None

if data_dict['known_var'] == '1' and data_dict['ref_ctg_effect'] == 'SNP' \
and data_dict['smtls_alt_nt'] != '.' and ';' not in data_dict['smtls_alt_nt']:
nucleotides = [data_dict['ctg_nt']] + data_dict['smtls_alt_nt'].split(',')
depths = data_dict['smtls_alt_depth'].split(',')

if len(nucleotides) != len(depths):
raise Error('Mismatch in number of inferred nucleotides from ctg_nt, smtls_alt_nt, smtls_alt_depth columns. Cannot continue\n' + str(data_dict))

try:
var_nucleotide = data_dict['known_var_change'][-1]
depths = [int(x) for x in depths]
nuc_to_depth = dict(zip(nucleotides, depths))
total_depth = sum(depths)
var_depth = nuc_to_depth.get(var_nucleotide, 0)
percent_depth = 100 * var_depth / total_depth
except:
return None

return data_dict['known_var_change'], percent_depth
else:
return None


@staticmethod
def _get_nonsynonymous_var(data_dict):
'''if data_dict has a non synonymous variant, return string:
Expand Down Expand Up @@ -256,3 +288,12 @@ def non_synon_variants(self):
variants = {self._get_nonsynonymous_var(d) for d in self.data}
variants.discard(None)
return variants


def known_noncoding_het_snps(self):
snps = {}
for d in self.data:
snp_tuple = self._get_known_noncoding_het_snp(d)
if snp_tuple is not None:
snps[snp_tuple[0]] = snp_tuple[1]
return snps
10 changes: 7 additions & 3 deletions ariba/summary_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,22 @@ def _var_groups(self):
return {c: self.clusters[c].has_var_groups() for c in self.clusters}


def _variant_column_names_tuples(self):
def _variant_column_names_tuples_and_het_snps(self):
variants = {}
het_snps = {}
for cluster_name, cluster in self.clusters.items():
cluster_vars = cluster.non_synon_variants()
cluster_noncoding_het_snps = cluster.known_noncoding_het_snps()

if len(cluster_vars):
variants[cluster_name] = cluster_vars
return variants
het_snps[cluster_name] = cluster_noncoding_het_snps
return variants, het_snps


def run(self):
self.clusters = self._load_file(self.report_tsv, self.min_pc_id)
self.column_summary_data = self._column_summary_data()
self.variant_column_names_tuples = self._variant_column_names_tuples()
self.variant_column_names_tuples, self.het_snps = self._variant_column_names_tuples_and_het_snps()
self.var_groups = self._var_groups()

1 change: 1 addition & 0 deletions ariba/tasks/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def run(options):
filter_rows=options.col_filter == 'y',
filter_columns=options.row_filter == 'y',
min_id=options.min_id,
show_known_het=options.het,
cluster_cols=options.cluster_cols,
variant_cols=options.var_cols,
verbose=options.verbose
Expand Down
Loading

0 comments on commit a342ec6

Please sign in to comment.