-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #31 from zellerlab/refactor/next_big_cleanup_20210428
Refactor/next big cleanup 20210428 Merging 0.8.2
- Loading branch information
Showing
11 changed files
with
515 additions
and
487 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import logging | ||
|
||
import pandas as pd | ||
import numpy as np | ||
|
||
# Function to identify the rownames and number of columns in an alignment | ||
def find_raw_names_ncol(file_name): | ||
gene_names = list() | ||
with open(file_name) as f: | ||
for line in f: | ||
gene_names.append(line[:line.find("\t")].replace("/", "-")) | ||
return gene_names, line.count("\t") | ||
|
||
# function to load an alignment produced by the "align" option ================= | ||
# Input: | ||
# - a file created by "align" | ||
# Output: | ||
# - a panda object | ||
# as a note, numpy.loadtxt is way slower than pandas read.csv | ||
# It works also on .gz files | ||
def load_alignment_from_file(file_name, safe_mode=False): | ||
gene_names, align_length = find_raw_names_ncol(file_name) | ||
alignment = pd.DataFrame(False, index=gene_names, columns=range(align_length)) | ||
with open(file_name) as f: | ||
if safe_mode: | ||
for pos, line in enumerate(f): | ||
align = [int(c) for c in line.split("\t")[1:]] | ||
if len(align) != align_length or any((c != 0 and c != 1) for c in align): | ||
raise ValueError(f"Malformatted alignment in line {pos}:\n{gene}\t{''.join(align)}") | ||
alignment.iloc[pos] = np.array([c == 1 for c in align]) | ||
else: | ||
for pos, line in enumerate(f): | ||
alignment.iloc[pos] = np.array([c == "1" for c in line.split("\t")[1:]]) | ||
|
||
logging.info(f' LOAD_AL: Number of genes: {len(list(alignment.index.values))}') | ||
alignment = alignment.drop_duplicates() | ||
logging.info(f' LOAD_AL: Number of genes, after removing duplicates: {len(list(alignment.index.values))}') | ||
return alignment |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,54 @@ | ||
""" | ||
Scripts that find the taxonomy of an aligned sequence | ||
""" | ||
|
||
# Author: Alessio Milanese <[email protected]> | ||
|
||
# Requirements: | ||
# - numpy | ||
# - h5py | ||
|
||
import numpy as np | ||
import sys | ||
import time | ||
import os | ||
import h5py | ||
import tempfile | ||
import shutil | ||
import contextlib | ||
|
||
from stag.taxonomy import Taxonomy | ||
from stag.databases import load_db | ||
import numpy as np | ||
import h5py | ||
|
||
from . import __version__ as tool_version | ||
from stag.taxonomy3 import Taxonomy | ||
from stag.databases import load_db | ||
import stag.align as align | ||
|
||
#=============================================================================== | ||
# FUNCTION TO LOAD ALIGNED SEQUENCES | ||
#=============================================================================== | ||
def alignment_reader(aligned_sequences): | ||
with open(aligned_sequences,"r") as align_in: | ||
for ali_line in align_in: | ||
gene_id, *aligned_seq = ali_line.rstrip().split("\t") | ||
yield gene_id, np.array(list(map(int, aligned_seq)), dtype=bool) | ||
|
||
#=============================================================================== | ||
# TAXONOMICALLY ANNOTATE SEQUENCES | ||
#=============================================================================== | ||
def run_logistic_prediction(seq, coeff_raw): | ||
# the first value of the coeff is the intercept | ||
coeff = coeff_raw[1:] | ||
intercept = coeff_raw[0] | ||
# calculate | ||
intercept, *coeff = coeff_raw | ||
sm = coeff * seq | ||
np_sum = (sm).sum() + intercept | ||
score = 1 / (1 + np.exp(-np_sum)) | ||
return score | ||
return 1 / (1 + np.exp(-np_sum)) | ||
|
||
# given many taxa (all siblings) and a sequence, it finds taxa with the highest | ||
# score. Returns the taxa name and the score | ||
def find_best_score(test_seq, siblings, classifiers): | ||
best_score = -1 | ||
best_taxa = "" | ||
# check that siblings is not empty: | ||
if len(siblings) < 1: | ||
sys.stderr.write("Error. no siblings") | ||
# if there is only one sibiling: | ||
if len(siblings) == 1: | ||
best_score = 2 # if there are no siblings I put 2, it will be replaced after | ||
best_taxa = siblings[0] | ||
if len(siblings) > 1: | ||
for s in siblings: | ||
this_score = run_logistic_prediction(test_seq, classifiers[s]) | ||
best_score, best_taxa = -1, "" | ||
if not siblings: | ||
pass | ||
elif len(siblings) == 1: | ||
# if there are no siblings I put 2, it will be replaced after | ||
best_score, best_taxa = 2, siblings[0] | ||
else: | ||
for sibling in siblings: | ||
this_score = run_logistic_prediction(test_seq, classifiers[sibling]) | ||
if this_score > best_score: | ||
best_score = this_score | ||
best_taxa = s | ||
best_score, best_taxa = this_score, sibling | ||
return best_taxa, best_score | ||
|
||
def predict_iter(test_seq, taxonomy, classifiers, tax, perc, arrived_so_far): | ||
# last iterative step | ||
if taxonomy.get(arrived_so_far) is not None: | ||
t, p = find_best_score(test_seq, taxonomy[arrived_so_far], classifiers) | ||
tax.append(t) | ||
perc.append(p) | ||
if t: | ||
tax.append(t) | ||
perc.append(p) | ||
predict_iter(test_seq, taxonomy, classifiers, tax, perc, t) | ||
|
||
|
||
|
@@ -114,7 +93,7 @@ def find_n_aligned_characters(test_seq): | |
|
||
def classify_seq(gene_id, test_seq, taxonomy, tax_function, classifiers, threads, verbose): | ||
# test_seq is a boolean numpy array corresponding to the encoded aligned sequence | ||
|
||
#print("TAX", taxonomy) | ||
# number of characters that map to the internal states of the HMM | ||
n_aligned_characters = find_n_aligned_characters(test_seq) | ||
|
||
|
@@ -153,7 +132,6 @@ def classify(database, fasta_input=None, protein_fasta_input=None, verbose=3, th | |
long_out=False, current_tool_version=tool_version, | ||
aligned_sequences=None, save_ali_to_file=None, min_perc_state=0, internal_call=False): | ||
t0 = time.time() | ||
# load the database | ||
db = load_db(database, protein_fasta_input=protein_fasta_input, aligned_sequences=aligned_sequences) | ||
hmm_file_path, use_cmalign, taxonomy, tax_function, classifiers, db_tool_version = db | ||
if verbose > 2: | ||
|
Oops, something went wrong.