diff --git a/conda_env_stag.yaml b/conda_env_stag.yaml index 291d957..6835899 100644 --- a/conda_env_stag.yaml +++ b/conda_env_stag.yaml @@ -1,4 +1,4 @@ -name: stag +name: stag_0.8.1_tax3 channels: - bioconda - defaults @@ -11,7 +11,7 @@ dependencies: - easel - numpy - pandas - - scikit-learn + - scikit-learn<0.24 - h5py - seqtk - regex diff --git a/stag/__init__.py b/stag/__init__.py index 4b5e07f..d46d35f 100644 --- a/stag/__init__.py +++ b/stag/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.8.1" +__version__ = "0.8.2" __title__ = "stag" __author__ = "Alessio Milanese" __license__ = 'GPLv3+' diff --git a/stag/__main__.py b/stag/__main__.py index 0411782..32ffab3 100755 --- a/stag/__main__.py +++ b/stag/__main__.py @@ -322,7 +322,7 @@ def main(argv=None): # call the function to create the database create_db.create_db(args.aligned_sequences, args.taxonomy, args.verbose, args.output, args.use_cm_align, - args.template_al, args.intermediate_cross_val, tool_version, args.protein_fasta_input, + args.template_al, args.intermediate_cross_val, args.protein_fasta_input, args.penalty_logistic, args.solver_logistic, procs=args.threads) # -------------------------------------------------------------------------- @@ -368,7 +368,7 @@ def main(argv=None): # SECOND: CREATE_DB ---------------------------------------------------- # call the function to create the database create_db.create_db(al_file.name, args.taxonomy, args.verbose, args.output, args.use_cm_align, - args.template_al, args.intermediate_cross_val, tool_version, args.protein_fasta_input, + args.template_al, args.intermediate_cross_val, args.protein_fasta_input, args.penalty_logistic, args.solver_logistic, procs=args.threads) # what to do with intermediate alignment ------------------------------- diff --git a/stag/align.py b/stag/align.py index 0b86032..e091838 100644 --- a/stag/align.py +++ b/stag/align.py @@ -118,36 +118,27 @@ def align_generator(seq_file, protein_file, hmm_file, use_cmalign, n_threads, ve n_pass, n_not_pass = 0, 0 # check that the tools are available if use_cmalign and not is_tool("cmalign"): - sys.stderr.write("[E::align] Error: cmalign is not in the path. Please install Infernal.\n") - sys.exit(1) + raise ValueError("[E::align] Error: cmalign is not in the path. Please install Infernal.") elif not is_tool("hmmalign"): - sys.stderr.write("[E::align] Error: hmmalign is not in the path. Please install HMMER3.\n") - sys.exit(1) + raise ValueError("[E::align] Error: hmmalign is not in the path. Please install HMMER3.") if not is_tool("esl-reformat"): - sys.stderr.write("[E::align] Error: esl-reformat is not in the path. Please install Easel.\n") - sys.exit(1) + raise ValueError("[E::align] Error: esl-reformat is not in the path. Please install Easel.") - # prepare the command to run - cmd = "hmmalign " - if use_cmalign: - cmd = "cmalign --cpu "+str(n_threads)+" " - - if not protein_file: - cmd = cmd + hmm_file +" "+ seq_file - else: - cmd = cmd + hmm_file +" "+ protein_file + aligner = f"cmalign --cpu {n_threads}" if use_cmalign else "hmmalign" + seq_input = protein_file if protein_file else seq_file + align_cmd = f"{aligner} {hmm_file} {seq_input}" if verbose > 4: - sys.stderr.write("Command used to align the sequences: "+cmd+"\n") + print(f"Command used to align the sequences: {align_cmd}", file=sys.stderr) # run the command - CMD = shlex.split(cmd) - align_cmd = subprocess.Popen(CMD,stdout=subprocess.PIPE,) + CMD = shlex.split(align_cmd) + align_cmd = subprocess.Popen(CMD, stdout=subprocess.PIPE,) # command to parse the alignment from STOCKHOLM to fasta format cmd2 = "esl-reformat a2m -" CMD2 = shlex.split(cmd2) - parse_cmd = subprocess.Popen(CMD2,stdin=align_cmd.stdout,stdout=subprocess.PIPE,) + parse_cmd = subprocess.Popen(CMD2, stdin=align_cmd.stdout, stdout=subprocess.PIPE,) if protein_file: seq_stream = zip(read_fasta(parse_cmd.stdout, head_start=1), @@ -176,19 +167,17 @@ def align_generator(seq_file, protein_file, hmm_file, use_cmalign, n_threads, ve align_cmd.stdout.close() return_code = align_cmd.wait() if return_code: - sys.stderr.write("[E::align] Error. hmmalign/cmalign failed\n") - sys.exit(1) + raise ValueError("[E::align] Error. hmmalign/cmalign failed.") # check that converting the file worked correctly parse_cmd.stdout.close() return_code = parse_cmd.wait() if return_code: - sys.stderr.write("[E::align] Error. esl-reformat failed\n") - sys.exit(1) + raise ValueError("[E::align] Error. esl-reformat failed.") # print the number of sequences that were filtered if verbose > 3: - sys.stderr.write(" Number of sequences that pass the filter: "+str(n_pass)+"\n") - sys.stderr.write(" Number of sequences that do not pass the filter: "+str(n_not_pass)+"\n") + print(f" Number of sequences that pass the filter: {n_pass}", file=sys.stderr) + print(f" Number of sequences that do not pass the filter: {n_not_pass}", file=sys.stderr) # ------------------------------------------------------------------------------ # main function @@ -210,7 +199,6 @@ def align_file(seq_file, protein_file, hmm_file, use_cmalign, n_threads, verbose It will save the aligned sequences to the specified file. """ - # open the temporary file where to save the result temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w") os.chmod(temp_file.name, 0o644) with temp_file: @@ -218,18 +206,13 @@ def align_file(seq_file, protein_file, hmm_file, use_cmalign, n_threads, verbose n_threads, verbose, False, min_perc_state): print(gid, *map(int, ali), sep="\t", file=temp_file) - # if we save the result to a file, then we close it now try: temp_file.flush() os.fsync(temp_file.fileno()) except: - if verbose>4: sys.stderr.write("[E::align] Error when saving the resulting file\n") - sys.exit(1) + raise ValueError("[E::align] Error when saving the resulting file.") - # move temp file to the final destination try: - #os.rename(bam_temp_file.name,args.profile_bam_file) # atomic operation - shutil.move(temp_file.name,res_file) #It is not atomic if the files are on different filsystems. + shutil.move(temp_file.name, res_file) except: - sys.stderr.write("[E::align] The resulting file couldn't be save in the final destination. You can find the file here:\n"+temp_file.name+"\n") - sys.exit(1) + raise ValueError(f"[E::align] The resulting file couldn't be saved. You can find the file here:\n{temp_file.name}.") diff --git a/stag/alignment.py b/stag/alignment.py new file mode 100644 index 0000000..f39df75 --- /dev/null +++ b/stag/alignment.py @@ -0,0 +1,38 @@ +import logging + +import pandas as pd +import numpy as np + +# Function to identify the rownames and number of columns in an alignment +def find_raw_names_ncol(file_name): + gene_names = list() + with open(file_name) as f: + for line in f: + gene_names.append(line[:line.find("\t")].replace("/", "-")) + return gene_names, line.count("\t") + +# function to load an alignment produced by the "align" option ================= +# Input: +# - a file created by "align" +# Output: +# - a panda object +# as a note, numpy.loadtxt is way slower than pandas read.csv +# It works also on .gz files +def load_alignment_from_file(file_name, safe_mode=False): + gene_names, align_length = find_raw_names_ncol(file_name) + alignment = pd.DataFrame(False, index=gene_names, columns=range(align_length)) + with open(file_name) as f: + if safe_mode: + for pos, line in enumerate(f): + align = [int(c) for c in line.split("\t")[1:]] + if len(align) != align_length or any((c != 0 and c != 1) for c in align): + raise ValueError(f"Malformatted alignment in line {pos}:\n{gene}\t{''.join(align)}") + alignment.iloc[pos] = np.array([c == 1 for c in align]) + else: + for pos, line in enumerate(f): + alignment.iloc[pos] = np.array([c == "1" for c in line.split("\t")[1:]]) + + logging.info(f' LOAD_AL: Number of genes: {len(list(alignment.index.values))}') + alignment = alignment.drop_duplicates() + logging.info(f' LOAD_AL: Number of genes, after removing duplicates: {len(list(alignment.index.values))}') + return alignment diff --git a/stag/classify.py b/stag/classify.py index 87b18d9..6662954 100644 --- a/stag/classify.py +++ b/stag/classify.py @@ -1,75 +1,54 @@ -""" -Scripts that find the taxonomy of an aligned sequence -""" - -# Author: Alessio Milanese - -# Requirements: -# - numpy -# - h5py - -import numpy as np import sys import time import os -import h5py import tempfile import shutil import contextlib -from stag.taxonomy import Taxonomy -from stag.databases import load_db +import numpy as np +import h5py + from . import __version__ as tool_version +from stag.taxonomy3 import Taxonomy +from stag.databases import load_db import stag.align as align -#=============================================================================== -# FUNCTION TO LOAD ALIGNED SEQUENCES -#=============================================================================== def alignment_reader(aligned_sequences): with open(aligned_sequences,"r") as align_in: for ali_line in align_in: gene_id, *aligned_seq = ali_line.rstrip().split("\t") yield gene_id, np.array(list(map(int, aligned_seq)), dtype=bool) -#=============================================================================== -# TAXONOMICALLY ANNOTATE SEQUENCES -#=============================================================================== def run_logistic_prediction(seq, coeff_raw): # the first value of the coeff is the intercept - coeff = coeff_raw[1:] - intercept = coeff_raw[0] - # calculate + intercept, *coeff = coeff_raw sm = coeff * seq np_sum = (sm).sum() + intercept - score = 1 / (1 + np.exp(-np_sum)) - return score + return 1 / (1 + np.exp(-np_sum)) # given many taxa (all siblings) and a sequence, it finds taxa with the highest # score. Returns the taxa name and the score def find_best_score(test_seq, siblings, classifiers): - best_score = -1 - best_taxa = "" - # check that siblings is not empty: - if len(siblings) < 1: - sys.stderr.write("Error. no siblings") - # if there is only one sibiling: - if len(siblings) == 1: - best_score = 2 # if there are no siblings I put 2, it will be replaced after - best_taxa = siblings[0] - if len(siblings) > 1: - for s in siblings: - this_score = run_logistic_prediction(test_seq, classifiers[s]) + best_score, best_taxa = -1, "" + if not siblings: + pass + elif len(siblings) == 1: + # if there are no siblings I put 2, it will be replaced after + best_score, best_taxa = 2, siblings[0] + else: + for sibling in siblings: + this_score = run_logistic_prediction(test_seq, classifiers[sibling]) if this_score > best_score: - best_score = this_score - best_taxa = s + best_score, best_taxa = this_score, sibling return best_taxa, best_score def predict_iter(test_seq, taxonomy, classifiers, tax, perc, arrived_so_far): # last iterative step if taxonomy.get(arrived_so_far) is not None: t, p = find_best_score(test_seq, taxonomy[arrived_so_far], classifiers) - tax.append(t) - perc.append(p) + if t: + tax.append(t) + perc.append(p) predict_iter(test_seq, taxonomy, classifiers, tax, perc, t) @@ -114,7 +93,7 @@ def find_n_aligned_characters(test_seq): def classify_seq(gene_id, test_seq, taxonomy, tax_function, classifiers, threads, verbose): # test_seq is a boolean numpy array corresponding to the encoded aligned sequence - + #print("TAX", taxonomy) # number of characters that map to the internal states of the HMM n_aligned_characters = find_n_aligned_characters(test_seq) @@ -153,7 +132,6 @@ def classify(database, fasta_input=None, protein_fasta_input=None, verbose=3, th long_out=False, current_tool_version=tool_version, aligned_sequences=None, save_ali_to_file=None, min_perc_state=0, internal_call=False): t0 = time.time() - # load the database db = load_db(database, protein_fasta_input=protein_fasta_input, aligned_sequences=aligned_sequences) hmm_file_path, use_cmalign, taxonomy, tax_function, classifiers, db_tool_version = db if verbose > 2: diff --git a/stag/create_db.py b/stag/create_db.py index 90f9c3f..3a61877 100644 --- a/stag/create_db.py +++ b/stag/create_db.py @@ -1,9 +1,3 @@ -""" -Scripts that creates the database of classifiers -""" - -# Author: Alessio Milanese - # Input: # - one multiple sequence alignment (MSA) per marker gene. The MSA is obtained # from the function stag align, like: @@ -15,65 +9,42 @@ # Output: # - a database file (hdf5) that can be used by stag classify -import numpy as np import sys import random -import pandas as pd import logging import os -from sklearn.linear_model import LogisticRegression -import h5py +import time import tempfile import shutil +from collections import Counter -from stag.taxonomy import Taxonomy - -# Function to identify the rownames and number of columns in an alignment -def find_raw_names_ncol(file_name): - gene_names = list() - with open(file_name) as f: - for line in f: - gene_names.append(line[:line.find("\t")].replace("/", "-")) - return gene_names, line.count("\t") +import numpy as np +import pandas as pd +from sklearn.linear_model import LogisticRegression +import h5py -# function to load an alignment produced by the "align" option ================= -# Input: -# - a file created by "align" -# Output: -# - a panda object -# as a note, numpy.loadtxt is way slower than pandas read.csv -# It works also on .gz files -def load_alignment_from_file(file_name): - gene_names, align_length = find_raw_names_ncol(file_name) - alignment = pd.DataFrame(False, index=gene_names, columns=range(align_length)) - with open(file_name) as f: - for pos, line in enumerate(f): - alignment.iloc[pos] = np.array([c == "1" for c in line.split("\t")[1:]]) - - logging.info(f' LOAD_AL: Number of genes: {len(list(alignment.index.values))}') - - # we remove duplicates - alignment = alignment.drop_duplicates() - logging.info(f' LOAD_AL: Number of genes, after removing duplicates: {len(list(alignment.index.values))}') - return alignment +from stag.taxonomy3 import Taxonomy +from stag.databases import save_to_file +from stag.alignment import load_alignment_from_file -#=============================================================================== -# FUNCTIONS TO TRAIN THE CLASSIFIERS -#=============================================================================== # function that finds positive and negative examples =========================== def find_training_genes(node, siblings, full_taxonomy, alignment): + t00 = time.time() # "positive_examples" and "negative_examples" are list of gene ids positive_examples = full_taxonomy.find_gene_ids(node) + t_pos = time.time() - t00 + t0 = time.time() negative_examples = list() for s in siblings: negative_examples.extend(full_taxonomy.find_gene_ids(s)) + t_neg = time.time() - t0 if not negative_examples: # it means that there was only one child, and we cannot do anything return positive_examples, negative_examples - # From here, it means that there is at least one sibiling ================== + # From here, it means that there is at least one sibling ================== # We make classes more balanced positive_examples_subsample = list(positive_examples) negative_examples_subsample = list(negative_examples) @@ -84,7 +55,7 @@ def find_training_genes(node, siblings, full_taxonomy, alignment): if len(negative_examples_subsample) > 1000: negative_examples_subsample = random.sample(negative_examples_subsample, 1000) # 3. max 20 times more negative than positive ------------------------------ - # but if there is only one other sibiling, we choose only 3 times more negative + # but if there is only one other sibling, we choose only 3 times more negative max_negative_samples = len(positive_examples_subsample) * (20 if len(siblings) > 1 else 3) if len(negative_examples_subsample) > max_negative_samples: negative_examples_subsample = random.sample(negative_examples_subsample, max_negative_samples) @@ -103,67 +74,38 @@ def find_training_genes(node, siblings, full_taxonomy, alignment): rr = random.choice(range(n_positive_class)) X_clade = np.vstack((X_clade, X_clade[rr,])) - # find possible genes to add additionaly to negarives + # find possible genes to add additionaly to negatives possible_neg = list(set(alignment.index.values).difference(set(positive_examples + negative_examples))) if possible_neg: # if it is possible to add negatives # note that at the highest level, it's not possible X_poss_na = alignment.loc[possible_neg, : ].to_numpy() - len_poss_na = len(X_poss_na) # choose 5 random positive clades - X_check_sim = X_clade[random.sample(range(len(X_clade)),5),] - - m_for_diff_0 = np.tile(X_check_sim[0,],(len_poss_na,1)) - m_for_diff_1 = np.tile(X_check_sim[1,],(len_poss_na,1)) - m_for_diff_2 = np.tile(X_check_sim[2,],(len_poss_na,1)) - m_for_diff_3 = np.tile(X_check_sim[3,],(len_poss_na,1)) - m_for_diff_4 = np.tile(X_check_sim[4,],(len_poss_na,1)) - - differences_0 = np.sum(np.bitwise_xor(m_for_diff_0, X_poss_na), axis=1) - differences_1 = np.sum(np.bitwise_xor(m_for_diff_1, X_poss_na), axis=1) - differences_2 = np.sum(np.bitwise_xor(m_for_diff_2, X_poss_na), axis=1) - differences_3 = np.sum(np.bitwise_xor(m_for_diff_3, X_poss_na), axis=1) - differences_4 = np.sum(np.bitwise_xor(m_for_diff_4, X_poss_na), axis=1) - - non_zero_0 = np.sum(differences_0 != 0) - differences_0 = np.where(differences_0 == 0, np.nan, differences_0) - corr_ord_0 = np.argsort(differences_0)[0:non_zero_0+1] - - non_zero_1 = np.sum(differences_1 != 0) - differences_1 = np.where(differences_1 == 0, np.nan, differences_1) - corr_ord_1 = np.argsort(differences_1)[0:non_zero_1+1] - - non_zero_2 = np.sum(differences_2 != 0) - differences_2 = np.where(differences_2 == 0, np.nan, differences_2) - corr_ord_2 = np.argsort(differences_2)[0:non_zero_2+1] - - non_zero_3 = np.sum(differences_3 != 0) - differences_3 = np.where(differences_3 == 0, np.nan, differences_3) - corr_ord_3 = np.argsort(differences_3)[0:non_zero_3+1] - - non_zero_4 = np.sum(differences_4 != 0) - differences_4 = np.where(differences_4 == 0, np.nan, differences_4) - corr_ord_4 = np.argsort(differences_4)[0:non_zero_4+1] - - to_add = list() - for (a,b,c,d,e) in zip(list(corr_ord_0),list(corr_ord_1),list(corr_ord_2),list(corr_ord_3),list(corr_ord_4)): - if not(a in to_add): to_add.append(a) - if not(b in to_add): to_add.append(b) - if not(c in to_add): to_add.append(c) - if not(d in to_add): to_add.append(d) - if not(e in to_add): to_add.append(e) - if len(to_add) > missing_neg: - break # stop if we have enough similar genes - - # add list_genomes_to_add to the X_na - for i in to_add: - negative_examples_subsample.append(possible_neg[i]) + X_check_sim = X_clade[random.sample(range(len(X_clade)), 5), ] + + random_clades = list() + for clade_i in range(5): + m_for_diff = np.tile(X_check_sim[clade_i,], (len(X_poss_na), 1)) + differences = np.sum(np.bitwise_xor(m_for_diff, X_poss_na), axis=1) + non_zero = np.sum(differences != 0) + differences = np.where(differences == 0, np.nan, differences) + corr_ord = np.argsort(differences)[:non_zero + 1] + random_clades.append(list(corr_ord)) + clade_indices = set() + for indices in zip(*random_clades): + clade_indices.update(indices) + if len(clade_indices) > missing_neg: + break + negative_examples_subsample.extend(possible_neg[i] for i in clade_indices) + + t_total = time.time() - t00 + logging.info(f"find_training_genes\t{node}\t{len(positive_examples)}\t{len(negative_examples)}\t{t_pos:.3f}\t{t_neg:.3f}\t{t_total:.3f}\t{os.getpid()}") return positive_examples_subsample, negative_examples_subsample def get_classification_input(taxonomy, alignment): - for node, siblings in taxonomy.get_all_nodes(mode="bfs"): + for node, siblings in taxonomy.get_all_nodes(get_root=True): logging.info(f' TRAIN:"{node}":Find genes') positive_examples, negative_examples = find_training_genes(node, siblings, taxonomy, alignment) logging.info(f' SEL_GENES:"{node}": {len(positive_examples)} positive, {len(negative_examples)} negative') @@ -171,18 +113,19 @@ def get_classification_input(taxonomy, alignment): # check that we have at least 1 example for each class: if not negative_examples: # when the node is the only child, then there are no negative examples - logging.info(' Warning: no negative examples for "%s', node) - yield node, "no_negative_examples", None + logging.info(' Warning: no negative examples for "%s"', node) + X, y = "no_negative_examples", None elif not positive_examples: # There should be positive examples - logging.info(' Error: no positive examples for "%s', node) - yield node, "ERROR_no_positive_examples", None + logging.info(' Error: no positive examples for "%s"', node) + X, y = "ERROR_no_positive_examples", None else: X = alignment.loc[ negative_examples + positive_examples , : ].to_numpy() y = np.asarray(["no"] * len(negative_examples) + ["yes"] * len(positive_examples)) - yield node, X, y + yield node, X, y def train_all_classifiers_nonmp(alignment, full_taxonomy, penalty_v, solver_v, procs=None): + print("train_all_classifiers_nonmp - single-proc") all_classifiers = dict() for node, X, y in get_classification_input(full_taxonomy, alignment): if y is not None: @@ -196,31 +139,94 @@ def train_all_classifiers_nonmp(alignment, full_taxonomy, penalty_v, solver_v, p def perform_training(X, y, penalty_v, solver_v, node): if y is None: return node, X + # logging.info(' TRAIN:"%s":Train classifier', node) clf = LogisticRegression(random_state=0, penalty=penalty_v, solver=solver_v) clf.fit(X, y) return node, clf +def get_classification_input_mp(node, siblings, taxonomy, alignment, penalty_v, solver_v): + logging.info(f' TRAIN:"{node}":Find genes (proc={os.getpid})') + positive_examples, negative_examples = find_training_genes(node, siblings, taxonomy, alignment) + logging.info(f' SEL_GENES:"{node}": {len(positive_examples)} positive, {len(negative_examples)} negative') + + # check that we have at least 1 example for each class: + if not negative_examples: + # when the node is the only child, then there are no negative examples + logging.info(' Warning: no negative examples for "%s', node) + X, y = "no_negative_examples", None + elif not positive_examples: + # There should be positive examples + logging.info(' Error: no positive examples for "%s', node) + X, y = "ERROR_no_positive_examples", None + else: + X = alignment.loc[ negative_examples + positive_examples , : ].to_numpy() + y = np.asarray(["no"] * len(negative_examples) + ["yes"] * len(positive_examples)) + + return perform_training(X, y, penalty_v, solver_v, node) + +def get_classification_input_mp2(nodes, taxonomy, alignment, penalty_v, solver_v): + results = list() + for node, siblings in nodes: + #logging.info(f' TRAIN:"{node}":Find genes') + t00 = time.time() + positive_examples, negative_examples = find_training_genes(node, siblings, taxonomy, alignment) + t_select = time.time() - t00 + #logging.info(f' SEL_GENES:"{node}": {len(positive_examples)} positive, {len(negative_examples)} negative') + + # check that we have at least 1 example for each class: + if not negative_examples: + # when the node is the only child, then there are no negative examples + logging.info(' Warning: no negative examples for "%s', node) + X, y = "no_negative_examples", None + elif not positive_examples: + # There should be positive examples + logging.info(' Error: no positive examples for "%s', node) + X, y = "ERROR_no_positive_examples", None + else: + X = alignment.loc[ negative_examples + positive_examples , : ].to_numpy() + y = np.asarray(["no"] * len(negative_examples) + ["yes"] * len(positive_examples)) + t0 = time.time() + results.append(perform_training(X, y, penalty_v, solver_v, node)) + t1 = time.time() + t_train, t_total = t1 - t0, t1 - t00 + + # logging.info(f' "{node}": {len(positive_examples)} positive, {len(negative_examples)} negative\tselection: {t_select:.3f}s, training: {t_train:.3f}s\tpid={os.getpid()}') + logging.info("\t".join(map(str, [node, len(positive_examples), len(negative_examples), f"{t_select:.3f}s", f"{t_train:.3f}s", f"{t_total:.3f}s", os.getpid()]))) + return results + def train_all_classifiers_mp(alignment, full_taxonomy, penalty_v, solver_v, procs=2): import multiprocessing as mp + print(f"train_all_classifiers_mp with {procs} processes.") + logging.info("\t".join([" node", "positive", "negative", "t_select", "t_train", "t_total", "pid"])) with mp.Pool(processes=procs) as pool: - results = ( - pool.apply_async(perform_training, args=(X, y, penalty_v, solver_v, node)) - for node, X, y in get_classification_input(full_taxonomy, alignment) - ) - - return dict(p.get() for p in results) + nodes = list(full_taxonomy.get_all_nodes(get_root=False)) + step = len(nodes) // procs + results = [ + pool.apply_async(get_classification_input_mp2, args=(nodes[i:i+step], full_taxonomy, alignment, penalty_v, solver_v)) + for i in range(0, len(nodes), step) + ] + + res_d = dict() + for res in results: + res_d.update(res.get()) + return res_d + + #results = [ + # pool.apply_async(get_classification_input_mp, args=(node, siblings, full_taxonomy, alignment, penalty_v, solver_v)) + # for node, siblings in full_taxonomy.get_all_nodes(get_root=True) + #] + + #results = [ + # pool.apply_async(perform_training, args=(X, y, penalty_v, solver_v, node)) + # for node, X, y in list(get_classification_input(full_taxonomy, alignment)) + #] + + # return dict(p.get() for p in results) def train_all_classifiers(*args, procs=None): - train_f = train_all_classifiers_mp if procs else train_all_classifiers_nonmp + train_f = train_all_classifiers_mp if (procs and procs > 1) else train_all_classifiers_nonmp return train_f(*args, procs=procs) - results = ( - pool.apply_async(train_classifier, args=(X, y, penalty_v, solver_v, node,)) - for node, siblings, X, y in get_training_genes(taxonomy, alignment) - ) - - return dict(p.get() for p in results) - #=============================================================================== # FUNCTIONS TO LEARN THE FUNCTION FOR THE TAX. LEVEL @@ -254,7 +260,7 @@ def predict_one_gene(test_seq, training_tax, classifiers_train): perc = list() # we arrived at the root, and now we classify from there predict_iter(test_seq, training_tax, classifiers_train, tax, perc, training_tax.get_root()) - # we change the predictions that came from having only one sibiling -------- + # we change the predictions that came from having only one sibling -------- if perc[0] == 2: perc[0] = 1 for i in range(len(perc)): @@ -269,44 +275,42 @@ def predict(test_al, training_tax, classifiers_train): for gene in test_al.index.values ] -def learn_function_one_level(level_to_learn, alignment, full_taxonomy, penalty_v, solver_v, procs=None): - logging.info(' TEST:"%s" taxonomic level', str(level_to_learn)) - # 1. Identify which clades we want to remove (test set) and which to keep - # (training set) - this_level_clades = full_taxonomy.find_node_level(level_to_learn) - # we use 33% of the clades for testing - perc_test_set = 0.33 # this cannot be higher than 0.5. - test_set = set() - training_set = set() - for c in this_level_clades: +def learn_function(level_to_learn, alignment, full_taxonomy, penalty_v, solver_v, perc_test_set=0.33, gene_level=False, procs=None): + # perc_test_set <= 0.5 ! + logging.info(f' TEST:"{level_to_learn}" taxonomic level') + # 1. Identify which clades we want to remove (test set) and which to keep (training set) + test_set, training_set = set(), set() + clades = full_taxonomy.get_last_level_to_genes() if gene_level else full_taxonomy.find_node_level(level_to_learn) + for node, children in clades.items(): + aval_clades = set(children) # find how many to use for the test set: - aval_clades = set(this_level_clades[c]) - n_test = round(len(aval_clades) * perc_test_set) - if len(aval_clades) == 2: - n_test = 0 - # add to test set - for i in range(n_test): - test_set.add(aval_clades.pop()) - # the remaining clades in aval_clades go to the trainig set + n_test = 0 if (not gene_level and len(aval_clades) == 2) else round(len(aval_clades) * perc_test_set) + test_set.update(aval_clades.pop() for _ in range(n_test)) training_set.update(aval_clades) - logging.info(' TEST:"%s" level:test_set (%s):%s', str(level_to_learn),str(len(test_set)),str(test_set)) - logging.info(' TEST:"%s" level:trai_set (%s):%s', str(level_to_learn),str(len(training_set)),str(training_set)) + logging.info(f' TEST:"{level_to_learn}" level:test_set ({len(test_set)}):{test_set}') + logging.info(f' TEST:"{level_to_learn}" level:trai_set ({len(training_set)}):{training_set}') # 2. Create new taxonomy and alignment file & train the classifiers training_tax = full_taxonomy.copy() - removed_genes = training_tax.remove_clades(list(test_set)) - training_al = alignment.loc[ training_tax.find_gene_ids(training_tax.get_root()) , : ] + if gene_level: + training_tax.remove_genes(list(test_set)) + training_filter = training_set + test_filter = test_set + else: + test_filter = training_tax.remove_clades(list(test_set)) + training_filter = training_tax.find_gene_ids(training_tax.get_root()) + + training_al = alignment.loc[ training_filter, : ] classifiers_train = train_all_classifiers(training_al, training_tax, penalty_v, solver_v, procs=procs) # 3. Classify the test set - test_al = alignment.loc[ removed_genes , : ] + test_al = alignment.loc[ test_filter , : ] pr = predict(test_al, training_tax, classifiers_train) for g in pr: # g is: # ["geneB",["A","B","D","species8"],[0.99,0.96,0.96,0.07]] correct_tax = full_taxonomy.extract_full_tax_from_gene(g[0]) - g.append(correct_tax) - g.append(level_to_learn) + g.extend([correct_tax, level_to_learn]) return pr # return: @@ -315,45 +319,6 @@ def learn_function_one_level(level_to_learn, alignment, full_taxonomy, penalty_v # ["geneB",["A","B","D","species8"],[0.99,0.96,0.10,0.07],["A","B","U","speciesZ"],2] # ..... ] -def learn_function_genes_level(level_to_learn, alignment, full_taxonomy, penalty_v, solver_v, procs=None): - logging.info(' TEST:"%s" taxonomic level', str(level_to_learn)) - # 1. Identify which clades we want to remove (test set) and which to keep - # (training set) - this_level_clades = full_taxonomy.get_last_level_to_genes() # now there are genes - # we use 33% of the genes for testing - perc_test_set = 0.33 # this cannot be higher than 0.5. - test_set = set() - training_set = set() - for c in this_level_clades: - # find how many to use for the test set: - aval_clades = set(this_level_clades[c]) - n_test = round(len(aval_clades) * perc_test_set) - # add to test set - for i in range(n_test): - test_set.add(aval_clades.pop()) - # the remaining clades in aval_clades go to the trainig set - training_set.update(aval_clades) - logging.info(' TEST:"%s" level:test_set (%s):%s', str(level_to_learn),str(len(test_set)),str(test_set)) - logging.info(' TEST:"%s" level:trai_set (%s):%s', str(level_to_learn),str(len(training_set)),str(training_set)) - - # 2. Create new taxonomy and alignment file & train the classifiers - training_tax = full_taxonomy.copy() - training_tax.remove_genes(list(test_set)) - training_al = alignment.loc[ training_set , : ] - classifiers_train = train_all_classifiers(training_al, training_tax, penalty_v, solver_v, procs=procs) - - # 3. Classify the test set - test_al = alignment.loc[ test_set , : ] - pr = predict(test_al, training_tax, classifiers_train) - for g in pr: - # g is: - # ["geneB",["A","B","D","species8"],[0.99,0.96,0.96,0.07]] - correct_tax = full_taxonomy.extract_full_tax_from_gene(g[0]) - g.append(correct_tax) - g.append(level_to_learn) - - return pr - def estimate_function(all_calc_functions): # The all_calc_functions looks like: # GENE_ID PREDICTED PROB_PREDICTED CORRECT REMOVED_LEVEL @@ -367,74 +332,43 @@ def estimate_function(all_calc_functions): # refers to the fact that we removed the genes # we remove duplicates with the same predicted probability ----------------- - all_uniq = dict() - for line in all_calc_functions: - v = "" - for j in line[2]: - v = v+str(j) - all_uniq[v] = line - logging.info(' LEARN_FUNCTION:Number of lines: %s (before removing duplicates: %s)', - str(len(all_uniq)),str(len(all_calc_functions))) - # change all_calc_functions - all_calc_functions = list() - for j in all_uniq: - all_calc_functions.append(all_uniq[j]) + all_uniq = {tuple(round(v, 2) for v in item[2]): item for item in all_calc_functions} + logging.info(f' LEARN_FUNCTION:Number of lines: {len(all_uniq)}/{len(all_calc_functions)}') - # we find what is the correct value for the prediction level --------------- correct_level = list() - for line in all_calc_functions: + for _, predicted, _, ground_truth, _ in all_uniq.values(): corr_level_this = -1 - cont = 0 - for p,c in zip(line[1],line[3]): - cont = cont + 1 + for cont, (p, c) in enumerate(zip(predicted, ground_truth)): if p == c: - corr_level_this = cont-1 # we select to what level to predict + corr_level_this = cont # we select to what level to predict correct_level.append(corr_level_this) + # now in correct_level there is to which level to predict to. Example: # "A","B","C","species2" # with corr_level_this = 0, we should assign "A" # with corr_level_this = 2, we should assign "A","B","C" # with corr_level_this = -1, we should assign "" (no taxonomy) + + level_counter = Counter(correct_level) + for level, count in sorted(level_counter.items()): + logging.info(f' LEARN_FUNCTION:Number of lines: level {level}: {count}') - # check how many lines there are per correct level ------------------------- - for l in set(correct_level): - cont = 0 - for j in correct_level: - if j == l: - cont = cont + 1 - logging.info(' LEARN_FUNCTION:Number of lines: level %s: %s', - str(l),str(cont)) - - - # we train the classifiers ------------------------------------------------- all_classifiers = dict() - for l in set(correct_level): - # we create the feature matrix + for uniq_level in sorted(level_counter): # NOTE: we always need the negative class to be first - correct_order_lines = list() - correct_order_labels = list() - cont = 0 - for i in range(len(all_calc_functions)): - if correct_level[i] != l: - correct_order_lines.append(all_calc_functions[i][2]) - correct_order_labels.append(0) - cont = cont + 1 - cont = 0 - for i in range(len(all_calc_functions)): - if correct_level[i] == l: - correct_order_lines.append(all_calc_functions[i][2]) - correct_order_labels.append(1) - cont = cont + 1 - - X = np.array([np.array(xi) for xi in correct_order_lines]) - y = np.asarray(correct_order_labels) - # train classifier - clf = LogisticRegression(random_state=0, penalty = "none", solver='saga',max_iter = 5000) + correct_order = [[], []] + for level, (_, _, prob, *_) in zip(correct_level, all_uniq.values()): + correct_order[int(uniq_level == level)].append(prob) + + X = np.array([np.array(xi) for xi in correct_order[0] + correct_order[1]]) + y = np.asarray([0] * len(correct_order[0]) + [1] * len(correct_order[1])) + clf = LogisticRegression(random_state=0, penalty = "none", solver='saga', max_iter = 5000) clf.fit(X, y) - all_classifiers[str(l)] = clf + all_classifiers[str(uniq_level)] = clf return all_classifiers + # create taxonomy selection function =========================================== # This function define a function that is able to identify to which taxonomic # level a new gene should be assigned to. @@ -444,91 +378,37 @@ def learn_taxonomy_selection_function(alignment, full_taxonomy, save_cross_val_d # do the cross validation for each level all_calc_functions = list() - for i in range(n_levels): - all_calc_functions = all_calc_functions + learn_function_one_level(i, alignment, full_taxonomy, penalty_v, solver_v, procs=procs) + for level in range(n_levels): + all_calc_functions.extend(learn_function(level, alignment, full_taxonomy, penalty_v, solver_v, procs=procs)) + #all_calc_functions.extend(learn_function_one_level(level, alignment, full_taxonomy, penalty_v, solver_v, procs=procs)) # do the cross val. for the last level (using the genes) - all_calc_functions = all_calc_functions + learn_function_genes_level(n_levels, alignment, full_taxonomy, penalty_v, solver_v, procs=procs) + # all_calc_functions.extend(learn_function_genes_level(n_levels, alignment, full_taxonomy, penalty_v, solver_v, procs=procs)) + all_calc_functions.extend(learn_function(n_levels, alignment, full_taxonomy, penalty_v, solver_v, gene_level=True, procs=procs)) # save all_calc_functions if necessary ------------------------------------- - if not (save_cross_val_data is None): + if save_cross_val_data: outfile = tempfile.NamedTemporaryFile(delete=False, mode="w") - outfile.write("gene\tpredicted\tprob\tground_truth\tremoved_level\n") - os.chmod(outfile.name, 0o644) - for vals in all_calc_functions: - to_print = vals[0] + "\t" + "/".join(vals[1]) + "\t" # "geneB",["D","E","F","species8"] - to_print = to_print + "/".join(str(x) for x in vals[2]) + "\t" # [0.99,0.96,0.95,0.07] - to_print = to_print + "/".join(vals[3]) + "\t" # ["G","H","I","species9"] - to_print = to_print + str(vals[4]) # removed level - outfile.write(to_print+"\n") - # save + with outfile: + os.chmod(outfile.name, 0o644) + print("gene", "predicted", "prob", "ground_truth", "removed_level", sep="\t", file=outfile) + for gene, predicted, prob, ground_truth, removed_level in all_calc_functions: + predicted, prob, ground_truth = ("/".join(s) for s in (predicted, ["{:.2f}".format(pr) for pr in prob], ground_truth)) + print(gene, predicted, prob, ground_truth, removed_level, sep="\t", file=outfile) + try: + outfile.flush() + os.fsync(outfile.fileno()) + except: + print("[E::main] Error: failed to save the cross validation results", file=sys.stderr) try: - outfile.flush() - os.fsync(outfile.fileno()) - outfile.close() + shutil.move(outfile.name, save_cross_val_data) except: - sys.stderr.write("[E::main] Error: failed to save the cross validation results\n") - try: - #os.rename(outfile.name,output) # atomic operation - shutil.move(outfile.name,save_cross_val_data) #It is not atomic if the files are on different filsystems. - except: - sys.stderr.write("[E::main] Error: failed to save the cross validation results\n") - sys.stderr.write("[E::main] you can find the file here:\n"+outfile.name+"\n") - sys.exit(1) - - # estimate the function ---------------------------------------------------- - f = estimate_function(all_calc_functions) - return f + print("[E::main] Error: failed to save the cross validation results\n" + \ + f"[E::main] you can find the file here: \n{outfile.name}\n", file=sys.stderr) + return estimate_function(all_calc_functions) -#=============================================================================== -# FUNCTIONS TO SAVE TO A DATABASE -#=============================================================================== -def save_to_file(classifiers, full_taxonomy, tax_function, use_cmalign, tool_version, output, hmm_file_path=None, protein_fasta_input=None): - - string_dt = h5py.special_dtype(vlen=str) - - with h5py.File(output, "w") as h5p_out: - # zero: tool version ------------------------------------------------------- - h5p_out.create_dataset('tool_version', data=np.array([str(tool_version)], "S100"), dtype=string_dt) - # and type of database - h5p_out.create_dataset('db_type', data=np.array(["single_gene"], "S100"), dtype=string_dt) - # was the alignment done at the protein level? - h5p_out.create_dataset('align_protein', data=np.array([bool(protein_fasta_input)]), dtype=bool) - # first we save the hmm file ----------------------------------------------- - hmm_string = "".join(line for line in open(hmm_file_path)) if hmm_file_path else "NA" - h5p_out.create_dataset('hmm_file', data=np.array([hmm_string], "S" + str(len(hmm_string) + 100)), dtype=string_dt, compression="gzip") - # second, save the use_cmalign info ---------------------------------------- - h5p_out.create_dataset('use_cmalign', data=np.array([use_cmalign]), dtype=bool) - # third, we save the taxonomy --------------------------------------------- - h5p_out.create_group("taxonomy") - for node in full_taxonomy.child_nodes: - h5p_out.create_dataset("taxonomy/" + node, data=np.array(list(full_taxonomy.child_nodes[node]), "S10000"), dtype=string_dt, compression="gzip") - # fourth, the taxonomy function -------------------------------------------- - h5p_out.create_group("tax_function") - for c in tax_function: - # we append the intercept at the head (will have position 0) - vals = np.append(tax_function[c].intercept_, tax_function[c].coef_) - h5p_out.create_dataset("tax_function/" + str(c), data=vals, dtype=np.float64, compression="gzip") - # fifth, save the classifiers ---------------------------------------------- - h5p_out.create_group("classifiers") - for c in classifiers: - if classifiers[c] != "no_negative_examples": - vals = np.append(classifiers[c].intercept_, classifiers[c].coef_) - h5p_out.create_dataset("classifiers/" + c, data=vals, dtype=np.float64, compression="gzip", compression_opts=8) - else: - # in this case, it always predict 1, we save it as an array of - # with the string "no_negative_examples" - h5p_out.create_dataset("classifiers/" + c, data=np.array(["no_negative_examples"], "S40"), dtype=string_dt, compression="gzip") - - h5p_out.flush() - -#=============================================================================== -# MAIN -#=============================================================================== - -def create_db(aligned_seq_file, tax_file, verbose, output, use_cmalign, hmm_file_path, save_cross_val_data, tool_version, protein_fasta_input, penalty_v, solver_v, procs=None): - # set log file +def create_db(aligned_seq_file, tax_file, verbose, output, use_cmalign, hmm_file_path, save_cross_val_data, protein_fasta_input, penalty_v, solver_v, procs=None): filename_log = os.path.realpath(output)+'.log' logging.basicConfig(filename=filename_log, filemode='w', @@ -564,7 +444,7 @@ def create_db(aligned_seq_file, tax_file, verbose, output, use_cmalign, hmm_file # 6. save the result logging.info('MAIN:Save to file') - save_to_file(classifiers, full_taxonomy, tax_function, use_cmalign, tool_version, output, hmm_file_path=hmm_file_path, protein_fasta_input=protein_fasta_input) + save_to_file(classifiers, full_taxonomy, tax_function, use_cmalign, output, hmm_file_path=hmm_file_path, protein_fasta_input=protein_fasta_input) logging.info('TIME:Finish save to file') logging.info('MAIN:Finished') diff --git a/stag/databases.py b/stag/databases.py index 42695c6..a9cdc9b 100644 --- a/stag/databases.py +++ b/stag/databases.py @@ -8,8 +8,10 @@ import numpy as np import h5py +from . import __version__ as tool_version import stag.align as align + def load_genome_DB(database, tool_version, verbose): dirpath = tempfile.mkdtemp() shutil.unpack_archive(database, dirpath, "gztar") @@ -104,3 +106,42 @@ def load_db(hdf5_DB_path, protein_fasta_input=None, aligned_sequences=None, dir_ print(key, *classifiers[key], sep="\t", file=class_out) return hmm_file.name, use_cmalign, taxonomy, tax_function, classifiers, db_tool_version + + +def save_to_file(classifiers, full_taxonomy, tax_function, use_cmalign, output, hmm_file_path=None, protein_fasta_input=None): + + string_dt = h5py.special_dtype(vlen=str) + with h5py.File(output, "w") as h5p_out: + # zero: tool version ------------------------------------------------------- + h5p_out.create_dataset('tool_version', data=np.array([str(tool_version)], "S100"), dtype=string_dt) + # and type of database + h5p_out.create_dataset('db_type', data=np.array(["single_gene"], "S100"), dtype=string_dt) + # was the alignment done at the protein level? + h5p_out.create_dataset('align_protein', data=np.array([bool(protein_fasta_input)]), dtype=bool) + # first we save the hmm file ----------------------------------------------- + hmm_string = "".join(line for line in open(hmm_file_path)) if hmm_file_path else "NA" + h5p_out.create_dataset('hmm_file', data=np.array([hmm_string], "S" + str(len(hmm_string) + 100)), dtype=string_dt, compression="gzip") + # second, save the use_cmalign info ---------------------------------------- + h5p_out.create_dataset('use_cmalign', data=np.array([use_cmalign]), dtype=bool) + # third, we save the taxonomy --------------------------------------------- + h5p_out.create_group("taxonomy") + for node, _ in full_taxonomy.get_all_nodes(get_root=True): + h5p_out.create_dataset(f"taxonomy/{node}", data=np.array(list(full_taxonomy[node].children.keys()), "S10000"), dtype=string_dt, compression="gzip") + # fourth, the taxonomy function -------------------------------------------- + h5p_out.create_group("tax_function") + for c in tax_function: + # we append the intercept at the head (will have position 0) + vals = np.append(tax_function[c].intercept_, tax_function[c].coef_) + h5p_out.create_dataset("tax_function/" + str(c), data=vals, dtype=np.float64, compression="gzip") + # fifth, save the classifiers ---------------------------------------------- + h5p_out.create_group("classifiers") + for c in classifiers: + if classifiers[c] != "no_negative_examples": + vals = np.append(classifiers[c].intercept_, classifiers[c].coef_) + h5p_out.create_dataset("classifiers/" + c, data=vals, dtype=np.float64, compression="gzip", compression_opts=8) + else: + # in this case, it always predict 1, we save it as an array of + # with the string "no_negative_examples" + h5p_out.create_dataset("classifiers/" + c, data=np.array(["no_negative_examples"], "S40"), dtype=string_dt, compression="gzip") + + h5p_out.flush() diff --git a/stag/stag_test.py b/stag/stag_test.py index 65c557e..7137bd1 100644 --- a/stag/stag_test.py +++ b/stag/stag_test.py @@ -401,7 +401,7 @@ def main(argv=None): temp_file_db = tempfile.NamedTemporaryFile(delete=False, mode="w") t0 = time.time() - stag_command = "stag train -f -o "+trained_db+" -i "+seq_file+" -p "+protein_file+" -x "+tax_file+" -a "+hmm_file + stag_command = "stag train -f -o "+trained_db+" -i "+seq_file+" -p "+protein_file+" -x "+tax_file+" -a "+hmm_file + " -t 2" process = subprocess.run(stag_command.split()) runtime = time.time() - t0 diff --git a/stag/taxonomy3.py b/stag/taxonomy3.py new file mode 100644 index 0000000..1ccce0d --- /dev/null +++ b/stag/taxonomy3.py @@ -0,0 +1,157 @@ +import csv +import logging + +class Taxon: + def __init__(self, parent=None, label=None): + self.label = label if label else Taxonomy.TREE_ROOT + self.children = dict() + self.genes = set() + self.parent = parent + def add_child(self, child): + self.children.setdefault(child.label, child) + def add_gene(self, gene): + self.genes.add(gene) + def is_leaf(self): + return not self.children + +class Taxonomy(dict): + TREE_ROOT = "tree_root" + def __init__(self, fn=None): + self[self.TREE_ROOT] = Taxon() + self.n_taxlevels = 0 + self.gene_lineages = dict() + self.fn = fn + + def load_from_file(self): + self._read_taxonomy(self.fn) + + def _check_lineage_depth(self, lineage, line_no): + lineage = lineage.replace("/", "-").split(";") # issue10 + if len(lineage) < self.n_taxlevels: + raise ValueError(f"Line {line_no}: Taxonomy record does not have the expected number of taxonomic levels\n{lineage}") + self.n_taxlevels = len(lineage) + return lineage + + def _read_taxonomy(self, fn): + for line_no, (gene, lineage) in enumerate(csv.reader(open(fn), delimiter="\t"), start=1): + parent = self[self.TREE_ROOT] + lineage = self._check_lineage_depth(lineage, line_no) + for i, taxon in enumerate(lineage): + if i > 0: + parent = node + node = self.setdefault(taxon, Taxon(parent=parent, label=taxon)) + parent.add_child(node) + node.add_gene(gene) + self.gene_lineages[gene] = lineage + + def copy(self): + from copy import deepcopy + return deepcopy(self) + def extract_full_tax_from_gene(self, gene): + return self.gene_lineages.get(gene) + def get_n_levels(self): + return self.n_taxlevels + def get_root(self): + return self.TREE_ROOT + def find_children_node(self, node): + return list(self.get(node, Taxon()).children.keys()) + def get_last_level_to_genes(self): + return {node: set(node.genes) for node in self.values()} + def is_last_node(self, node): + return self.get(node, Taxon()).is_leaf() + def find_gene_ids(self, node=None): + genes = set() + nodes = [self[node if node else self.TREE_ROOT]] + while nodes: + node = nodes.pop(0) + nodes.extend(node.children.values()) + genes.update(node.genes) + return list(genes) + def remove_clades(self, nodes): + removed_genes = set() + for node in nodes: + stack = [node] + while stack: + node2 = self[stack.pop()] + removed_genes.update(node2.genes) + stack.extend(node2.children) + if node2.parent: + node2.parent.children.pop(node2.label, None) + self._clean_branch(node2.parent) + + self.pop(node2.label, None) + for gene in removed_genes: + self.gene_lineages.pop(gene, None) + return list(removed_genes) + def _clean_branch(self, node): + while True: + if node.children or not node.parent: + break + try: + self.pop(node.label) + node.parent.children.pop(node.label) + except: + pass + node = node.parent + def remove_genes(self, genes): + empty_nodes = set() + for gene in genes: + node = self[self.gene_lineages[gene][-1]] + node.genes.discard(gene) + if not node.genes: + empty_nodes.add(node.label) + self.remove_clades(empty_nodes) + def find_node_level(self, tax_level): + nodes = dict() + queue = [(self[self.TREE_ROOT], 0)] + while queue: + node, level = queue.pop(0) + if level + 1 == tax_level: + for child in node.children.values(): + nodes[child.label] = set(child.children) + else: + queue.extend((child, level + 1) for child in node.children.values()) + return nodes + def get_all_nodes(self, mode=None, get_root=False): + queue = [(self[self.TREE_ROOT], set())] + while queue: + node, siblings = queue.pop(0) + if node.label != self.get_root() or get_root: + yield node.label, siblings + + children = set(node.children) + for child in children: + siblings = children.difference({child}) + queue.append((self[child], siblings)) + def ensure_geneset_consistency(self, genes): + genes_in_tree = set(self.find_gene_ids()) + logging.info(f" CHECK: genes in geneset: {len(genes)}") + logging.info(f" CHECK: genes in taxonomy: {len(genes_in_tree)}") + + # check that all genes in the geneset are in the taxonomy + missing_genes = set(genes).difference(genes_in_tree) + if missing_genes: + logging.info(" Error: some genes in the alignment have no taxonomy.") + for gene in missing_genes: + logging.info(f" {gene}") + raise ValueError("Some genes in the alignment have no taxonomy.\n" + "Use the command 'check_input' to find more information.\n") + else: + logging.info(" CHECK: check all genes in the alignment have a taxonomy: correct") + + # the taxonomy can have more genes than the geneset, but these need to be removed + # since the selection of the genes for training and testing is done at taxonomy level + drop_genes = genes_in_tree.difference(genes) + if drop_genes: + n_drop_genes = len(drop_genes) + self.remove_genes(drop_genes) + else: + n_drop_genes = None + logging.info(f" CHECK: check genes that we need to remove from the taxonomy: {n_drop_genes}") + + # verify number of genes is consistent between set and taxonomy tree + genes_in_tree = self.find_gene_ids() + if len(genes_in_tree) != len(genes): + msg = "Even after correction, the genes in the taxonomy and the alignment do not agree." + logging.info(f" Error: {msg.lower()}") + raise ValueError(msg) diff --git a/stag/train_genome.py b/stag/train_genome.py index 45c318a..ef0e622 100644 --- a/stag/train_genome.py +++ b/stag/train_genome.py @@ -1,121 +1,72 @@ -""" -Scripts that trains the database for the genome -""" - -# Author: Alessio Milanese - -import sys -import time import os import tempfile import shutil -import subprocess -import shlex -import errno -import h5py -import re import tarfile from stag.helpers import check_file_exists from stag.classify import classify +# cschu 2021-04-10: we need to change the alignment format!! -> this is too hacky. # find the length of the alignments -------------------------------------------- def find_length_ali(gene_db, fasta_input, protein_fasta_input): return classify(gene_db, fasta_input=fasta_input, protein_fasta_input=protein_fasta_input, internal_call=True)[0] -#=============================================================================== -# MAIN -#=============================================================================== -def train_genome(output, list_genes, gene_thresholds, threads, verbose, concat_stag_db): - # temp file where to save the result --------------------------------------- +def get_dummy_fastas(): + fasta_files = list() + for seq in ("AAA", "A"): + with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmp_fasta: + os.chmod(tmp_fasta.name, 0o644) + print(">test", seq, sep="\n", file=tmp_fasta, flush=True) + fasta_files.append(tmp_fasta.name) + return fasta_files + +def get_alignment_lengths(list_genes): + fna, faa = get_dummy_fastas() + with tempfile.NamedTemporaryFile(delete=False, mode="w") as length_file: + os.chmod(length_file.name, 0o644) + for gene_db in list_genes: + print(os.path.basename(gene_db), find_length_ali(gene_db, fna, faa), sep="\t", flush=True, file=length_file) + [os.remove(f) for f in (fna, faa)] + return length_file.name + + +def train_genome(output, list_genes, gene_threshold_file, threads, verbose, concat_stag_db): + check_file_exists(gene_threshold_file, isfasta=False) + with open(gene_threshold_file) as f: + gene_thresholds = set(line.strip().split("\t")[0] for line in f if line) + + list_genes = list_genes.split(",") + missing_thresholds = set(os.path.basename(fn) for fn in list_genes).difference(gene_thresholds) + if missing_thresholds: + raise ValueError(f"[E::main] Error: gene {list(missing_thresholds)[0]} is missing from the threshold file (-T)") + outfile = tempfile.NamedTemporaryFile(delete=False, mode="w") os.chmod(outfile.name, 0o644) + core_db_files = ("threshold_file.tsv", "hmm_lengths_file.tsv", "concatenated_genes_STAG_database.HDF5") + with tarfile.open(outfile.name, "w:gz", dereference=True) as genome_tar: + for fn in list_genes: + check_file_exists(fn) + base_fn = os.path.basename(fn) + if base_fn in core_db_files: + raise ValueError(f"[E::main] Error: gene databases cannot be named '{base_fn}'. Please choose another name.") + if "##" in base_fn: + raise ValueError(f"Error with: {base_fn}\n[E::main] Error: gene database file names cannot contain '##'. Please choose another name.") + try: + genome_tar.add(fn, base_fn) + except: + raise ValueError(f"[E::main] Error: when adding {fn} to the database") + for source, target in zip((gene_threshold_file, get_alignment_lengths(list_genes), concat_stag_db), core_db_files): + genome_tar.add(source, target) - # we need a file with the thresholds --------------------------------------- - check_file_exists(gene_thresholds,isfasta = False) - genes_threhold_file = list() - o = open(gene_thresholds) - for i in o: - vals = i.rstrip().split("\t") - genes_threhold_file.append(vals[0]) - o.close() - for name in list_genes.split(","): - if not name.split("/")[-1] in genes_threhold_file: - sys.stderr.write("[E::main] Error: ") - sys.stderr.write("gene "+name.split("/")[-1]+" is missing from the threshold file (-T)\n") - sys.exit(1) - - - # we create a tar.gz with all the genes ------------------------------------ - tar = tarfile.open(outfile.name, "w:gz") - for name in list_genes.split(","): - check_file_exists(name,isfasta = False) - try: - name_file = os.path.basename(name) - if name_file == "threshold_file.tsv": - sys.stderr.write("[E::main] Error: gene databases cannot have name 'threshold_file.tsv'. Please, choose another name.\n") - sys.exit(1) - if name_file == "hmm_lengths_file.tsv": - sys.stderr.write("[E::main] Error: gene databases cannot have name 'hmm_lengths_file.tsv'. Please, choose another name.\n") - sys.exit(1) - if name_file == "concatenated_genes_STAG_database.HDF5": - sys.stderr.write("[E::main] Error: gene databases cannot have name 'concatenated_genes_STAG_database.HDF5'. Please, choose another name.\n") - sys.exit(1) - if len(name_file.split("##")) > 1: - sys.stderr.write("Error with: "+name_file+"\n") - sys.stderr.write("[E::main] Error: gene databases cannot have in the name '##'. Please, choose another name.\n") - sys.exit(1) - tar.add(name, name_file) - except: - sys.stderr.write("[E::main] Error: when adding "+name+" to the database\n") - sys.exit(1) - - # we add the file with the thresholds to the tar.gz - tar.add(gene_thresholds, "threshold_file.tsv") - - - # we need to find the length of the alignments ----------------------------- - len_f = tempfile.NamedTemporaryFile(delete=False, mode="w") - os.chmod(len_f.name, 0o644) - # temp fasta file - temp_fasta = tempfile.NamedTemporaryFile(delete=False, mode="w") - os.chmod(temp_fasta.name, 0o644) - temp_fasta.write(">test\nAAA\n") - temp_fasta.flush() - # protein - temp_fasta2 = tempfile.NamedTemporaryFile(delete=False, mode="w") - os.chmod(temp_fasta2.name, 0o644) - temp_fasta2.write(">test\nA\n") - temp_fasta2.flush() - for gene_db in list_genes.split(","): - len_this = find_length_ali(gene_db,temp_fasta.name,temp_fasta2.name) - len_f.write(os.path.basename(gene_db) + "\t" + str(len_this) + "\n") - len_f.flush() - - os.remove(temp_fasta.name) - # we add the file with the lengths to the tar.gz - tar.add(len_f.name, "hmm_lengths_file.tsv") - - - # add file with stag DB of the concatenated alis --------------------------- - tar.add(concat_stag_db, "concatenated_genes_STAG_database.HDF5") - - - # close tar file ----------------------------------------------------------- - tar.close() - # close try: outfile.flush() os.fsync(outfile.fileno()) outfile.close() except: - sys.stderr.write("[E::main] Error: failed to save the result\n") - sys.exit(1) + raise ValueError("[E::main] Error: failed to save the result.") try: - #os.rename(outfile.name,output) # atomic operation - shutil.move(outfile.name,output) #It is not atomic if the files are on different filsystems. + shutil.move(outfile.name, output) except: - sys.stderr.write("[E::main] Error: failed to save the resulting database\n") - sys.stderr.write("[E::main] you can find the file here:\n"+outfile.name+"\n") - sys.exit(1) + raise ValueError("[E::main] Error: failed to save the resulting database\n" + \ + f"[E::main] you can find the file here:\n{outfile.name}")