Skip to content

Commit

Permalink
Merge pull request #31 from zellerlab/refactor/next_big_cleanup_20210428
Browse files Browse the repository at this point in the history
Refactor/next big cleanup 20210428
Merging 0.8.2
  • Loading branch information
AlessioMilanese authored May 6, 2021
2 parents e420b3f + 0db1376 commit 349009f
Show file tree
Hide file tree
Showing 11 changed files with 515 additions and 487 deletions.
4 changes: 2 additions & 2 deletions conda_env_stag.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: stag
name: stag_0.8.1_tax3
channels:
- bioconda
- defaults
Expand All @@ -11,7 +11,7 @@ dependencies:
- easel
- numpy
- pandas
- scikit-learn
- scikit-learn<0.24
- h5py
- seqtk
- regex
Expand Down
2 changes: 1 addition & 1 deletion stag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.8.1"
__version__ = "0.8.2"
__title__ = "stag"
__author__ = "Alessio Milanese"
__license__ = 'GPLv3+'
Expand Down
4 changes: 2 additions & 2 deletions stag/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def main(argv=None):

# call the function to create the database
create_db.create_db(args.aligned_sequences, args.taxonomy, args.verbose, args.output, args.use_cm_align,
args.template_al, args.intermediate_cross_val, tool_version, args.protein_fasta_input,
args.template_al, args.intermediate_cross_val, args.protein_fasta_input,
args.penalty_logistic, args.solver_logistic, procs=args.threads)

# --------------------------------------------------------------------------
Expand Down Expand Up @@ -368,7 +368,7 @@ def main(argv=None):
# SECOND: CREATE_DB ----------------------------------------------------
# call the function to create the database
create_db.create_db(al_file.name, args.taxonomy, args.verbose, args.output, args.use_cm_align,
args.template_al, args.intermediate_cross_val, tool_version, args.protein_fasta_input,
args.template_al, args.intermediate_cross_val, args.protein_fasta_input,
args.penalty_logistic, args.solver_logistic, procs=args.threads)

# what to do with intermediate alignment -------------------------------
Expand Down
51 changes: 17 additions & 34 deletions stag/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,36 +118,27 @@ def align_generator(seq_file, protein_file, hmm_file, use_cmalign, n_threads, ve
n_pass, n_not_pass = 0, 0
# check that the tools are available
if use_cmalign and not is_tool("cmalign"):
sys.stderr.write("[E::align] Error: cmalign is not in the path. Please install Infernal.\n")
sys.exit(1)
raise ValueError("[E::align] Error: cmalign is not in the path. Please install Infernal.")
elif not is_tool("hmmalign"):
sys.stderr.write("[E::align] Error: hmmalign is not in the path. Please install HMMER3.\n")
sys.exit(1)
raise ValueError("[E::align] Error: hmmalign is not in the path. Please install HMMER3.")
if not is_tool("esl-reformat"):
sys.stderr.write("[E::align] Error: esl-reformat is not in the path. Please install Easel.\n")
sys.exit(1)
raise ValueError("[E::align] Error: esl-reformat is not in the path. Please install Easel.")

# prepare the command to run
cmd = "hmmalign "
if use_cmalign:
cmd = "cmalign --cpu "+str(n_threads)+" "

if not protein_file:
cmd = cmd + hmm_file +" "+ seq_file
else:
cmd = cmd + hmm_file +" "+ protein_file
aligner = f"cmalign --cpu {n_threads}" if use_cmalign else "hmmalign"
seq_input = protein_file if protein_file else seq_file
align_cmd = f"{aligner} {hmm_file} {seq_input}"

if verbose > 4:
sys.stderr.write("Command used to align the sequences: "+cmd+"\n")
print(f"Command used to align the sequences: {align_cmd}", file=sys.stderr)

# run the command
CMD = shlex.split(cmd)
align_cmd = subprocess.Popen(CMD,stdout=subprocess.PIPE,)
CMD = shlex.split(align_cmd)
align_cmd = subprocess.Popen(CMD, stdout=subprocess.PIPE,)

# command to parse the alignment from STOCKHOLM to fasta format
cmd2 = "esl-reformat a2m -"
CMD2 = shlex.split(cmd2)
parse_cmd = subprocess.Popen(CMD2,stdin=align_cmd.stdout,stdout=subprocess.PIPE,)
parse_cmd = subprocess.Popen(CMD2, stdin=align_cmd.stdout, stdout=subprocess.PIPE,)

if protein_file:
seq_stream = zip(read_fasta(parse_cmd.stdout, head_start=1),
Expand Down Expand Up @@ -176,19 +167,17 @@ def align_generator(seq_file, protein_file, hmm_file, use_cmalign, n_threads, ve
align_cmd.stdout.close()
return_code = align_cmd.wait()
if return_code:
sys.stderr.write("[E::align] Error. hmmalign/cmalign failed\n")
sys.exit(1)
raise ValueError("[E::align] Error. hmmalign/cmalign failed.")
# check that converting the file worked correctly
parse_cmd.stdout.close()
return_code = parse_cmd.wait()
if return_code:
sys.stderr.write("[E::align] Error. esl-reformat failed\n")
sys.exit(1)
raise ValueError("[E::align] Error. esl-reformat failed.")

# print the number of sequences that were filtered
if verbose > 3:
sys.stderr.write(" Number of sequences that pass the filter: "+str(n_pass)+"\n")
sys.stderr.write(" Number of sequences that do not pass the filter: "+str(n_not_pass)+"\n")
print(f" Number of sequences that pass the filter: {n_pass}", file=sys.stderr)
print(f" Number of sequences that do not pass the filter: {n_not_pass}", file=sys.stderr)

# ------------------------------------------------------------------------------
# main function
Expand All @@ -210,26 +199,20 @@ def align_file(seq_file, protein_file, hmm_file, use_cmalign, n_threads, verbose
It will save the aligned sequences to the specified file.
"""

# open the temporary file where to save the result
temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w")
os.chmod(temp_file.name, 0o644)
with temp_file:
for gid, ali in align_generator(seq_file, protein_file, hmm_file, use_cmalign,
n_threads, verbose, False, min_perc_state):
print(gid, *map(int, ali), sep="\t", file=temp_file)

# if we save the result to a file, then we close it now
try:
temp_file.flush()
os.fsync(temp_file.fileno())
except:
if verbose>4: sys.stderr.write("[E::align] Error when saving the resulting file\n")
sys.exit(1)
raise ValueError("[E::align] Error when saving the resulting file.")

# move temp file to the final destination
try:
#os.rename(bam_temp_file.name,args.profile_bam_file) # atomic operation
shutil.move(temp_file.name,res_file) #It is not atomic if the files are on different filsystems.
shutil.move(temp_file.name, res_file)
except:
sys.stderr.write("[E::align] The resulting file couldn't be save in the final destination. You can find the file here:\n"+temp_file.name+"\n")
sys.exit(1)
raise ValueError(f"[E::align] The resulting file couldn't be saved. You can find the file here:\n{temp_file.name}.")
38 changes: 38 additions & 0 deletions stag/alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

import pandas as pd
import numpy as np

# Function to identify the rownames and number of columns in an alignment
def find_raw_names_ncol(file_name):
gene_names = list()
with open(file_name) as f:
for line in f:
gene_names.append(line[:line.find("\t")].replace("/", "-"))
return gene_names, line.count("\t")

# function to load an alignment produced by the "align" option =================
# Input:
# - a file created by "align"
# Output:
# - a panda object
# as a note, numpy.loadtxt is way slower than pandas read.csv
# It works also on .gz files
def load_alignment_from_file(file_name, safe_mode=False):
gene_names, align_length = find_raw_names_ncol(file_name)
alignment = pd.DataFrame(False, index=gene_names, columns=range(align_length))
with open(file_name) as f:
if safe_mode:
for pos, line in enumerate(f):
align = [int(c) for c in line.split("\t")[1:]]
if len(align) != align_length or any((c != 0 and c != 1) for c in align):
raise ValueError(f"Malformatted alignment in line {pos}:\n{gene}\t{''.join(align)}")
alignment.iloc[pos] = np.array([c == 1 for c in align])
else:
for pos, line in enumerate(f):
alignment.iloc[pos] = np.array([c == "1" for c in line.split("\t")[1:]])

logging.info(f' LOAD_AL: Number of genes: {len(list(alignment.index.values))}')
alignment = alignment.drop_duplicates()
logging.info(f' LOAD_AL: Number of genes, after removing duplicates: {len(list(alignment.index.values))}')
return alignment
64 changes: 21 additions & 43 deletions stag/classify.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,54 @@
"""
Scripts that find the taxonomy of an aligned sequence
"""

# Author: Alessio Milanese <[email protected]>

# Requirements:
# - numpy
# - h5py

import numpy as np
import sys
import time
import os
import h5py
import tempfile
import shutil
import contextlib

from stag.taxonomy import Taxonomy
from stag.databases import load_db
import numpy as np
import h5py

from . import __version__ as tool_version
from stag.taxonomy3 import Taxonomy
from stag.databases import load_db
import stag.align as align

#===============================================================================
# FUNCTION TO LOAD ALIGNED SEQUENCES
#===============================================================================
def alignment_reader(aligned_sequences):
with open(aligned_sequences,"r") as align_in:
for ali_line in align_in:
gene_id, *aligned_seq = ali_line.rstrip().split("\t")
yield gene_id, np.array(list(map(int, aligned_seq)), dtype=bool)

#===============================================================================
# TAXONOMICALLY ANNOTATE SEQUENCES
#===============================================================================
def run_logistic_prediction(seq, coeff_raw):
# the first value of the coeff is the intercept
coeff = coeff_raw[1:]
intercept = coeff_raw[0]
# calculate
intercept, *coeff = coeff_raw
sm = coeff * seq
np_sum = (sm).sum() + intercept
score = 1 / (1 + np.exp(-np_sum))
return score
return 1 / (1 + np.exp(-np_sum))

# given many taxa (all siblings) and a sequence, it finds taxa with the highest
# score. Returns the taxa name and the score
def find_best_score(test_seq, siblings, classifiers):
best_score = -1
best_taxa = ""
# check that siblings is not empty:
if len(siblings) < 1:
sys.stderr.write("Error. no siblings")
# if there is only one sibiling:
if len(siblings) == 1:
best_score = 2 # if there are no siblings I put 2, it will be replaced after
best_taxa = siblings[0]
if len(siblings) > 1:
for s in siblings:
this_score = run_logistic_prediction(test_seq, classifiers[s])
best_score, best_taxa = -1, ""
if not siblings:
pass
elif len(siblings) == 1:
# if there are no siblings I put 2, it will be replaced after
best_score, best_taxa = 2, siblings[0]
else:
for sibling in siblings:
this_score = run_logistic_prediction(test_seq, classifiers[sibling])
if this_score > best_score:
best_score = this_score
best_taxa = s
best_score, best_taxa = this_score, sibling
return best_taxa, best_score

def predict_iter(test_seq, taxonomy, classifiers, tax, perc, arrived_so_far):
# last iterative step
if taxonomy.get(arrived_so_far) is not None:
t, p = find_best_score(test_seq, taxonomy[arrived_so_far], classifiers)
tax.append(t)
perc.append(p)
if t:
tax.append(t)
perc.append(p)
predict_iter(test_seq, taxonomy, classifiers, tax, perc, t)


Expand Down Expand Up @@ -114,7 +93,7 @@ def find_n_aligned_characters(test_seq):

def classify_seq(gene_id, test_seq, taxonomy, tax_function, classifiers, threads, verbose):
# test_seq is a boolean numpy array corresponding to the encoded aligned sequence

#print("TAX", taxonomy)
# number of characters that map to the internal states of the HMM
n_aligned_characters = find_n_aligned_characters(test_seq)

Expand Down Expand Up @@ -153,7 +132,6 @@ def classify(database, fasta_input=None, protein_fasta_input=None, verbose=3, th
long_out=False, current_tool_version=tool_version,
aligned_sequences=None, save_ali_to_file=None, min_perc_state=0, internal_call=False):
t0 = time.time()
# load the database
db = load_db(database, protein_fasta_input=protein_fasta_input, aligned_sequences=aligned_sequences)
hmm_file_path, use_cmalign, taxonomy, tax_function, classifiers, db_tool_version = db
if verbose > 2:
Expand Down
Loading

0 comments on commit 349009f

Please sign in to comment.