From 65248abc3b3a1342c2bde498b70b6cfaa8129648 Mon Sep 17 00:00:00 2001 From: Matt Rasmussen Date: Thu, 6 Mar 2014 09:54:04 -0800 Subject: [PATCH] update dependency python libs --- Makefile.dev | 12 +- argweaver/deps/compbio/arglib.py | 1121 +++++--------------------- argweaver/deps/compbio/coal.py | 258 +++--- argweaver/deps/compbio/phylo.py | 543 +++++++------ argweaver/deps/compbio/vis/argvis.py | 186 ++--- argweaver/deps/rasmus/stats.py | 609 ++++++-------- argweaver/deps/rasmus/tablelib.py | 1030 ++++++++--------------- argweaver/deps/rasmus/testing.py | 34 +- argweaver/deps/rasmus/treelib.py | 765 ++++++++---------- argweaver/deps/rasmus/util.py | 995 +++++++++++------------ 10 files changed, 2156 insertions(+), 3397 deletions(-) diff --git a/Makefile.dev b/Makefile.dev index f56cf324..19de955e 100644 --- a/Makefile.dev +++ b/Makefile.dev @@ -3,7 +3,7 @@ # -PY_SRC_PATH=../compbio/python +PY_SRC_PATH=../compbio RASMUS_SRC_FILES = \ __init__.py \ @@ -33,9 +33,9 @@ COMPBIO_SRC_FILES = \ vis/__init__.py \ vis/argvis.py -# copy subset of python modules for packaging with arghmm +# copy subset of python modules for packaging with argweaver includedep: - mkdir -p arghmm/deps - touch arghmm/deps/__init__.py - ./setup/cp-deps.py $(PY_SRC_PATH)/rasmus arghmm/deps/rasmus $(RASMUS_SRC_FILES) - ./setup/cp-deps.py $(PY_SRC_PATH)/compbio arghmm/deps/compbio $(COMPBIO_SRC_FILES) + mkdir -p argweaver/deps + touch argweaver/deps/__init__.py + ./setup/cp-deps.py $(PY_SRC_PATH)/rasmus argweaver/deps/rasmus $(RASMUS_SRC_FILES) + ./setup/cp-deps.py $(PY_SRC_PATH)/compbio argweaver/deps/compbio $(COMPBIO_SRC_FILES) diff --git a/argweaver/deps/compbio/arglib.py b/argweaver/deps/compbio/arglib.py index 127d6965..9b9baaa5 100644 --- a/argweaver/deps/compbio/arglib.py +++ b/argweaver/deps/compbio/arglib.py @@ -1,8 +1,7 @@ """ arglib.py - - Ancestral recombination graph (ARG) + Ancestral recombination graph (ARG) """ @@ -17,14 +16,13 @@ from itertools import izip, chain from collections import defaultdict import heapq -from math import * # compbio libs from . import fasta # rasmus libs -from rasmus import treelib, util, stats -from rasmus.intervals import iter_intersections +from rasmus import treelib +from rasmus import util #============================================================================= @@ -41,34 +39,46 @@ def __init__(self, name="n", age=0, event="gene", pos=0): self.children = [] self.event = event self.age = age - self.pos = pos # recomb position + self.pos = pos # recomb position self.data = {} def __repr__(self): return "" % self.name def get_dist(self, parent_index=0): - """Get branch length distance from node to parent_index'th parent""" + """Get branch length distance from node to parent_index'th parent.""" if len(self.parents) == 0: return 0.0 return self.parents[parent_index].age - self.age def get_dists(self): - """Get all branch length distances from node to parents""" + """Get all branch length distances from node to parents.""" return [p.age - self.age for p in self.parents] def copy(self): - """Returns a copy of this node""" + """Returns a copy of this node.""" node = ArgNode(self.name, age=self.age, event=self.event, - pos=self.pos) + pos=self.pos) node.data = dict(self.data) return node def is_leaf(self): - """Returns True if this node is a leaf""" + """Returns True if this node is a leaf.""" return len(self.children) == 0 - + def equal(self, other): + """ + Structural equality with another node. + """ + return ( + self.name == other.name and + [parent.name for parent in self.parents] == + [parent.name for parent in other.parents] and + set(child.name for child in self.children) == + set(child.name for child in other.children) and + self.event == other.event and + self.age == other.age and + self.pos == other.pos) class ARG (object): @@ -83,51 +93,62 @@ def __init__(self, start=0.0, end=1.0): self.start = start self.end = end - def __iter__(self): - """Iterates over the nodes in the ARG""" + """Iterates over the nodes in the ARG.""" return self.nodes.itervalues() - def __len__(self): - """Returns number of nodes in the ARG""" + """Returns number of nodes in the ARG.""" return len(self.nodes) - def __getitem__(self, name): - """Returns node by name""" + """Returns node by name.""" return self.nodes[name] - def __setitem__(self, name, node): - """Adds a node to the ARG""" + """Adds a node to the ARG.""" node.name = name self.add(node) - def __contains__(self, name): """ - Returns True if node in ARG has name 'name' + Returns True if node in ARG has name 'name'. """ return name in self.nodes + def equal(self, other): + """ + Structural equality with another ARG. + """ + # Is the meta data equal? + if (self.start != other.start or + self.end != other.end): + return False + + # Is each node equal? + for node in self: + if node.name not in other: + return False + if not node.equal(other[node.name]): + return False + + return True #================================= # node manipulation methods def new_name(self): """ - Returns a new name for a node + Returns a new name for a node. """ name = self.nextname self.nextname += 1 return name - def new_node(self, name=None, parents=[], children=[], age=0, event="gene", pos=0): """ - Returns a new node + Returns a new node. """ if name is None: name = self.new_name() @@ -136,25 +157,23 @@ def new_node(self, name=None, parents=[], children=[], node.children = list(children) return node - def new_root(self, age=0, event="gene", pos=0): """ - Returns a new root + Returns a new root. """ self.root = self.new_node(age=age, event=event, pos=pos) return self.root - def add(self, node): """ - Adds a node to the ARG + Adds a node to the ARG. """ self.nodes[node.name] = node return node def remove(self, node): """ - Removes a node from the ARG + Removes a node from the ARG. """ for child in node.children: child.parents.remove(node) @@ -162,20 +181,18 @@ def remove(self, node): parent.children.remove(node) del self.nodes[node.name] - def rename(self, oldname, newname): """ - Renames a node in the ARG + Renames a node in the ARG. """ node = self.nodes[oldname] node.name = newname del self.nodes[oldname] self.nodes[newname] = node - def leaves(self, node=None): """ - Iterates over the leaves of the ARG + Iterates over the leaves of the ARG. """ if node is None: for node in self: @@ -186,10 +203,9 @@ def leaves(self, node=None): if len(node.children) == 0: yield node - def leaf_names(self, node=None): """ - Iterates over the leaf names of the ARG + Iterates over the leaf names of the ARG. """ if node is None: for node in self: @@ -200,12 +216,10 @@ def leaf_names(self, node=None): if len(node.children) == 0: yield node.name - def copy(self): """ - Returns a copy of this ARG + Returns a copy of this ARG. """ - arg = ARG(start=self.start, end=self.end) arg.nextname = self.nextname @@ -225,16 +239,14 @@ def copy(self): arg.root = arg[self.root.name] return arg - #================================ # iterator methods - + def postorder(self, node=None): """ - Iterates through nodes in postorder traversal + Iterates through nodes in postorder traversal. """ - visit = defaultdict(lambda: 0) queue = list(self.leaves(node)) @@ -247,12 +259,10 @@ def postorder(self, node=None): if visit[parent] == len(parent.children): queue.append(parent) - def preorder(self, node=None): """ - Iterates through nodes in preorder traversal + Iterates through nodes in preorder traversal. """ - visit = set() if node is None: node = self.root @@ -263,11 +273,10 @@ def preorder(self, node=None): continue yield node visit.add(node) - + for child in node.children: queue.append(child) - def postorder_marginal_tree(self, pos, nodes=None): """ Iterate postorder over the nodes in the marginal tree at position 'pos' @@ -276,7 +285,6 @@ def postorder_marginal_tree(self, pos, nodes=None): NOTE: nodes are iterated in order of age """ - # initialize heap heap = [(node.age, node) for node in self.leaves()] seen = set([None]) @@ -286,7 +294,7 @@ def postorder_marginal_tree(self, pos, nodes=None): def get_local_children(node, pos): return [child for child in self.get_local_children(node, pos) if child in nodes] - + def reachable(node): # returns True if node is unreachable from leaves if node in visited or node.is_leaf(): @@ -310,7 +318,6 @@ def ready(node): if child not in visited and reachable(child): return False return True - # add all ancestor of lineages unready = [] @@ -330,7 +337,7 @@ def ready(node): if len(heap) == 0: # MRCA reached return - + # find correct marginal parent # add parent to lineages if it has not been seen before parent = self.get_local_parent(node, pos) @@ -338,7 +345,6 @@ def ready(node): heapq.heappush(heap, (parent.age, parent)) seen.add(parent) - def preorder_marginal_tree(self, pos, node=None): """ Iterate preorder over the nodes in the marginal tree at position 'pos' @@ -347,12 +353,12 @@ def preorder_marginal_tree(self, pos, node=None): """ if node is None: - node = arg.root + node = self.root # initialize heap heap = [node] seen = set([node]) - + # add all ancestor of lineages while len(heap) > 0: node = heap.pop() @@ -366,9 +372,8 @@ def preorder_marginal_tree(self, pos, node=None): # NOTE: this prevents error when # children[0] == children[1] - def get_local_parent(self, node, pos): - """Returns the local parent of 'node' for position 'pos'""" + """Returns the local parent of 'node' for position 'pos'.""" if node.event == "gene" or node.event == "coal": if len(node.parents) > 0: return node.parents[0] @@ -384,19 +389,14 @@ def get_local_parent(self, node, pos): return None elif len(node.parents) > 1: return node.parents[0 if pos < node.pos else 1] - - ''' - if len(node.parents) > 0: - return node.parents[0 if pos < node.pos else 1] - else: - return None - ''' else: raise Exception("unknown event '%s'" % node.event) - def get_local_parents(self, node, start, end): - """Returns the parents of 'node' with ancestral sequence within (start, end)""" + """ + Return the parents of 'node' with ancestral sequence within + (start, end) + """ if node.event == "recomb": parents = [] if node.pos > start: @@ -407,18 +407,16 @@ def get_local_parents(self, node, start, end): parents = node.parents return parents - def get_local_children(self, node, pos): """ Returns the local children of 'node' for position 'pos' - + NOTE: the local children are not necessarily in the local tree because the children may be unreachable from the leaves """ return [child for child in node.children if self.get_local_parent(child, pos) == node] - def get_local_dist(self, node, pos): """Returns the local parent of 'node' for position 'pos'""" @@ -428,9 +426,10 @@ def get_local_dist(self, node, pos): else: return 0.0 - def set_root(self): - + """ + Set the root node of the ARG. + """ for node in self: if not node.parents: self.root = node @@ -441,9 +440,8 @@ def set_root(self): def set_recomb_pos(self, start=None, end=None, descrete=False): """ - Randomly aample all recombination positions in the ARG + Randomly aample all recombination positions in the ARG. """ - if start is not None: self.start = start if end is not None: @@ -458,14 +456,12 @@ def set_recomb_pos(self, start=None, end=None, descrete=False): else: node.pos = random.random() * length + self.start - def set_ancestral(self): """ - Set all ancestral regions for the nodes of the ARG + Set all ancestral regions for the nodes of the ARG. NOTE: recombination positions must be set first (set_recomb_pos) """ - def root_path(ptr, pos): "walk up the root path from a node" while ptr.parents: @@ -475,9 +471,10 @@ def root_path(ptr, pos): for node in self: node.data["ancestral"] = [] - for block, tree in iter_tree_tracks(self): + for block, tree in iter_local_trees(self): pos = (block[0] + block[1]) / 2.0 - for node in chain(tree, root_path(self.nodes[tree.root.name], pos)): + for node in chain(tree, root_path( + self.nodes[tree.root.name], pos)): if node.name in self.nodes: ancestral = self[node.name].data["ancestral"] if len(ancestral) > 0 and ancestral[-1][1] == block[0]: @@ -489,27 +486,25 @@ def root_path(ptr, pos): # cap node? pass - def get_ancestral(self, node, side=None, parent=None): """ - Get the ancestral sequence from an edge above a node - + Get the ancestral sequence from an edge above a node. + node -- node to get ancestral sequence from side -- 0 for left parent edge, 1 for right parental edge parent -- if given, determine side from parent node """ - # set side from parent if parent: side = node.parents.index(parent) if node.event == "recomb": if (parent and len(node.parents) == 2 and - node.parents[0] == node.parents[1]): + node.parents[0] == node.parents[1]): # special case where both children of a coal node are the same # recomb node. return node.data["ancestral"] - + regions = [] for reg in node.data["ancestral"]: if side == 0: @@ -529,22 +524,20 @@ def get_ancestral(self, node, side=None, parent=None): else: raise Exception("side not specified") return regions - + elif node.event == "gene" or node.event == "coal": return node.data["ancestral"] else: - raise Exception("unknown event '%s'" % node.event) - + raise Exception("unknown event '%s'" % node.event) def prune(self, remove_single=True): """ - Prune ARG to only those nodes with ancestral sequence + Prune ARG to only those nodes with ancestral sequence. """ - # NOTE: be careful when removing nodes that you call get_ancestral # before changing parent/child orders - + # find pruned edges prune_edges = [] for node in list(self): @@ -556,42 +549,38 @@ def prune(self, remove_single=True): for node, parent in prune_edges: parent.children.remove(node) node.parents.remove(parent) - + # remove pruned nodes for node in list(self): if len(node.data["ancestral"]) == 0: self.remove(node) - for node in self: assert not node.is_leaf() or node.age == 0.0 # remove single children if remove_single: remove_single_lineages(self) - + # set root # TODO: may need to actually use self.roots for node in list(self): if len(node.parents) == 0: - dellist = [] while len(node.children) == 1: delnode = node node = node.children[0] self.remove(delnode) self.root = node - #=========================== # marginal tree methods - + def get_marginal_tree(self, pos, nodes=None): """ - Returns the marginal tree of the ARG containing position 'pos' + Returns the marginal tree of the ARG containing position 'pos'. if nodes is given, marginal tree can be determined quicker """ - # make new ARG to contain marginal tree tree = ARG(self.start, self.end) tree.nextname = self.nextname @@ -599,7 +588,7 @@ def get_marginal_tree(self, pos, nodes=None): # populate tree with marginal nodes for node in self.postorder_marginal_tree(pos, nodes=nodes): tree.add(node.copy()) - + # set parent and children roots = [] for node2 in tree: @@ -617,7 +606,7 @@ def get_marginal_tree(self, pos, nodes=None): tree.root = roots[0] elif len(roots) > 1: # make cap node since marginal tree does not fully coallesce - tree.root = tree.new_node(event="coal", + tree.root = tree.new_node(event="coal", name=self.new_name(), age=max(x.age for x in roots)+1) tree.nextname = self.nextname @@ -626,13 +615,12 @@ def get_marginal_tree(self, pos, nodes=None): node.parents.append(tree.root) assert tree.root is not None, (tree.nodes, pos) - + return tree - - + def get_tree(self, pos=None): """ - Returns a treelib.Tree() object representing the ARG if it is a tree + Return a treelib.Tree() object representing the ARG if it is a tree. if 'pos' is given, return a treelib.Tree() for the marginal tree at position 'pos'. @@ -664,20 +652,24 @@ def get_tree(self, pos=None): tree.root = tree[self.root.name] return tree - #======================= # input/output def read(self, filename=sys.stdin): + """ + Read ARG from filename or stream. + """ read_arg(filename, arg=self) - def write(self, filename=sys.stdout): + """ + Write ARG to filename or stream. + """ write_arg(filename, self) - #============================================================================= +# Asserts def assert_arg(arg): """Asserts that the arg data structure is consistent""" @@ -700,11 +692,12 @@ def assert_arg(arg): #============================================================================= -# coalescence with recombination +# Coalescence with recombination + def sample_coal_recomb(k, n, r): """ - Returns a sample time for either coal or recombination + Returns a sample time for either coal or recombination. k -- chromosomes n -- effective population size (haploid) @@ -722,13 +715,13 @@ def sample_coal_recomb(k, n, r): rate = coal_rate + recomb_rate event = ("coal", "recomb")[int(random.random() < (recomb_rate / rate))] - + return event, random.expovariate(rate) def sample_coal_recomb_times(k, n, r, t=0): """ - Returns a sample time for either coal or recombination + Returns a sample time for either coal or recombination. k -- chromosomes n -- effective population size (haploid) @@ -758,10 +751,10 @@ def sample_coal_recomb_times(k, n, r, t=0): return times, events -def sample_arg(k, n, rho, start=0.0, end=1.0, t=0, names=None, +def sample_arg(k, n, rho, start=0.0, end=1.0, t=0, names=None, make_names=True): """ - Returns an ARG sampled from the coalescent with recombination (pruned) + Returns an ARG sampled from the coalescent with recombination (pruned). k -- chromosomes n -- effective population size (haploid) @@ -817,7 +810,6 @@ def __init__(self, node, regions, seqlen): event = ("coal", "recomb")[int(random.random() < (recomb_rate / rate))] t += t2 - # process event if event == "coal": node = arg.new_node(age=t, event=event) @@ -840,9 +832,9 @@ def __init__(self, node, regions, seqlen): lineage_regions = [] nblocks = len(block_starts) i = 0 - + for start, end, count in count_region_overlaps( - a.regions, b.regions): + a.regions, b.regions): assert start != end, count in (0, 1, 2) #assert end == arg.end or end in block_starts i = block_starts.index(start, i) @@ -854,7 +846,7 @@ def __init__(self, node, regions, seqlen): if count == 2: block_counts[start2] -= 1 if count >= 1: - regions.append((start2, end2)) # ancestral seq + regions.append((start2, end2)) # ancestral seq if block_counts[start2] > 1: # regions moves on, since not MRCA lineage_regions.append((start2, end2)) @@ -872,7 +864,6 @@ def __init__(self, node, regions, seqlen): lineages.add(Lineage(node, lineage_regions, seqlen)) total_seqlen += seqlen - elif event == "recomb": node = arg.new_node(age=t, event=event) @@ -907,7 +898,7 @@ def __init__(self, node, regions, seqlen): # create 2 new lineages regions1 = list(split_regions(node.pos, 0, lineage.regions)) regions2 = list(split_regions(node.pos, 1, lineage.regions)) - + regions1_len = regions1[-1][1] - regions1[0][0] regions2_len = regions2[-1][1] - regions2[0][0] total_seqlen += regions1_len + regions2_len - lineage.seqlen @@ -925,8 +916,10 @@ def __init__(self, node, regions, seqlen): for node, (a, b) in recomb_parent_lineages.iteritems(): an = lineage_parents[a] bn = lineage_parents[b] - for reg in a.regions: assert reg[1] <= node.pos - for reg in b.regions: assert reg[0] >= node.pos + for reg in a.regions: + assert reg[1] <= node.pos + for reg in b.regions: + assert reg[0] >= node.pos node.parents = [an, bn] # set root @@ -935,11 +928,10 @@ def __init__(self, node, regions, seqlen): return arg - def sample_smc_sprs(k, n, rho, start=0.0, end=0.0, init_tree=None, names=None, make_names=True): """ - Sample ARG using Sequentially Markovian Coalescent (SMC) + Sample ARG using Sequentially Markov Coalescent (SMC) k -- chromosomes n -- effective population size (haploid) @@ -950,7 +942,6 @@ def sample_smc_sprs(k, n, rho, start=0.0, end=0.0, init_tree=None, names -- names to use for leaves (default: None) make_names -- make names using strings (default: True) """ - # yield initial tree first if init_tree is None: init_tree = sample_arg(k, n, rho=0.0, start=start, end=end, @@ -959,7 +950,7 @@ def sample_smc_sprs(k, n, rho, start=0.0, end=0.0, init_tree=None, else: init_tree.end = end tree = init_tree.get_marginal_tree(start) - remove_single_lineages(tree) + remove_single_lineages(tree) yield init_tree # sample SPRs @@ -974,7 +965,7 @@ def sample_smc_sprs(k, n, rho, start=0.0, end=0.0, init_tree=None, # choose branch for recombination p = random.uniform(0.0, treelen) total = 0.0 - nodes = (x for x in tree if x.parents) # root can't have a recomb + nodes = (x for x in tree if x.parents) # root can't have a recomb for node in nodes: total += node.get_dist() if total > p: @@ -996,19 +987,17 @@ def sample_smc_sprs(k, n, rho, start=0.0, end=0.0, init_tree=None, i = 0 #print while i < len(all_nodes): - #print coal_time, recomb_node, lineages - #treelib.draw_tree_names(tree.get_tree(), scale=1e-3, minlen=5) next_node = all_nodes[i] - + if next_node.age > recomb_time: if coal_time < recomb_time: coal_time = recomb_time next_time = coal_time + random.expovariate( len(lineages) / float(n)) - + if next_time < next_node.age: coal_time = next_time - + # choose coal branch coal_node = random.sample(lineages, 1)[0] assert coal_node.age < coal_time < coal_node.parents[0].age @@ -1023,20 +1012,20 @@ def sample_smc_sprs(k, n, rho, start=0.0, end=0.0, init_tree=None, if child in lineages: lineages.remove(child) else: - assert child == recomb_node, (next_node, child, recomb_node) + assert child == recomb_node, ( + next_node, child, recomb_node) if next_node != recomb_node: lineages.add(next_node) else: # coal above tree coal_node = all_nodes[-1] coal_time = coal_node.age + random.expovariate(1.0 / float(n)) - + # yield SPR rleaves = list(tree.leaf_names(recomb_node)) cleaves = list(tree.leaf_names(coal_node)) yield pos, (rleaves, recomb_time), (cleaves, coal_time) - # apply SPR to local tree broken = recomb_node.parents[0] recoal = tree.new_node(age=coal_time, @@ -1051,7 +1040,6 @@ def sample_smc_sprs(k, n, rho, start=0.0, end=0.0, init_tree=None, coal_node.parents[0] = recoal else: coal_node.parents.append(recoal) - # remove broken node broken_child = broken.children[0] @@ -1063,29 +1051,27 @@ def sample_smc_sprs(k, n, rho, start=0.0, end=0.0, init_tree=None, del tree.nodes[broken.name] tree.set_root() - def sample_arg_smc(k, n, rho, start=0.0, end=0.0, init_tree=None, names=None, make_names=True): """ Returns an ARG sampled from the Sequentially Markovian Coalescent (SMC) - + k -- chromosomes n -- effective population size (haploid) rho -- recombination rate (recombinations / site / generation) start -- staring chromosome coordinate end -- ending chromsome coordinate - + names -- names to use for leaves (default: None) make_names -- make names using strings (default: True) """ - it = sample_smc_sprs(k, n, rho, start=start, end=end, init_tree=init_tree, names=names, make_names=make_names) tree = it.next() arg = make_arg_from_sprs(tree, it) - + return arg @@ -1095,28 +1081,26 @@ def sample_arg_smc(k, n, rho, start=0.0, end=0.0, init_tree=None, def lineages_over_time(k, events): """ - Computes number of lineage though time using coal/recomb events + Computes number of lineage though time using coal/recomb events. """ - for event in events: if event == "coal": k -= 1 elif event == "recomb": k += 1 else: - raise Exception("unknown event '%s'" % event) + raise Exception("unknown event '%s'" % event) yield k - + def make_arg_from_times(k, times, events, start=0, end=1, names=None, make_names=True): """ - Returns an ARG given 'k' samples and a list of 'times' and 'events' + Returns an ARG given 'k' samples and a list of 'times' and 'events'. times -- ordered times of coalescence or recombination events -- list of event types (either 'coal' or 'recomb') """ - arg = ARG(start, end) # make leaves @@ -1126,7 +1110,7 @@ def make_arg_from_times(k, times, events, start=0, end=1, lineages = set((arg.new_node(), 1) for i in xrange(k)) else: lineages = set((arg.new_node(name=names[i]), 1) for i in xrange(k)) - + # process events for t, event in izip(times, events): if event == "coal": @@ -1138,7 +1122,7 @@ def make_arg_from_times(k, times, events, start=0, end=1, a[0].parents.append(node) b[0].parents.append(node) lineages.add((node, 1)) - + elif event == "recomb": node = arg.add(ArgNode(arg.new_name(), age=t, event=event)) a = random.sample(lineages, 1)[0] @@ -1151,16 +1135,15 @@ def make_arg_from_times(k, times, events, start=0, end=1, else: raise Exception("unknown event '%s'" % event) - if len(lineages) == 1: - arg.root = lineages.pop()[0] + arg.root = lineages.pop()[0] return arg def make_arg_from_tree(tree, times=None): """ - Creates an ARG from a treelib.Tree 'tree' + Creates an ARG from a treelib.Tree 'tree'. """ arg = ARG() if times is None: @@ -1182,18 +1165,18 @@ def make_arg_from_tree(tree, times=None): arg.nextname = max(node.name for node in arg if isinstance(node.name, int)) + 1 - + return arg def get_recombs(arg, start=None, end=None, visible=False): """ - Returns a sorted list of an ARG's recombination positions + Returns a sorted list of an ARG's recombination positions. - visible -- if True only iterate recombination break points that are + visible -- if True only iterate recombination break points that are visible to extant sequences """ - + if visible: return list(iter_visible_recombs(arg, start, end)) else: @@ -1201,13 +1184,13 @@ def get_recombs(arg, start=None, end=None, visible=False): arg if node.event == "recomb"] rpos.sort() return rpos -get_recomb_pos = get_recombs + def iter_recombs(arg, start=None, end=None, visible=False): """ Iterates through an ARG's recombination positions - visible -- if True only iterate recombination break points that are + visible -- if True only iterate recombination break points that are visible to extant sequences """ @@ -1222,7 +1205,6 @@ def iter_recombs(arg, start=None, end=None, visible=False): def iter_visible_recombs(arg, start=None, end=None): """Iterates through visible recombinations in an ARG""" - pos = start if start is not None else 0 while True: recomb = find_next_recomb(arg, pos) @@ -1235,7 +1217,6 @@ def iter_visible_recombs(arg, start=None, end=None): def find_next_recomb(arg, pos, tree=False): """Returns the next recombination node in a local tree""" - recomb = None nextpos = util.INF @@ -1254,15 +1235,14 @@ def find_next_recomb(arg, pos, tree=False): def iter_recomb_blocks(arg, start=None, end=None, visible=False): """ - Iterates over the recombination blocks of an ARG + Iterates over the recombination blocks of an ARG. arg -- ARG to iterate over start -- starting position in chromosome to iterate over end -- ending position in chromosome to iterate over - visible -- if True only iterate recombination break points that are + visible -- if True only iterate recombination break points that are visible to extant sequences """ - # determine region to iterate over if start is None: start = arg.start @@ -1286,7 +1266,7 @@ def iter_recomb_blocks(arg, start=None, end=None, visible=False): def iter_marginal_trees(arg, start=None, end=None): """ - Iterate over the marginal trees of an ARG + Iterate over the marginal trees of an ARG. """ for block, tree in iter_local_trees(arg, start, end): yield tree @@ -1294,7 +1274,7 @@ def iter_marginal_trees(arg, start=None, end=None): def iter_local_trees(arg, start=None, end=None, convert=False): """ - Iterate over the local trees of an ARG + Iterate over the local trees of an ARG. Yeilds ((start, end), tree) for each marginal tree where (start, end) defines the block of the marginal tree @@ -1315,7 +1295,7 @@ def iter_local_trees(arg, start=None, end=None, convert=False): i += 1 tree = arg.get_marginal_tree((start+rpos[i]) / 2.0) - + # find block end end2 = arg.end for node in tree: @@ -1327,12 +1307,10 @@ def iter_local_trees(arg, start=None, end=None, convert=False): yield (start, min(end2, end)), tree start = end2 -iter_tree_tracks = iter_local_trees - def descendants(node, nodes=None): """ - Return all descendants of a node in an ARG + Return all descendants of a node in an ARG. """ if nodes is None: nodes = set() @@ -1345,14 +1323,14 @@ def descendants(node, nodes=None): def remove_single_lineages(arg): """ - Remove unnecessary nodes with single parent and single child + Remove unnecessary nodes with single parent and single child. """ queue = list(arg) for node in queue: if node.name not in arg: continue - + if len(node.children) == 1: if len(node.parents) == 1: child = node.children[0] @@ -1379,9 +1357,11 @@ def remove_single_lineages(arg): return arg - def postorder_subarg(arg, start, end): - """Iterates postorder over the nodes of the 'arg' that are ancestral to (start,end)""" + """ + Iterate postorder over the nodes of the 'arg' that are ancestral to + (start,end) + """ # initialize heap heap = [(node.age, node) for node in arg.leaves()] @@ -1404,7 +1384,9 @@ def postorder_subarg(arg, start, end): def subarg(arg, start, end): - """Returns a new ARG that only contains recombination within (start, end)""" + """ + Returns a new ARG that only contains recombination within (start, end). + """ arg2 = ARG(start, end) @@ -1428,9 +1410,8 @@ def subarg(arg, start, end): def subarg_by_leaves(arg, leaves, keep_single=False): """ - Removes any leaf from the arg that is not in leaves set + Removes any leaf from the arg that is not in leaves set. """ - stay = set(leaves) remove = [] @@ -1445,16 +1426,16 @@ def subarg_by_leaves(arg, leaves, keep_single=False): # remove nodes for node in remove: arg.remove(node) - + if not keep_single: remove_single_lineages(arg) - + return arg def apply_spr(tree, rnode, rtime, cnode, ctime, rpos): """ - Apply an Subtree Pruning Regrafting (SPR) operation on a tree + Apply an Subtree Pruning Regrafting (SPR) operation on a tree. """ if rnode == cnode: return None, None @@ -1485,24 +1466,23 @@ def remove_node(arg, node): arg.root = child del arg.nodes[node.name] - + coal = add_node(tree, cnode, ctime, rpos, "coal") - + broken_node = rnode.parents[0] - broken_node.children.remove(rnode) + broken_node.children.remove(rnode) remove_node(tree, broken_node) - + rnode.parents[0] = coal coal.children.append(rnode) return coal, broken_node - - -def iter_arg_sprs(arg, start=None, end=None, use_leaves=False, use_local=False): +def iter_arg_sprs(arg, start=None, end=None, + use_leaves=False, use_local=False): """ - Iterate through the SPR moves of an ARG + Iterate through the SPR moves of an ARG. Yields (recomb_pos, (rnode, rtime), (cnode, ctime)) @@ -1553,7 +1533,7 @@ def walk_down(node, local, pos): local = set(nodes) local_root = nodes[-1] pos = start - + while pos < end: # find next recombination node after 'pos' recomb_pos = end @@ -1573,7 +1553,6 @@ def walk_down(node, local, pos): ptr = recomb rnode = walk_down(recomb, local, mid).name - # find recoal node ptr = recomb local_root_path = [] @@ -1603,16 +1582,15 @@ def walk_down(node, local, pos): local_root_path = local_root_path[:i+1] break ctime = ptr.age - recoal = ptr + # NOTE: recoal = ptr # find recoal baring branch in local tree # walk down until next coalescent node in local tree if ptr in local: cnode = walk_down(ptr, local, mid).name else: - cnode = local_root.name + cnode = local_root.name - # find broken nodes # walk up left parent of recomb until coalescent node or coal path ptr = recomb @@ -1644,7 +1622,7 @@ def walk_down(node, local, pos): broken_path.append(ptr) ptr = children[0] local_root = ptr - + # yield SPR if use_leaves: rleaves = list(x.name for x in @@ -1658,13 +1636,12 @@ def walk_down(node, local, pos): else: recomb_point = (rnode, rtime) coal_point = (cnode, ctime) - + if use_local: yield (recomb_pos, recomb_point, coal_point, local) else: yield (recomb_pos, recomb_point, coal_point) - # update local nodes if cnode == local_root.name: # add root path @@ -1676,22 +1653,20 @@ def walk_down(node, local, pos): local.remove(node) for node in coal_path: local.add(node) - + # advance the current position pos = recomb_pos - def iter_arg_sprs_simple(arg, start=None, end=None, use_leaves=False): """ - Iterate through the SPR moves of an ARG + Iterate through the SPR moves of an ARG. Yields (recomb_pos, (rnode, rtime), (cnode, ctime)) """ - - trees = iter_tree_tracks(arg, start, end) + trees = iter_local_trees(arg, start, end) block, last_tree = trees.next() - + for block, tree in trees: # find recombination node @@ -1739,7 +1714,7 @@ def make_arg_from_sprs(init_tree, sprs, ignore_self=False, NOTE: sprs should indicate branches by their leaf set (use_leaves=True) """ - + def add_node(arg, node, time, pos, event): node2 = arg.new_node(event=event, age=time, children=[node], pos=pos) if event == "coal": @@ -1747,7 +1722,7 @@ def add_node(arg, node, time, pos, event): parent = arg.get_local_parent(node, pos) #if parent is None and node.event == "recomb": # parent = node.parents[0] - + if parent: node.parents[node.parents.index(parent)] = node2 parent.children[parent.children.index(node)] = node2 @@ -1755,16 +1730,15 @@ def add_node(arg, node, time, pos, event): else: node.parents.append(node2) arg.root = node2 - + return node2 - - + def walk_up(arg, node, time, pos, local): parent = arg.get_local_parent(node, pos) - + while parent and parent.age <= time: if parent in local: - break + break node = parent parent = arg.get_local_parent(node, pos) @@ -1777,7 +1751,6 @@ def walk_up(arg, node, time, pos, local): assert False return node - arg = init_tree tree = None @@ -1814,7 +1787,7 @@ def walk_up(arg, node, time, pos, local): #rnode2 = arg_lca(arg, rleaves, rpos, time=rtime) #cnode2 = arg_lca(arg, cleaves, rpos, time=ctime) #assert (rnode == rnode2) and (cnode == cnode2) - + # add edge to ARG recomb = add_node(arg, rnode, rtime, rpos, "recomb") if rnode == cnode: @@ -1834,14 +1807,13 @@ def walk_up(arg, node, time, pos, local): del mapping[broken_node.name] mapping[coal2.name] = coal local.add(coal) - return arg def make_arg_from_sprs_simple(init_tree, sprs, ignore_self=False): """ - Make an ARG from an initial tree 'init_tree' and a list of SPRs 'sprs' + Make an ARG from an initial tree 'init_tree' and a list of SPRs 'sprs'. NOTE: sprs should indicate branches by their leaf set (use_leaves=True) """ @@ -1868,7 +1840,7 @@ def add_node(arg, node, time, pos, event): # check whether self cycles are wanted if ignore_self and node1 == node2: continue - + recomb = add_node(arg, node1, rtime, rpos, "recomb") if node1 == node2: node2 = recomb @@ -1882,9 +1854,8 @@ def add_node(arg, node, time, pos, event): def smcify_arg(arg, start=None, end=None, ignore_self=True): """ - Rebuild an ARG so that is follows the SMC assumptions + Rebuild an ARG so that is follows the SMC assumptions. """ - if start is None: start = arg.start @@ -1898,23 +1869,21 @@ def smcify_arg(arg, start=None, end=None, ignore_self=True): arg2.start = start if end is not None: arg2.end = end - + return arg2 - def subarg_by_leaf_names(arg, leaf_names, keep_single=False): """ - Removes any leaf from the arg that is not in leaf name set + Removes any leaf from the arg that is not in leaf name set. """ - return subarg_by_leaves(arg, [arg[x] for x in leaf_names], keep_single=keep_single) def arg_lca(arg, leaves, pos, time=None, local=None): """ - Find the Least Common Ancestor (LCA) of a set of leaves in the ARG + Find the Least Common Ancestor (LCA) of a set of leaves in the ARG. arg -- an ARG leaves -- a list of nodes in arg @@ -1931,12 +1900,11 @@ def is_local_coal(arg, node, pos, local): arg.get_local_parent(node.children[1], pos) == node and node.children[0] != node.children[1]) - order = dict((node, i) for i, node in enumerate( arg.postorder_marginal_tree(pos))) if local is None: local = order - + queue = [(order[arg[x]], arg[x]) for x in leaves] seen = set(x[1] for x in queue) heapq.heapify(queue) @@ -1950,11 +1918,10 @@ def is_local_coal(arg, node, pos, local): node = queue[0][1] parent = arg.get_local_parent(node, pos) - if time is not None: while parent and parent.age <= time: if is_local_coal(arg, parent, pos, local): - break + break node = parent parent = arg.get_local_parent(node, pos) @@ -1969,12 +1936,10 @@ def is_local_coal(arg, node, pos, local): return node - def arglen(arg, start=None, end=None): """Calculate the total branch length of an ARG""" - treelen = 0.0 - for (start, end), tree in iter_tree_tracks(arg, start=start, end=end): + for (start, end), tree in iter_local_trees(arg, start=start, end=end): treelen += sum(x.get_dist() for x in tree) * (end - start) return treelen @@ -1986,9 +1951,9 @@ def arglen(arg, start=None, end=None): def split_regions(pos, side, regions): """ - Iterates through the regions on the left (side=0) or right (side=1) of 'pos' + Iterate through the regions on the left (side=0) or right (side=1) of + position 'pos'. """ - for reg in regions: if side == 0: if reg[1] <= pos: @@ -2011,10 +1976,9 @@ def split_regions(pos, side, regions): def count_region_overlaps(*region_sets): """ Count how many regions overlap each interval (start, end) - + Iterates through (start, end, count) sorted """ - # build endpoints list end_points = [] for regions in region_sets: @@ -2024,8 +1988,6 @@ def count_region_overlaps(*region_sets): end_points.sort() count = 0 - start = None - end = None last = None for pos, kind in end_points: if last is not None and pos != last: @@ -2038,20 +2000,18 @@ def count_region_overlaps(*region_sets): if last is not None and pos != last: yield last, pos, count - - + def groupby_overlaps(regions, bygroup=True): """ Group ranges into overlapping groups Ranges must be sorted by start positions """ - start = -util.INF end = -util.INF group = None groupnum = -1 - for reg in regions: + for reg in regions: if reg[0] > end: # start new group start, end = reg @@ -2081,6 +2041,7 @@ def groupby_overlaps(regions, bygroup=True): #============================================================================= # mutations and splits + def sample_arg_mutations(arg, mu, minlen=0): """ mu -- mutation rate (mutations/site/gen) @@ -2088,7 +2049,7 @@ def sample_arg_mutations(arg, mu, minlen=0): mutations = [] - for (start, end), tree in iter_tree_tracks(arg): + for (start, end), tree in iter_local_trees(arg): remove_single_lineages(tree) for node in tree: if not node.parents: @@ -2116,7 +2077,7 @@ def get_mutation_split(arg, mutation): def split_to_tree_branch(tree, split): """Place a split on a tree branch""" - + node = treelib.lca([tree[name] for name in split]) if sorted(split) != sorted(node.leaf_names()): @@ -2147,7 +2108,7 @@ def iter_tree_splits(tree): split = tuple(sorted(tree.leaf_names(node))) if len(split) > 1: yield split - + def is_split_compatible(split1, split2): """Returns True if two splits are compatible""" @@ -2188,6 +2149,7 @@ def is_split_compatible_unpolar2(split1, split2, leaves): return not (x00 and x01 and x10 and x11) + def is_split_compatible_unpolar(split1, split2, leaves): if is_split_compatible(split1, split2): return True @@ -2225,7 +2187,7 @@ def split_relation(split1, split2): else: return "conflict" - + return intersect == 0 or intersect == min(len(split1), len(split2)) @@ -2240,8 +2202,6 @@ def iter_mutation_splits(arg, mutations): yield pos, split - - #============================================================================= # alignments @@ -2257,8 +2217,6 @@ def make_alignment(arg, mutations, infinite_sites=True, # make align matrix mat = [] - - pos = arg.start muti = 0 for i in xrange(alnlen): if muti >= len(mutations) or i < int(mutations[muti][2]): @@ -2270,14 +2228,13 @@ def make_alignment(arg, mutations, infinite_sites=True, #while muti < len(mutations) and i == int(mutations[muti][2]): # mut_group.append(mutations[muti]) # muti += 1 - + node, parent, mpos, t = mutations[muti] - row = [] split = set(x.name for x in get_marginal_leaves(arg, node, mpos)) mat.append("".join((derived if leaf in split else ancestral) for leaf in leaves)) muti += 1 - + # make fasta for i, leaf in enumerate(leaves): aln[leaf] = "".join(x[i] for x in mat) @@ -2298,9 +2255,9 @@ def iter_align_splits(aln, warn=False): if warn and len(chars) != 2: print >>sys.stderr, "warning: not bi-allelic (site=%d)" % j - part1 = tuple(sorted(names[i] for i, c in enumerate(col) + part1 = tuple(sorted(names[i] for i, c in enumerate(col) if c == chars[0])) - part2 = tuple(sorted(names[i] for i, c in enumerate(col) + part2 = tuple(sorted(names[i] for i, c in enumerate(col) if c != chars[0])) if len(part1) > len(part2): part1, part2 = part2, part1 @@ -2309,7 +2266,6 @@ def iter_align_splits(aln, warn=False): yield j, split - #============================================================================= # input/output @@ -2318,7 +2274,7 @@ def write_arg(filename, arg): """ Write ARG to file """ - + out = util.open_stream(filename, "w") # write ARG key values @@ -2335,7 +2291,7 @@ def write_arg(filename, arg): ",".join(str(x.name) for x in node.parents), ",".join(str(x.name) for x in node.children), out=out) - + if isinstance(filename, basestring): out.close() @@ -2353,13 +2309,13 @@ def parse_node_name(text): else: return text + def parse_key_value(field): try: i = field.index("=") return field[:i], field[i+1:] except: - raise Exception("improper key-value field '%s'" % text) - + raise Exception("improper key-value field '%s'" % field) def read_arg(filename, arg=None): @@ -2379,7 +2335,7 @@ def read_arg(filename, arg=None): arg.start = int(val) elif key == "end": arg.end = int(val) - + # read header row = infile.next() assert row == ["name", "event", "age", "pos", "parents", "children"] @@ -2388,8 +2344,8 @@ def read_arg(filename, arg=None): clinks = {} plinks = {} for row in infile: - node = arg.new_node(name=parse_node_name(row[0]), event=row[1], - age=float(row[2]), + node = arg.new_node(name=parse_node_name(row[0]), event=row[1], + age=float(row[2]), pos=parse_number(row[3])) if len(row) > 4 and len(row[4]) > 0: plinks[node.name] = map(parse_node_name, row[4].split(",")) @@ -2403,7 +2359,7 @@ def read_arg(filename, arg=None): if parent: node.parents.append(parent) else: - raise Exception("node '%s' has unknown parent '%s'" % + raise Exception("node '%s' has unknown parent '%s'" % (node.name, parent_name)) # detect root @@ -2418,12 +2374,11 @@ def read_arg(filename, arg=None): node.children.append(child) assert node in child.parents, \ "node '%s' doesn't have parent '%s' (%s)" % ( - child.name, node.name, str(child.parents)) + child.name, node.name, str(child.parents)) else: - raise Exception("node '%s' has unknown child '%s'" % + raise Exception("node '%s' has unknown child '%s'" % (node.name, child_name)) - # set nextname for name in arg.nodes: if isinstance(name, int): @@ -2434,7 +2389,7 @@ def read_arg(filename, arg=None): def write_tree_tracks(filename, arg, start=None, end=None, verbose=False): out = util.open_stream(filename, "w") - for block, tree in iter_tree_tracks(arg, start, end): + for block, tree in iter_local_trees(arg, start, end): if verbose: print >>sys.stderr, "writing block", block remove_single_lineages(tree) @@ -2451,7 +2406,6 @@ def read_tree_tracks(filename): yield (int(row[0]), int(row[1])), treelib.parse_newick(row[2]) - def write_mutations(filename, arg, mutations): out = util.open_stream(filename, "w") @@ -2491,7 +2445,6 @@ def read_ancestral(filename, arg): # OLD CODE - def sample_mutations(arg, u): """ u -- mutation rate (mutations/locus/gen) @@ -2512,9 +2465,8 @@ def sample_mutations(arg, u): break else: continue - + frac = (region[1] - region[0]) / locsize - dist = parent.age - node.age t = parent.age while True: t -= random.expovariate(u * frac) @@ -2524,640 +2476,3 @@ def sample_mutations(arg, u): mutations.append((node, parent, pos, t)) return mutations - - -''' - -def has_self_cycles(arg): - """ - Return True if there are lineages that coalesce with themselves - - Requires ancestral sequences set. - """ - - # Such a cycle does not contain 'local coalescent nodes' on the sides - # but it might have non-local coalescent and recombination nodes. - # The relative order of theses nodes from the left and right side - # does not matter - # - # | - # coal - # / \ - # | | - # \ / - # recomb - # | - - # get overall postorder - # assumes stable sort - nodes = list(arg.postorder()) - nodes.sort(key=lambda x: x.age) - order = dict((x, i) for i, x in enumerate(nodes)) - - # find cycles by their recombination nodes - recombs = [x.name for x in arg if x.event == "recomb"] - - # find smallest separation - recomb_pos = [arg[x].pos for x in recombs] - recomb_pos.sort() - eps = .5 - for i in xrange(1, len(recomb_pos)): - sep = recomb_pos[i] - recomb_pos[i-1] - if sep > 0 and sep/2.0 < eps: - eps = sep / 2.0 - - for recomb_name in recombs: - if recomb_name not in arg: - continue - if is_self_cycle(arg, arg[recomb_name], order=order, eps=eps): - print recomb_name, arg[recomb_name].pos - return True - - return False - - -def iter_self_cycles(arg): - """ - Return True if there are lineages that coalesce with themselves - - Requires ancestral sequences set. - """ - - # Such a cycle does not contain 'local coalescent nodes' on the sides - # but it might have non-local coalescent and recombination nodes. - # The relative order of theses nodes from the left and right side - # does not matter - # - # | - # coal - # / \ - # | | - # \ / - # recomb - # | - - # get overall postorder - # assumes stable sort - nodes = list(arg.postorder()) - nodes.sort(key=lambda x: x.age) - order = dict((x, i) for i, x in enumerate(nodes)) - - # find cycles by their recombination nodes - recombs = [x.name for x in arg if x.event == "recomb"] - - # find smallest separation - recomb_pos = [arg[x].pos for x in recombs] - recomb_pos.sort() - eps = .5 - for i in xrange(1, len(recomb_pos)): - sep = recomb_pos[i] - recomb_pos[i-1] - if sep > 0 and sep/2.0 < eps: - eps = sep / 2.0 - - for recomb_name in recombs: - if recomb_name not in arg: - continue - if is_self_cycle(arg, arg[recomb_name], order=order, eps=eps): - yield arg[recomb_name] - - - - -def is_self_cycle(arg, recomb, order=None, eps=1e-4): - - def is_local_coal(node, child, pos): - if node.event != "coal": - return False - - i = node.children.index(child) - other_child = node.children[1 - i] - - for start, end in other_child.data["ancestral"]: - if start < pos < end: - return True - - return False - - if order is None: - # get overall postorder - # assumes stable sort - nodes = list(arg.postorder()) - nodes.sort(key=lambda x: x.age) - order = dict((x, i) for i, x in enumerate(nodes)) - - # find cycle - # also check for local coal nodes along the way - rpos = recomb.pos - path1 = [] - path2 = [] - ptr1 = arg.get_local_parent(recomb, rpos-eps) - ptr2 = arg.get_local_parent(recomb, rpos+eps) - while ptr1 and ptr2: - order1 = order[ptr1] - order2 = order[ptr2] - - if order1 < order2: - if is_local_coal(ptr1,path1[-1] if path1 else recomb,rpos-eps): - break - path1.append(ptr1) - ptr1 = arg.get_local_parent(ptr1, rpos-eps) - - elif order1 > order2: - if is_local_coal(ptr2,path2[-1] if path2 else recomb,rpos+eps): - break - path2.append(ptr2) - ptr2 = arg.get_local_parent(ptr2, rpos+eps) - - else: - # we have reached coal node - assert ptr1 == ptr2 - coal = ptr1 - return True - - return False - - -def remove_self_cycles(arg): - """ - Removes cycles that represent a lineage coalescing with itself - - Requires ancestral sequences set. - """ - - # Such a cycle does not contain 'local coalescent nodes' on the sides - # but it might have non-local coalescent and recombination nodes. - # The relative order of theses nodes from the left and right side - # does not matter - # - # | - # coal - # / \ - # | | - # \ / - # recomb - # | - - def is_local_coal(node, child, pos): - if node.event != "coal": - return False - - i = node.children.index(child) - other_child = node.children[1 - i] - - for start, end in other_child.data["ancestral"]: - if start < pos < end: - return True - - return False - - # get overall postorder - # assumes stable sort - nodes = list(arg.postorder()) - nodes.sort(key=lambda x: x.age) - order = dict((x, i) for i, x in enumerate(nodes)) - - # find cycles by their recombination nodes - recombs = [x.name for x in arg if x.event == "recomb"] - - # find smallest separation - recomb_pos = [arg[x].pos for x in recombs] - recomb_pos.sort() - eps = .5 - for i in xrange(1, len(recomb_pos)): - sep = recomb_pos[i] - recomb_pos[i-1] - if sep > 0 and sep/2.0 < eps: - eps = sep / 2.0 - - - for recomb_name in recombs: - if recomb_name not in arg: - continue - recomb = arg[recomb_name] - rpos = recomb.pos - - # find cycle - # also check for local coal nodes along the way - is_cycle = False - path1 = [] - path2 = [] - ptr1 = arg.get_local_parent(recomb, rpos-eps) - ptr2 = arg.get_local_parent(recomb, rpos+eps) - while ptr1 and ptr2: - order1 = order[ptr1] - order2 = order[ptr2] - - if order1 < order2: - if is_local_coal(ptr1,path1[-1] if path1 else recomb,rpos-eps): - break - path1.append(ptr1) - ptr1 = arg.get_local_parent(ptr1, rpos-eps) - - elif order1 > order2: - if is_local_coal(ptr2,path2[-1] if path2 else recomb,rpos+eps): - break - path2.append(ptr2) - ptr2 = arg.get_local_parent(ptr2, rpos+eps) - - else: - # we have reached coal node - assert ptr1 == ptr2 - coal = ptr1 - is_cycle = True - break - - if not is_cycle: - # this recombination node is not a cycle - # either because it contains a local coal node or never recoals - # which can happen in SMC ARGs - continue - - if path1: - assert coal in path1[-1].parents - else: - assert coal in recomb.parents - if path2: - assert coal in path2[-1].parents - else: - assert coal in recomb.parents - - if len(set(path1) & set(path2)) != 0: - print [(order[x], x) for x in path1] - print [(order[x], x) for x in path2] - assert False - - # remove coal node - top = coal.parents[0] if coal.parents else None - if top: - util.replace(top.children, coal, None) - - # remove recomb node - bottom = recomb.children[0] - util.replace(bottom.parents, recomb, None) - - # unlink nodes in left path - last = recomb - for node in path1: - util.replace(last.parents, node, None) - util.replace(node.children, last, None) - last = node - util.replace(last.parents, coal, None) - - # unlink nodes in right path - last = recomb - #print "--" - for node in path2: - #print last.parents, node - util.replace(last.parents, node, None) - util.replace(node.children, last, None) - last = node - util.replace(last.parents, coal, None) - - - # merge paths - combine = path1 + path2 - combine.sort(key=lambda x: (x.age, order[x])) - last = bottom - - for n in combine: - util.replace(last.parents, None, n) - util.replace(n.children, None, last) - last = n - - if top: - util.replace(last.parents, None, top) - util.replace(top.children, None, last) - else: - # no top node - if last.event == "coal" or last.event == "gene": - last.parents = [] - elif last.event == "recomb": - # remove last recomb node since it is a single lineage - c = last.children[0] - p = last.parents[1 - last.parents.index(None)] - util.replace(c.parents, last, p) - util.replace(p.children, last, c) - del arg.nodes[last.name] - else: - raise Exception("unknown event '%s'" % last.event) - - del arg.nodes[recomb.name] - del arg.nodes[coal.name] - - - - - -def remove_self_cycles2(arg): - """ - Removes cycles that represent a lineage coalescing with itself - - Requires ancestral sequences set. - """ - - # Such a cycle does not contain 'local coalescent nodes' on the sides - # but it might have non-local coalescent and recombination nodes. - # The relative order of theses nodes from the left and right side - # does not matter - # - # | - # coal - # / \ - # | | - # \ / - # recomb - # | - - # get overall postorder - # assumes stable sort - nodes = list(arg.postorder()) - nodes.sort(key=lambda x: x.age) - order = dict((x, i) for i, x in enumerate(nodes)) - - # find cycles by their recombination nodes - recombs = [x.name for x in arg if x.event == "recomb"] - - # find smallest separation - recomb_pos = [arg[x].pos for x in recombs] - recomb_pos.sort() - eps = .5 - for i in xrange(1, len(recomb_pos)): - sep = recomb_pos[i] - recomb_pos[i-1] - if sep > 0 and sep/2.0 < eps: - eps = sep / 2.0 - - - for recomb_name in recombs: - if recomb_name not in arg: - continue - recomb = arg[recomb_name] - rpos = recomb.pos - - # find cycle - # also check for local coal nodes along the way - is_cycle = False - path1 = [] - path2 = [] - ptr1 = arg.get_local_parent(recomb, rpos-eps) - ptr2 = arg.get_local_parent(recomb, rpos+eps) - while ptr1 and ptr2: - order1 = order[ptr1] - order2 = order[ptr2] - - if order1 < order2: - if ptr1.event == "coal": - break - path1.append(ptr1) - ptr1 = arg.get_local_parent(ptr1, rpos-eps) - - elif order1 > order2: - if ptr2.event == "coal": - break - path2.append(ptr2) - ptr2 = arg.get_local_parent(ptr2, rpos+eps) - - else: - # we have reached coal node - assert ptr1 == ptr2 - coal = ptr1 - is_cycle = True - break - - if not is_cycle: - # this recombination node is not a cycle - # either because it contains a local coal node or never recoals - # which can happen in SMC ARGs - continue - - # remove coal node - top = coal.parents[0] if coal.parents else None - if top: - util.replace(top.children, coal, None) - - # remove recomb node - bottom = recomb.children[0] - util.replace(bottom.parents, recomb, None) - - # unlink nodes in left path - last = recomb - for node in path1: - util.replace(last.parents, node, None) - util.replace(node.children, last, None) - last = node - util.replace(last.parents, coal, None) - - # unlink nodes in right path - last = recomb - for node in path2: - util.replace(last.parents, node, None) - util.replace(node.children, last, None) - last = node - util.replace(last.parents, coal, None) - - - # merge paths - combine = path1 + path2 - combine.sort(key=lambda x: (x.age, order[x])) - last = bottom - - for n in combine: - util.replace(last.parents, None, n) - util.replace(n.children, None, last) - last = n - - if top: - util.replace(last.parents, None, top) - util.replace(top.children, None, last) - else: - # no top node - if last.event == "coal": - last.parents = [] - elif last.event == "recomb": - # remove last recomb node since it is a single lineage - c = last.children[0] - p = last.parents[1 - last.parents.index(None)] - util.replace(c.parents, last, p) - util.replace(p.children, last, c) - del arg.nodes[last.name] - else: - raise Exception("unknown event '%s'" % node.event) - - del arg.nodes[recomb.name] - del arg.nodes[coal.name] - - -''' - - -''' -SLOW remove cycles - -def remove_self_cycles(arg, eps=.5): - """ - Removes cycles that represent a lineage coalescing with itself - - Requires ancestral sequences set. - """ - - # Such a cycle does not contain 'local coalescent nodes' on the sides - # but it might have non-local coalescent and recombination nodes. - # The relative order of theses nodes from the left and right side - # does not matter - # - # | - # coal - # / \ - # | | - # \ / - # recomb - # | - - def is_local_coal(node, child, pos): - if node.event != "coal": - return False - return True - - i = node.children.index(child) - other_child = node.children[1 - i] - - for start, end in other_child.data["ancestral"]: - #print node, other_child, (start, end), pos - if start < pos < end: - return True - - return False - - # get overall postorder - order = dict((x, i) for i, x in enumerate(arg.postorder())) - - - # find cycles by their recombination nodes - recombs = [x.name for x in arg if x.event == "recomb"] - for recomb_name in recombs: - if recomb_name not in arg: - continue - recomb = arg[recomb_name] - rpos = recomb.pos - - # get left path - path1 = [] - ptr = arg.get_local_parent(recomb, rpos-eps) - while ptr: - path1.append(ptr) - ptr = arg.get_local_parent(ptr, rpos-eps) - - # get right path - path2 = [] - ptr = arg.get_local_parent(recomb, rpos+eps) - while ptr: - path2.append(ptr) - ptr = arg.get_local_parent(ptr, rpos+eps) - - - # find recoal node - i = -1 - length = min(len(path1), len(path2)) - while -i <= length and path1[i] == path2[i]: - i -= 1 - if i == -1: - # this happens with SMC ARGs - continue - a = len(path1) + (i + 1) - b = len(path2) + (i + 1) - coal = path1[a] - - - # are there any coal nodes in left and right paths? - is_cycle = True - for i in range(a): - if is_local_coal( - path1[i], path1[i-1] if i > 0 else recomb, rpos-eps): - is_cycle = False - break - - for i in range(b): - if is_local_coal( - path2[i], path2[i-1] if i > 0 else recomb, rpos+eps): - is_cycle = False - break - - if not is_cycle: - # this recombination node is not a cycle - print "recomb", recomb, "pos=", rpos, "is not a cycle" - continue - - - print path1, path2, recomb, coal, i, a, b - print path1[:a], path2[:b] - - - # remove coal node - top = coal.parents[0] if coal.parents else None - if top: - util.replace(top.children, coal, None) - - # remove recomb node - bottom = recomb.children[0] - util.replace(bottom.parents, recomb, None) - - # unlink nodes in left path - last = recomb - for node in path1[:a]: - util.replace(last.parents, node, None) - util.replace(node.children, last, None) - last = node - util.replace(last.parents, coal, None) - - # unlink nodes in right path - last = recomb - for node in path2[:b]: - util.replace(last.parents, node, None) - util.replace(node.children, last, None) - last = node - util.replace(last.parents, coal, None) - - - # merge paths - combine = path1[:a] + path2[:b] - #print "path1", [order[x] for x in path1] - #print "path2", [order[x] for x in path2] - combine.sort(key=lambda x: (x.age, order[x])) - #print [(x.name, x.age, order[x]) for x in combine] - - - last = bottom - for n in combine: - util.replace(last.parents, None, n) - util.replace(n.children, None, last) - last = n - if top: - util.replace(last.parents, None, top) - util.replace(top.children, None, last) - else: - print "no top for", recomb, rpos - - # no top node - if last.event == "coal": - last.parents = [] - elif last.event == "recomb": - # remove last recomb node since it is a single lineage - c = last.children[0] - p = last.parents[1 - last.parents.index(None)] - util.replace(c.parents, last, p) - util.replace(p.children, last, c) - del arg.nodes[last.name] - print "remove last", last.name - else: - raise Exception("unknown event '%s'" % node.event) - - del arg.nodes[recomb.name] - del arg.nodes[coal.name] - print "remove", recomb.name, coal.name - print " ", order[bottom], order[recomb], order[coal], order[top] - - assert_arg(arg) - #print "good" - - - -''' - diff --git a/argweaver/deps/compbio/coal.py b/argweaver/deps/compbio/coal.py index abe25a43..2b6ab957 100644 --- a/argweaver/deps/compbio/coal.py +++ b/argweaver/deps/compbio/coal.py @@ -15,17 +15,17 @@ from __future__ import division # python imports -import itertools from itertools import chain, izip -from math import * +from math import exp, log, sqrt import random -from collections import defaultdict # rasmus imports from rasmus import treelib, stats, util, linked_list try: - from rasmus.symbolic import * + from rasmus.symbolic import assign_vars + from rasmus.symbolic import derivate + from rasmus.symbolic import simplify except ImportError: # only experimental functions need symbolic pass @@ -36,13 +36,12 @@ # import root finder try: from scipy.optimize import brentq + brentq except ImportError: def brentq(f, a, b, disp=False): return stats.bisect_root(f, a, b) - - #============================================================================= # single coalescent PDFs, CDFs, and sampling functions @@ -52,7 +51,6 @@ def prob_coal(t, k, n): Returns the probability density of observing the first coalesce of 'k' individuals in a population size of 'n' at generation 't' """ - # k choose 2 k2 = k * (k-1) / 2 k2n = k2 / n @@ -63,7 +61,6 @@ def sample_coal(k, n): """ Returns a sample coalescent time for 'k' individuals in a population 'n' """ - # k choose 2 k2 = k * (k-1) / 2 k2n = k2 / n @@ -86,6 +83,8 @@ def prob_coal_counts(a, b, t, n): The probabiluty of going from 'a' lineages to 'b' lineages in time 't' with population size 'n' """ + if b <= 0: + return 0.0 C = stats.prod((b+y)*(a-y)/(a+y) for y in xrange(b)) s = exp(-b*(b-1)*t/2.0/n) * C @@ -104,7 +103,6 @@ def prob_coal_counts_slow(a, b, t, n): Implemented more directly, but slower. Good for testing against. """ - s = 0.0 for k in xrange(b, a+1): i = exp(-k*(k-1)*t/2.0/n) * \ @@ -121,7 +119,6 @@ def prob_coal_cond_counts(x, a, b, t, n): between 'a' lineages conditioned on there being 'b' lineages at time 't'. The population size is 'n'. """ - lama = -a*(a-1)/2.0/n C = stats.prod((b+y)*(a-1-y)/(a-1+y) for y in xrange(b)) s = exp(-b*(b-1)/2.0/n*(t-x) + lama*x) * C @@ -140,7 +137,6 @@ def prob_coal_cond_counts_simple(x, a, b, t, n): between 'a' lineages conditioned on there being 'b' lineages at time 't'. The population size is 'n'. """ - return (prob_coal_counts(a-1, b, t-x, n) * prob_coal(x, a, n) / prob_coal_counts(a, b, t, n)) @@ -151,7 +147,6 @@ def cdf_coal_cond_counts(x, a, b, t, n): between 'a' lineages conditioned on there being 'b' lineages at time 't'. The population size is 'n'. """ - lama = -a*(a-1)/2.0/n C = stats.prod((b+y)*(a-1-y)/(a-1+y) for y in xrange(b)) c = -b*(b-1)/2.0/n @@ -160,8 +155,8 @@ def cdf_coal_cond_counts(x, a, b, t, n): k1 = k - 1 lam = -k*k1/2.0/n C = (b+k1)*(a-1-k1)/(a-1+k1)/(b-k) * C - s += exp(lam*t) * (exp((lama-lam)*x) - 1.0) / (lama - lam) \ - * (2*k-1) / (k1+b) * C + s += (exp(lam*t) * (exp((lama-lam)*x) - 1.0) / (lama - lam) + * (2*k-1) / (k1+b) * C) return s / stats.factorial(b) * (-lama) / prob_coal_counts(a, b, t, n) @@ -184,7 +179,6 @@ def sample_coal_cond_counts(a, b, t, n): c = -b*(b-1)/2.0/n d = 1.0/stats.factorial(b) * (-lama) / prob_coal_counts(a, b, t, n) - # CDF(t) - p def f(x): if x <= 0: @@ -198,15 +192,14 @@ def f(x): k1 = k - 1 lam = -k*k1/2.0/n C = (b+k1)*(a-1-k1)/(a-1+k1)/(b-k) * C - s += exp(lam*t) * (exp((lama-lam)*x) - 1.0) / (lama - lam) \ - * (2*k-1) / (k1+b) * C + s += (exp(lam*t) * (exp((lama-lam)*x) - 1.0) / (lama - lam) + * (2*k-1) / (k1+b) * C) return s * d - p return brentq(f, 0.0, t, disp=False) - def prob_mrca(t, k, n): """ Probability density function of the age 't' of the most recent @@ -255,7 +248,6 @@ def mrca_const(i, a, b): return prod - def prob_bounded_coal(t, k, n, T): """ Probability density function of seeing a coalescence at 't' from @@ -267,8 +259,8 @@ def prob_bounded_coal(t, k, n, T): if k == 2: prob_coal(t, k, n) - return prob_coal(t, k, n) * cdf_mrca(T-t, k-1, n) / \ - cdf_mrca(T, k, n) + return (prob_coal(t, k, n) * cdf_mrca(T-t, k-1, n) / + cdf_mrca(T, k, n)) def cdf_bounded_coal(t, k, n, T): @@ -280,7 +272,7 @@ def cdf_bounded_coal(t, k, n, T): lam_i = (i+1)*i/2.0 / n C = [mrca_const(j, 1, i-1) for j in xrange(1, i)] - A = lam_i / n / cdf_mrca(T, k, n) + #A = lam_i / n / cdf_mrca(T, k, n) B = sum(C) / lam_i F = [C[j-1] * exp(-(j+1)*j/2.0/n * T) / ((j+1)*j/2.0/n - lam_i) for j in xrange(1, i)] @@ -323,9 +315,9 @@ def f(t): if t >= T: return 1.0 - p + (t - T) - return (A * (B * (1-exp(-lam_i * t)) - - sum(F[j-1] * (exp(((j+1)*j/2.0/n - lam_i)*t)-1) - for j in xrange(1, i)))) - p + return ((A * (B * (1-exp(-lam_i * t)) + - sum(F[j-1] * (exp(((j+1)*j/2.0/n - lam_i)*t)-1) + for j in xrange(1, i)))) - p) return brentq(f, 0.0, T, disp=False) @@ -388,9 +380,9 @@ def count_lineages_per_branch(tree, recon, stree): for node in tree.postorder(): snode = recon[node] if node.is_leaf(): - lineages[snode][0] += 1 # leaf lineage + lineages[snode][0] += 1 # leaf lineage else: - lineages[snode][1] -= 1 # coal + lineages[snode][1] -= 1 # coal for snode in stree.postorder(): if not snode.is_leaf(): @@ -423,7 +415,6 @@ def get_topology_stats(tree, recon, stree): return nodes_per_species, descend_nodes - def prob_multicoal_recon_topology(tree, recon, stree, n, lineages=None, top_stats=None): """ @@ -431,7 +422,6 @@ def prob_multicoal_recon_topology(tree, recon, stree, n, from the coalescent model given a species tree 'stree' and population sizes 'n' """ - popsizes = init_popsizes(stree, n) if lineages is None: lineages = count_lineages_per_branch(tree, recon, stree) @@ -439,7 +429,7 @@ def prob_multicoal_recon_topology(tree, recon, stree, n, top_stats = get_topology_stats(tree, recon, stree) # iterate through species tree branches - lnp = 0.0 # log probability + lnp = 0.0 # log probability for snode in stree.postorder(): if snode.parent: # non root branch @@ -451,7 +441,7 @@ def prob_multicoal_recon_topology(tree, recon, stree, n, + stats.logfactorial(top_stats[0].get(snode, 0)) - log(num_labeled_histories(a, b))) except: - print (a, b, snode.dist, popsizes[snode.name], + print (a, b, snode.name, snode.dist, popsizes[snode.name], prob_coal_counts(a, b, snode.dist, popsizes[snode.name]), ) @@ -474,7 +464,6 @@ def prob_multicoal_recon_topology(tree, recon, stree, n, return lnp - def cdf_mrca_bounded_multicoal(gene_counts, T, stree, n, sroot=None, sleaves=None, stimes=None, tree=None, recon=None): @@ -518,11 +507,12 @@ def cdf_mrca_bounded_multicoal(gene_counts, T, stree, n, def calc_prob_counts_table(gene_counts, T, stree, popsizes, - sroot, sleaves, stimes): + sroot, sleaves, stimes): # use dynamic programming to calc prob of lineage counts # format: prob_counts[node] = [a, b] prob_counts = {} + def walk(node): if node in sleaves: # leaf case @@ -539,7 +529,7 @@ def walk(node): c2 = node.children[1] M1 = walk(c1) M2 = walk(c2) - M = M1 + M2 # max lineage counts in this snode + M = M1 + M2 # max lineage counts in this snode end1 = prob_counts[c1][1] end2 = prob_counts[c2][1] @@ -555,7 +545,7 @@ def walk(node): c1 = node.children[0] M1 = walk(c1) - M = M1 # max lineage counts in this snode + M = M1 # max lineage counts in this snode end1 = prob_counts[c1][1] # populate starting lineage counts with child's ending counts @@ -567,8 +557,6 @@ def walk(node): # unhandled case raise Exception("not implemented") - - # populate ending lineage counts n = popsizes[node.name] ptime = stimes[node.parent] if node.parent else T @@ -590,7 +578,7 @@ def walk(node): assert abs(sum(start) - 1.0) < .001, (start, node.children) return M - M = walk(sroot) + walk(sroot) return prob_counts @@ -611,6 +599,7 @@ def prob_coal_bmc(t, u, utime, ucount, gene_counts, T, stree, n, # find relevent leaves of stree (u should be treated as a leaf) if sleaves is None: sleaves = set() + def walk(node): if node.is_leaf() or node == u: sleaves.add(node) @@ -619,11 +608,10 @@ def walk(node): walk(child) walk(sroot) - # find timestamps of stree nodes if stimes is None: # modify timestamp of u to be that of the previous coal (utime) - stimes = {u : utime} + stimes = {u: utime} stimes = treelib.get_tree_timestamps(stree, sroot, sleaves, stimes) # init gene counts @@ -640,7 +628,6 @@ def walk(node): popsizes = init_popsizes(stree, n) - p = cdf_mrca_bounded_multicoal(gene_counts, T, stree, popsizes, sroot=sroot, sleaves=sleaves, stimes=stimes, tree=tree, recon=recon) @@ -681,6 +668,7 @@ def prob_no_coal_bmc(u, utime, ucount, gene_counts, T, stree, n, # find relevent leaves of stree (u should be treated as a leaf) if sleaves is None: sleaves = set() + def walk(node): if node.is_leaf() or node == u: sleaves.add(node) @@ -689,11 +677,10 @@ def walk(node): walk(child) walk(sroot) - # find timestamps of stree nodes if stimes is None: # modify timestamp of u to be that of the previous coal (utime) - stimes = {u : utime} + stimes = {u: utime} stimes = treelib.get_tree_timestamps(stree, sroot, sleaves, stimes) # init gene counts @@ -710,7 +697,6 @@ def walk(node): popsizes = init_popsizes(stree, n) - p = cdf_mrca_bounded_multicoal(gene_counts, T, stree, popsizes, sroot=sroot, sleaves=sleaves, stimes=stimes, tree=tree, recon=recon) @@ -728,13 +714,13 @@ def walk(node): return p2 - p + p3 - def num_labeled_histories(nleaves, nroots): n = 1.0 for i in xrange(nroots + 1, nleaves + 1): n *= i * (i - 1) / 2.0 return n + def log_num_labeled_histories(nleaves, nroots): n = 0.0 for i in xrange(nroots + 1, nleaves + 1): @@ -761,15 +747,14 @@ def prob_bounded_multicoal_recon_topology(tree, recon, stree, n, T, if stimes is None: stimes = treelib.get_tree_timestamps(stree) - p = prob_multicoal_recon_topology(tree, recon, stree, popsizes, lineages=lineages, top_stats=top_stats) k_root = lineages[stree.root][0] T_root = T - stimes[stree.root] - return log(cdf_mrca(T_root, k_root, popsizes[recon[tree.root].name])) + p \ - - cdf_mrca_bounded_multicoal(None, T, stree, popsizes, - tree=tree, recon=recon, stimes=stimes) - + return (log(cdf_mrca(T_root, k_root, popsizes[recon[tree.root].name])) + p + - cdf_mrca_bounded_multicoal( + None, T, stree, popsizes, + tree=tree, recon=recon, stimes=stimes)) #============================================================================= @@ -821,7 +806,7 @@ def sample_bounded_coal_tree_reject(k, n, T, capped=False): times = [0] for j in xrange(k, 1, -1): times.append(times[-1] + sample_coal(j, n)) - if times[-1] < t: + if times[-1] < T: break return make_tree_from_times(times, t=T, capped=capped)[0] @@ -868,7 +853,6 @@ def sample_coal_cond_counts_tree(a, b, t, n, capped=False): return make_tree_from_times(times, a, t, capped=capped) - def init_popsizes(stree, n): """ Uses 'n' to initialize a population size dict for species tree 'stree' @@ -912,6 +896,7 @@ def sample_multicoal_tree(stree, n, leaf_counts=None, # initialize function for generating new gene names if namefunc is None: spcounts = dict((l, 1) for l in stree.leaf_names()) + def namefunc(sp): name = sp + "_" + str(spcounts[sp]) spcounts[sp] += 1 @@ -933,7 +918,7 @@ def namefunc(sp): queue = MultiPushQueue(sleaves) # loop through species tree - for snode in queue: #stree.postorder(): + for snode in queue: # simulate population for one branch k = counts[snode.name] @@ -952,7 +937,6 @@ def namefunc(sp): for node in subtree: recon[node] = snode - # stitch subtrees together tree = treelib.Tree() @@ -963,7 +947,7 @@ def namefunc(sp): del recon[subtree.root] for snode in subtrees: - if snode not in sleaves: # not snode.is_leaf(): + if snode not in sleaves: subtree, lineages = subtrees[snode] # get lineages from child subtrees @@ -978,7 +962,6 @@ def namefunc(sp): for leaf, lineage in izip(leaves, lineages2): tree.add_child(leaf, lineage) - # set root tree.root = subtrees[sroot][0].root tree.add(tree.root) @@ -989,11 +972,6 @@ def namefunc(sp): if recon[node].is_leaf(): tree.rename(node.name, namefunc(recon[node].name)) - #print "HERE" - #treelib.draw_tree_names(tree, maxlen=8) - #print "recon", [(x[0].name, x[1].name) for x in recon.items()] - - return tree, recon @@ -1029,6 +1007,7 @@ def sample_bounded_multicoal_tree(stree, n, T, leaf_counts=None, namefunc=None, # initialize function for generating new gene names if namefunc is None: spcounts = dict((l.name, 1) for l in sleaves) + def namefunc(sp): name = sp + "_" + str(spcounts[sp]) spcounts[sp] += 1 @@ -1045,7 +1024,6 @@ def namefunc(sp): if stimes is None: stimes = treelib.get_tree_timestamps(stree) - # calc table prob_counts = calc_prob_counts_table(gene_counts, T, stree, popsizes, sroot, sleaves, stimes) @@ -1055,7 +1033,6 @@ def namefunc(sp): for node in sleaves: lineages[node] = [gene_counts[node.name], None] - # sample lineage counts sample_lineage_counts(sroot, sleaves, popsizes, stimes, T, lineages, prob_counts) @@ -1159,7 +1136,7 @@ def sample_lineage_counts(node, leaves, else: # unhandled case - raise Excepiton("not implemented") + raise NotImplementedError def coal_cond_lineage_counts(lineages, sroot, sleaves, popsizes, stimes, T, @@ -1242,7 +1219,6 @@ def join_subtrees(subtrees, recon, caps, sroot): for leaf, lineage in izip(leaves, lineages2): tree.add_child(leaf, lineage) - # set root tree.root = subtrees[sroot][0].root if tree.root in caps and len(tree.root.children) == 1: @@ -1251,8 +1227,6 @@ def join_subtrees(subtrees, recon, caps, sroot): return tree - - def sample_bounded_multicoal_tree_reject(stree, n, T, leaf_counts=None, namefunc=None, sleaves=None, sroot=None): @@ -1285,6 +1259,7 @@ def sample_bounded_multicoal_tree_reject(stree, n, T, leaf_counts=None, # initialize function for generating new gene names if namefunc is None: spcounts = dict((l.name, 1) for l in sleaves) + def namefunc(sp): name = sp + "_" + str(spcounts[sp]) spcounts[sp] += 1 @@ -1352,21 +1327,17 @@ def namefunc(sp): for leaf, lineage in izip(leaves, lineages2): tree.add_child(leaf, lineage) - # set root tree.root = subtrees[sroot][0].root tree.add(tree.root) recon[tree.root] = sroot - # reject tree if basal branch goes past deadline times = treelib.get_tree_timestamps(tree) if times[tree.root] < T: break else: reject += 1 - #print "reject", reject, times[tree.root], T - # name leaves for leaf in tree.leaves(): @@ -1375,8 +1346,6 @@ def namefunc(sp): return tree, recon - - def make_tree_from_times(times, k=None, t=None, leaves=None, capped=False): """ Returns a Tree from a list of divergence times. @@ -1417,7 +1386,6 @@ def make_tree_from_times(times, k=None, t=None, leaves=None, capped=False): children.remove(b) children.add(parent) - # set branch lengths for node in tree: if not node.parent: @@ -1442,7 +1410,6 @@ def make_tree_from_times(times, k=None, t=None, leaves=None, capped=False): return tree, children - #============================================================================= # popsize inference @@ -1454,7 +1421,7 @@ def mle_popsize_coal_times(k, times): s += i*(i-1) * (t - last) i -= 1 last = t - return s / float(2* k - 2) + return s / float(2 * k - 2) def mle_popsize_many_coal_times(k, times): @@ -1511,9 +1478,6 @@ def next(self): return self._lst.pop_front() - - - #============================================================================= # allele frequency @@ -1553,11 +1517,11 @@ def freq_CDF(p, N, t, T, k=50): T is the upper limit of the CDF (int from 0 to T) k is approximation for the upper limit in the (supposed to be) infinite sum """ - return freq_CDF-_legs_ends(legendre(1.0-2*p), legendre(1.0-2*T), - N, t, k=k) + return freq_CDF_legs_ends(legendre(1.0-2*p), legendre(1.0-2*T), + N, t, k=k) -def freq_CDF_legs_noends(leg_r,leg_T,N,t,k=50): +def freq_CDF_legs_noends(leg_r, leg_T, N, t, k=50): """ Evaluates the CDF derived from Kimura using two Legendre polynomials. This does not include the probabilities at 0 and 1 (partial CDF). @@ -1569,7 +1533,7 @@ def freq_CDF_legs_noends(leg_r,leg_T,N,t,k=50): """ s = 0.0 expconst = float(t) / 4.0 / N - for i in xrange(1,k+1): + for i in xrange(1, k+1): newterm = .5 * (leg_r(i-1) - leg_r(i+1)) newterm *= exp(- i * (i+1) * expconst) newterm *= 1 - leg_T(i) @@ -1577,7 +1541,7 @@ def freq_CDF_legs_noends(leg_r,leg_T,N,t,k=50): return s -def freq_CDF_legs_ends(leg_r,leg_T,N,t,k=50): +def freq_CDF_legs_ends(leg_r, leg_T, N, t, k=50): """ Evaluates the CDF derived from Kimura using two Legendre polynomials. This includes the probabilities at 0 and 1 (full CDF). @@ -1587,23 +1551,26 @@ def freq_CDF_legs_ends(leg_r,leg_T,N,t,k=50): t is the time elapsed k is the upper limit to approximate the infinite sum """ - s = prob_fix(1.0-leg_r(True),N,t) # leg_r(True) currently returns p, so this is probability of extinction + # leg_r(True) currently returns p, so this is probability of extinction + s = prob_fix(1.0-leg_r(True), N, t) expconst = float(t) / 4.0 / N - for i in xrange(1,k+1): + for i in xrange(1, k+1): newterm = .5 * (leg_r(i-1) - leg_r(i+1)) newterm *= exp(- i * (i+1) * expconst) newterm *= 1 - leg_T(i) s += newterm - return s if leg_T(True) < 1.0 else s + prob_fix(leg_r(True),N,t) # add fixation probability if T==1 + # add fixation probability if T==1 + return s if leg_T(True) < 1.0 else s + prob_fix(leg_r(True), N, t) def freq_prob_range(p, N, t, T1, T2, k=50): leg_r = legendre(1.0-2*p) leg_T1 = legendre(1.0-2*T1) leg_T2 = legendre(1.0-2*T2) - return (freq_CDF_legs_noends(leg_r, leg_T2, N, t, k=k) - \ - freq_CDF_legs_noends(leg_r, leg_T1, N, t, k=k)) - # uses noends because probabilities at 0 and 1 may be determined using other methods + return (freq_CDF_legs_noends(leg_r, leg_T2, N, t, k=k) - + freq_CDF_legs_noends(leg_r, leg_T1, N, t, k=k)) + # uses noends because probabilities at 0 and 1 may be + # determined using other methods def sample_freq_CDF(p, N, t): @@ -1624,16 +1591,17 @@ def sample_freq_CDF(p, N, t): y = random.random() leg_r = legendre(1.0-2*p) - extinction = prob_fix(1.0-p, N, t) # probability of allele extinction + extinction = prob_fix(1.0-p, N, t) # probability of allele extinction if y < extinction: - return 0.0 # sample an extinction event - elif y > 1.0 - prob_fix_leg(leg_r, N, t): #prob_fix(p, N, t): - return 1.0 # sample a fixation event + return 0.0 # sample an extinction event + elif y > 1.0 - prob_fix_leg(leg_r, N, t): + return 1.0 # sample a fixation event else: def f(T): - return freq_CDF_legs_noends(leg_r, legendre(1.0-2*T), N, t) \ - - y + extinction # trims extinction probability, assures brentq works + # trims extinction probability, assures brentq works + return (freq_CDF_legs_noends(leg_r, legendre(1.0-2*T), N, t) + - y + extinction) try: return brentq(f, 0.0, 1.0, disp=False) @@ -1642,7 +1610,6 @@ def f(T): raise - # new function for determining Legendre polynomial evaluations def legendre(r): """ @@ -1655,32 +1622,35 @@ def legendre(r): This function can run with n as high as one million in a fraction of a second (using isolated calls, so no caching to build higher values of n). """ - def cacheleg(i,d): + def cacheleg(i, d): if type(i) == bool: - return (1.0-d[1])/2.0 if i else d[1] # utility function; may need to be removed - assert (type(i) == int and i >= 0) # if i is not type bool + # utility function; may need to be removed + return (1.0-d[1])/2.0 if i else d[1] + assert (type(i) == int and i >= 0) # if i is not type bool m = d['max'] if i <= m: return d[i] x = d[1] - for n in xrange(m+1,i+1): - d[n] = 1.0 * ( (2*n-1)*x*d[n-1] - (n-1)*d[n-2] ) / n + for n in xrange(m+1, i+1): + d[n] = 1.0 * ((2 * n - 1) * x * d[n-1] - (n-1) * d[n-2]) / n d['max'] = i return d[i] - d = {0:1.0, 1:r, 'max':1} - assert -1.0 <= r and r <= 1.0 # ensure r in reasonable range - return lambda n: cacheleg(n,d) + d = {0: 1.0, 1: r, 'max': 1} + assert -1.0 <= r and r <= 1.0 # ensure r in reasonable range + return lambda n: cacheleg(n, d) def gegenbauer(i, r): return ((i * (i+1)) / 2.0 * hypergeo(i+2, 1 - i, 2, (1 - r) / 2.0)) + # this should be the fastest gegenbauer method now (21 July 2010) def gegenbauer2(i, r): leg = legendre(r) return ((i * (i+1)) / float((2*i+1)*(1-r*r)) * (leg(i-1) - leg(i+1))) + def gegenbauer3(n, a, z): tot = 0 @@ -1699,7 +1669,7 @@ def prob_fix(p, n, t, k=50, esp=0.000001): prob = p for i in xrange(1, k+1): term = (.5 * (-1)**i * (leg(i-1) - leg(i+1)) * - exp(-t * i * (i+1) / (4 * n))) + exp(-t * i * (i+1) / (4 * n))) if term != 0.0 and abs(term) < esp: return prob + term prob += term @@ -1712,10 +1682,10 @@ def prob_fix(p, n, t, k=50, esp=0.000001): def prob_fix_leg(leg_r, n, t, k=50, esp=0.000001): """Probability of fixation""" leg = leg_r - prob = leg(True) # gets p + prob = leg(True) # gets p for i in xrange(1, k+1): term = (.5 * (-1)**i * (leg(i-1) - leg(i+1)) * - exp(-t * i * (i+1) / (4 * n))) + exp(-t * i * (i+1) / (4 * n))) if term != 0.0 and abs(term) < esp: return prob + term prob += term @@ -1761,9 +1731,9 @@ def loghypergeo(a, b, c, z, k=100): def hypergeo_mult(i, z1, z2, k=100): - h1 = hypergeo(1-i, i+2, 2, z1, k) - h2 = hypergeo(1-i, i+2, 2, z2, k) - return h1 * h2 + h1 = hypergeo(1-i, i+2, 2, z1, k) + h2 = hypergeo(1-i, i+2, 2, z2, k) + return h1 * h2 def freq_pdf(x, p, n, t, k=8): @@ -1782,8 +1752,8 @@ def freq_pdf(x, p, n, t, k=8): # exp(-t * i * (i+1) / (4*n))) lcoff = log(p * q * i * (i+1) * (2*i+1)) - s1, h1 = loghypergeo(1-i,i+2,2,p, i+2) - s2, h2 = loghypergeo(1-i,i+2,2,x, i+2) + s1, h1 = loghypergeo(1-i, i+2, 2, p, i+2) + s2, h2 = loghypergeo(1-i, i+2, 2, x, i+2) sgn2 = s1 * s2 term = (lcoff + h1 + h2 - (i * (i+1) * t4n)) @@ -1792,30 +1762,16 @@ def freq_pdf(x, p, n, t, k=8): return sgn * exp(prob) - - #============================================================================= if __name__ == "__main__": from rasmus.common import plotfunc - if 0: - for i in range(5): - print "P_%d(x) = " % i, legendre_poly(i) - print - - #======================== # hypergeo speed a, b, c, z, k = 30, 20, 12, .3, 40 - util.tic("hypergeo_fast") - for i in range(100): - hypergeo_fast(a, b, c, z, k) - util.toc() - - util.tic("hypergeo") for i in range(100): hypergeo(a, b, c, z, k) @@ -1826,10 +1782,9 @@ def freq_pdf(x, p, n, t, k=8): loghypergeo(a, b, c, z, k) util.toc() - if 0: p0 = .5 - k=30 + k = 30 p = plotfunc(lambda x: freq_pdf(x, p0, 1000, 100, k=k), .01, .99, .01, style="lines") @@ -1849,8 +1804,6 @@ def freq_pdf(x, p, n, t, k=8): #p.plotfunc(lambda x: normalPdf(x, (.5, .1135)), # .01, .99, .01, style="lines") - - if 0: p0 = .1 @@ -1872,10 +1825,9 @@ def freq_pdf(x, p, n, t, k=8): #p.plotfunc(lambda x: freq_pdf3(x, .5, 1000, 1000/10, k=40), # .01, .99, .01, style="lines") - if 0: p0 = .5 - k=30 + k = 30 p = plotfunc(lambda x: freq_pdf(x, p0, 1000, 30, k=k), .01, .99, .01, style="lines") @@ -1883,8 +1835,6 @@ def freq_pdf(x, p, n, t, k=8): p.replot() - - #============================================================================= # old versions @@ -1933,15 +1883,14 @@ def innersum(i, T, j=0, s=0.0, c=1.0): # if p == 1.0: # all have the allele # return 1.0 if T == 1.0 else 0.0 s = 0.0 - for i in xrange(1,k+1): + for i in xrange(1, k+1): newterm = leg(i-1) - leg(i+1) newterm *= exp(- i * (i+1) / 4.0 * t / N) - newterm *= .5 - .5 * innersum(i,T) + newterm *= .5 - .5 * innersum(i, T) s += newterm return s - def hypergeo_old(a, b, c, z, k=100): """Hypergeometric function""" terms = [1.0] @@ -1972,8 +1921,8 @@ def freq_pdf_old(x, p, n, t, k=8): # exp(-t * i * (i+1) / (4*n))) lcoff = log(p * q * i * (i+1) * (2*i+1)) - h1 = hypergeo(1-i,i+2,2,p, i+2) - h2 = hypergeo(1-i,i+2,2,x, i+2) + h1 = hypergeo(1-i, i+2, 2, p, i+2) + h2 = hypergeo(1-i, i+2, 2, x, i+2) sgn2 = util.sign(h1) * util.sign(h2) if sgn2 != 0: @@ -1984,7 +1933,6 @@ def freq_pdf_old(x, p, n, t, k=8): return sgn * exp(prob) - def freq_pdf2(x, p, n, t, k=8): r = 1 - 2*p z = 1 - 2*x @@ -2005,7 +1953,7 @@ def freq_pdf3(x, p, n, t, k=8): prob = 0.0 for i in xrange(1, k+1): term = (p * q * i * (i+1) * (2*i+1) * - hypergeo(1-i,i+2,2,p,40) * hypergeo(1-i,i+2,2,x,40) * + hypergeo(1-i, i+2, 2, p, 40) * hypergeo(1-i, i+2, 2, x, 40) * exp(-t * i * (i+1) / (4*n))) prob += term @@ -2040,10 +1988,9 @@ def cdf_mrca2(t, k, n): return s - def prob_multicoal_recon_topology_old(tree, recon, stree, n, - root=None, leaves=None, - lineages=None, top_stats=None): + root=None, leaves=None, + lineages=None, top_stats=None): """ Returns the log probability of a reconciled gene tree ('tree', 'recon') from the coalescent model given a species tree 'stree' and @@ -2059,7 +2006,7 @@ def prob_multicoal_recon_topology_old(tree, recon, stree, n, top_stats = get_topology_stats(tree, recon, stree) # iterate through species tree branches - lnp = 0.0 # log probability + lnp = 0.0 # log probability for snode in stree.postorder(): if snode.parent: # non root branch @@ -2072,7 +2019,6 @@ def prob_multicoal_recon_topology_old(tree, recon, stree, n, a = lineages[snode][0] lnp -= log(num_labeled_histories(a, 1)) - # correct for topologies H(T) # find connected subtrees that are in the same species branch subtrees = [] @@ -2109,10 +2055,9 @@ def walk(node, subtree, leaves): def calc_prob_counts_table_old(gene_counts, T, stree, popsizes, sroot, sleaves, stimes): - root_time = T - stimes[sroot] - # use dynamic programming to calc prob of lineage counts prob_counts = {} + def walk(node): if node in sleaves: # leaf case @@ -2127,11 +2072,11 @@ def walk(node): c1 = node.children[0] c2 = node.children[1] ptime = stimes[node] - t1 = ptime - stimes[c1] # c1.dist - t2 = ptime - stimes[c2] # c2.dist + t1 = ptime - stimes[c1] # c1.dist + t2 = ptime - stimes[c2] # c2.dist M1 = walk(c1) M2 = walk(c2) - M = M1 + M2 # max lineage counts in this snode + M = M1 + M2 # max lineage counts in this snode n1 = popsizes[c1.name] n2 = popsizes[c2.name] @@ -2148,12 +2093,11 @@ def walk(node): assert abs(sum(prob_counts[node]) - 1.0) < .001 return M - M = walk(sroot) + walk(sroot) return prob_counts - def count_lineages_per_branch_old(tree, recon, stree, rev_recon=None): """ Returns the count of gene lineages present at each node in the species @@ -2214,8 +2158,8 @@ def get_topology_stats_old(tree, recon, stree, rev_recon=None): The function computes terms necessary for many topology calculations """ - nodes_per_species = {} # How many gene nodes per species - descend_nodes = {} # How many descendent nodes recon to the same species + nodes_per_species = {} # How many gene nodes per species + descend_nodes = {} # How many descendent nodes recon to the same species nodes_per_species = dict.fromkeys(stree, 0) @@ -2246,7 +2190,7 @@ def prob_fix_old(p, n, t, k=8, esp=0.001): prob = p for i in xrange(1, k+1): term = (.5 * (-1)**i * (legendre_old(i-1, r) - legendre_old(i+1, r)) * - exp(-t * i * (i+1) / (4 * n))) + exp(-t * i * (i+1) / (4 * n))) if term != 0.0 and abs(term) < esp: return prob + term prob += term diff --git a/argweaver/deps/compbio/phylo.py b/argweaver/deps/compbio/phylo.py index 04734aeb..2c6d18ec 100644 --- a/argweaver/deps/compbio/phylo.py +++ b/argweaver/deps/compbio/phylo.py @@ -1,7 +1,7 @@ # # Phylogeny functions # Matt Rasmussen 2006-2012 -# +# # python imports @@ -35,7 +35,7 @@ def make_gene2species(maps): maps -- a list of tuples [(gene_pattern, species_name), ... ] """ - + # find exact matches and expressions exacts = {} exps = [] @@ -44,7 +44,7 @@ def make_gene2species(maps): exacts[mapping[0]] = mapping[1] else: exps.append(mapping) - + # create mapping function def gene2species(gene): # eval expressions first in order of appearance @@ -55,10 +55,10 @@ def gene2species(gene): elif exp[0] == "*": if gene.endswith(exp[1:]): return species - + if gene in exacts: return exacts[gene] - + raise Exception("Cannot map gene '%s' to any species" % gene) return gene2species @@ -69,7 +69,7 @@ def read_gene2species(* filenames): Returns a function that will map gene names to species names. """ - + for filename in filenames: maps = [] for filename in filenames: @@ -81,15 +81,15 @@ def read_gene2species(* filenames): #============================================================================= # Reconciliation functions # - + def reconcile(gtree, stree, gene2species=gene2species): """ Returns a reconciliation dict for a gene tree 'gtree' and species tree 'stree' """ - + recon = {} - + # determine the preorder traversal of the stree order = {} def walk(node): @@ -97,38 +97,38 @@ def walk(node): node.recurse(walk) walk(stree.root) - + # label gene leaves with their species for node in gtree.leaves(): recon[node] = stree.nodes[gene2species(node.name)] - + # recurse through gene tree def walk(node): node.recurse(walk) - + if not node.is_leaf(): - # this node's species is lca of children species - recon[node] = reconcile_lca(stree, order, + # this node's species is lca of children species + recon[node] = reconcile_lca(stree, order, util.mget(recon, node.children)) walk(gtree.root) - + return recon def reconcile_lca(stree, order, nodes): """Helper function for reconcile""" - + # handle simple and complex cases if len(nodes) == 1: - return nodes[0] + return nodes[0] if len(nodes) > 2: return treelib.lca(nodes) - + # 2 node case node1, node2 = nodes index1 = order[node1] index2 = order[node2] - + while index1 != index2: if index1 > index2: node1 = node1.parent @@ -137,23 +137,46 @@ def reconcile_lca(stree, order, nodes): node2 = node2.parent index2 = order[node2] return node1 - + def reconcile_node(node, stree, recon): """Reconcile a single gene node to a species node""" return treelib.lca([recon[x] for x in node.children]) +def assert_recon(tree, stree, recon): + """Assert that a reconciliation is valid""" + + def below(node1, node2): + """Return True if node1 is below node2""" + while node1: + if node1 == node2: + return True + node1 = node1.parent + return False + + for node in tree: + # Every node in gene tree should be in reconciliation + assert node in recon + + # Every node should map to a species node equal to or + # below their parent's mapping + if node.parent: + snode = recon[node] + parent_snode = recon[node.parent] + assert below(snode, parent_snode) + + def label_events(gtree, recon): - """Returns a dict with gene node keys and values indicating + """Returns a dict with gene node keys and values indicating 'gene', 'spec', or 'dup'""" events = {} - + def walk(node): events[node] = label_events_node(node, recon) node.recurse(walk) walk(gtree.root) - + return events @@ -170,34 +193,34 @@ def label_events_node(node, recon): def find_loss_node(node, recon): """Finds the loss events for a branch in a reconciled gene tree""" loss = [] - + # if not parent, then no losses if not node.parent: return loss - + # determine starting and ending species sstart = recon[node] send = recon[node.parent] - + # determine species path of this gene branch (node, node.parent) ptr = sstart spath = [] while ptr != send: spath.append(ptr) ptr = ptr.parent - + # determine whether node.parent is a dup # if so, send (species end) is part of species path if label_events_node(node.parent, recon) == "dup": spath.append(send) - + # go up species path (skip starting species) # every node on the list is at least one loss for i, snode in enumerate(spath[1:]): for schild in snode.children: if schild != spath[i]: loss.append([node, schild]) - + return loss @@ -243,7 +266,7 @@ def walk(node): def count_dup(gtree, events, node=None): """Returns the number of duplications in a gene tree""" var = {"dups": 0} - + def walk(node): if events[node] == "dup": var["dups"] += len(node.children) - 1 @@ -252,7 +275,7 @@ def walk(node): walk(node) else: walk(gtree.root) - + return var["dups"] @@ -260,7 +283,7 @@ def count_dup_loss(gtree, stree, recon, events=None): """Returns the number of duplications + losses in a gene tree""" if events is None: events = label_events(gtree, recon) - + nloss = len(find_loss(gtree, stree, recon)) ndups = count_dup(gtree, events) return nloss + ndups @@ -269,7 +292,7 @@ def count_dup_loss(gtree, stree, recon, events=None): def find_species_roots(tree, stree, recon): """Find speciation nodes in the gene tree that reconcile to the species tree root""" - + roots = [] def walk(node): found = False @@ -280,7 +303,7 @@ def walk(node): found = True return found walk(tree.root) - return roots + return roots def find_orthologs(gtree, stree, recon, events=None, counts=True): @@ -289,13 +312,13 @@ def find_orthologs(gtree, stree, recon, events=None, counts=True): if events is None: events = label_events(gtree, recon) orths = [] - + for node, event in events.items(): if event == "spec": leavesmat = [x.leaves() for x in node.children] sp_counts = [util.hist_dict(util.mget(recon, row)) for row in leavesmat] - + for i in range(len(leavesmat)): for j in range(i+1, len(leavesmat)): for gene1 in leavesmat[i]: @@ -306,14 +329,14 @@ def find_orthologs(gtree, stree, recon, events=None, counts=True): else: g1, g2 = gene1, gene2 a, b = i, j - + if not counts: orths.append((g1.name, g2.name)) else: orths.append((g1.name, g2.name, sp_counts[a][recon[g1]], sp_counts[b][recon[g2]])) - + return orths @@ -330,7 +353,7 @@ def subset_recon(tree, recon, events=None): for node in list(events): if node not in nodes: del events[node] - + #============================================================================= @@ -343,7 +366,7 @@ def write_recon(filename, recon): def read_recon(filename, tree1, tree2): - """Read a reconciliation from a file""" + """Read a reconciliation from a file""" recon = {} for a, b in util.read_delim(filename): if a.isdigit(): a = int(a) @@ -368,24 +391,24 @@ def read_events(filename, tree): def write_recon_events(filename, recon, events=None, noevent=""): """Write a reconciliation and events to a file""" - + if events is None: events = dict.fromkeys(recon.keys(), noevent) - + util.write_delim(filename, [(str(a.name), str(b.name), events[a]) for a,b in recon.items()]) def read_recon_events(filename, tree1, tree2): """Read a reconciliation and events data structure from file""" - + recon = {} events = {} for a, b, event in util.read_delim(filename): if a.isdigit(): a = int(a) if b.isdigit(): b = int(b) node1 = tree1.nodes[a] - recon[node1] = tree2.nodes[b] + recon[node1] = tree2.nodes[b] events[node1] = event return recon, events @@ -413,15 +436,15 @@ def count_dup_loss_tree(tree, stree, gene2species, recon=None): recon = reconcile(tree, stree, gene2species) events = label_events(tree, recon) losses = find_loss(tree, stree, recon) - + dup = 0 loss = 0 appear = 0 - - # count appearance + + # count appearance recon[tree.root].data["appear"] += 1 appear += 1 - + # count dups for node, event in events.iteritems(): if event == "dup": @@ -434,7 +457,7 @@ def count_dup_loss_tree(tree, stree, gene2species, recon=None): for gnode, snode in losses: snode.data['loss'] += 1 loss += 1 - + return dup, loss, appear @@ -445,9 +468,9 @@ def walk(node): counts = [] for child in node.children: walk(child) - counts.append(child.data['genes'] + counts.append(child.data['genes'] - child.data['appear'] - - child.data['dup'] + - child.data['dup'] + child.data['loss']) assert util.equal(* counts), str(counts) node.data['genes'] = counts[0] @@ -458,7 +481,7 @@ def count_dup_loss_trees(trees, stree, gene2species): """ Returns new species tree with dup,loss,appear,genes counts in node's data """ - + stree = stree.copy() init_dup_loss_tree(stree) @@ -491,7 +514,7 @@ def dup_consistency(tree, recon, events): if len(tree.leaves()) == 1: return {} - + spset = {} def walk(node): for child in node.children: @@ -506,7 +529,7 @@ def walk(node): else: raise Exception("too many children (%d)" % len(node.children)) walk(tree.root) - + conf = {} for node in tree: if events[node] == "dup": @@ -521,29 +544,29 @@ def walk(node): # tree rooting -def recon_root(gtree, stree, gene2species = gene2species, +def recon_root(gtree, stree, gene2species = gene2species, rootby = "duploss", newCopy=True): """Reroot a tree by minimizing the number of duplications/losses/both""" # make a consistent unrooted copy of gene tree if newCopy: gtree = gtree.copy() - + if len(gtree.leaves()) == 2: return - + treelib.unroot(gtree, newCopy=False) - treelib.reroot(gtree, - gtree.nodes[sorted(gtree.leaf_names())[0]].parent.name, + treelib.reroot(gtree, + gtree.nodes[sorted(gtree.leaf_names())[0]].parent.name, onBranch=False, newCopy=False) - - + + # make recon root consistent for rerooting tree of the same names # TODO: there is the possibility of ties, they are currently broken - # arbitrarily. In order to make comparison of reconRooted trees with + # arbitrarily. In order to make comparison of reconRooted trees with # same gene names accurate, hashOrdering must be done, for now. hash_order_tree(gtree, gene2species) - + # get list of edges to root on edges = [] def walk(node): @@ -554,16 +577,16 @@ def walk(node): for child in gtree.root.children: walk(child) - - # try initial root and recon + + # try initial root and recon treelib.reroot(gtree, edges[0][0].name, newCopy=False) recon = reconcile(gtree, stree, gene2species) events = label_events(gtree, recon) - + # find reconciliation that minimizes loss minroot = edges[0] rootedge = sorted(edges[0]) - if rootby == "dup": + if rootby == "dup": cost = count_dup(gtree, events) elif rootby == "loss": cost = len(find_loss(gtree, stree, recon)) @@ -572,19 +595,19 @@ def walk(node): else: raise "unknown rootby value '%s'" % rootby mincost = cost - - + + # try rooting on everything for edge in edges[1:]: if sorted(edge) == rootedge: continue rootedge = sorted(edge) - + node1, node2 = edge if node1.parent != node2: node1, node2 = node2, node1 assert node1.parent == node2, "%s %s" % (node1.name, node2.name) - + # uncount cost if rootby in ["dup", "duploss"]: if events[gtree.root] == "dup": @@ -594,15 +617,15 @@ def walk(node): if rootby in ["loss", "duploss"]: cost -= len(find_loss_under_node(gtree.root, recon)) cost -= len(find_loss_under_node(node2, recon)) - + # new root and recon - treelib.reroot(gtree, node1.name, newCopy=False) - + treelib.reroot(gtree, node1.name, newCopy=False) + recon[node2] = reconcile_node(node2, stree, recon) recon[gtree.root] = reconcile_node(gtree.root, stree, recon) events[node2] = label_events_node(node2, recon) events[gtree.root] = label_events_node(gtree.root, recon) - + if rootby in ["dup", "duploss"]: if events[node2] == "dup": cost += 1 @@ -611,13 +634,13 @@ def walk(node): if rootby in ["loss", "duploss"]: cost += len(find_loss_under_node(gtree.root, recon)) cost += len(find_loss_under_node(node2, recon)) - + # keep track of min cost if cost < mincost: mincost = cost minroot = edge - + # root tree by minroot if edge != minroot: node1, node2 = minroot @@ -625,7 +648,7 @@ def walk(node): node1, node2 = node2, node1 assert node1.parent == node2 treelib.reroot(gtree, node1.name, newCopy=False) - + return gtree @@ -635,7 +658,7 @@ def midroot_recon(tree, stree, recon, events, params, generate): specs1 = [] specs2 = [] - + # find nearest specs/genes def walk(node, specs): if events[node] == "dup": @@ -647,7 +670,7 @@ def walk(node, specs): #walk(node2, specs2) specs1 = node1.leaves() specs2 = node2.leaves() - + def getDists(start, end): exp_dist = 0 obs_dist = 0 @@ -663,14 +686,14 @@ def getDists(start, end): start = start.parent return exp_dist, obs_dist / generate - + diffs1 = [] for spec in specs1: if events[tree.root] == "spec": exp_dist1, obs_dist1 = getDists(spec, tree.root) else: exp_dist1, obs_dist1 = getDists(spec, node1) - diffs1.append(obs_dist1 - exp_dist1) + diffs1.append(obs_dist1 - exp_dist1) diffs2 = [] for spec in specs2: @@ -679,17 +702,17 @@ def getDists(start, end): else: exp_dist2, obs_dist2 = getDists(spec, node2) diffs2.append(obs_dist2 - exp_dist2) - + totdist = (node1.dist + node2.dist) / generate left = node1.dist - stats.mean(diffs1) right = totdist - node2.dist + stats.mean(diffs2) - - #print diffs1, diffs2 + + #print diffs1, diffs2 #print stats.mean(diffs1), stats.mean(diffs2) - + mid = util.clamp((left + right) / 2.0, 0, totdist) - + node1.dist = mid * generate node2.dist = (totdist - mid) * generate @@ -697,9 +720,9 @@ def getDists(start, end): def stree2gtree(stree, genes, gene2species): """Create a gene tree with the same topology as the species tree""" - + tree = stree.copy() - + for gene in genes: tree.rename(gene2species(gene), gene) return tree @@ -729,16 +752,16 @@ def get_gene_losses(tree, stree, recon): """Returns losses as gene name, species name tuples""" return set((loss[0].name, loss[1].name) for loss in find_loss(tree, stree, recon)) - + def get_orthologs(tree, events): """Returns orthologs as gene name pairs""" - + specs = [sorted([sorted(child.leaf_names()) for child in node.children]) for node in events if events[node] == "spec"] - + return set(tuple(sorted((a, b))) for x in specs for a in x[0] @@ -761,7 +784,7 @@ def walk(node): child_hashes = map(walk, node.children) child_hashes.sort() return compose(child_hashes, node) - + if isinstance(tree, treelib.Tree) or hasattr(tree, "root"): return walk(tree.root) elif isinstance(tree, treelib.TreeNode): @@ -815,7 +838,7 @@ def walk(node): def brecon2recon_events(brecon): """ - Returns 'recon' and 'events' data structures from a branch reconciliation + Returns 'recon' and 'events' data structures from a branch reconciliation """ recon = {} events = {} @@ -831,7 +854,7 @@ def recon_events2brecon(recon, events): """ Returns a branch reconciliation from 'recon' and 'events' data structures """ - + brecon = {} for node, snode in recon.iteritems(): branch = [] @@ -846,7 +869,7 @@ def recon_events2brecon(recon, events): while ptr != sparent: losses.append((ptr, "specloss")) ptr = ptr.parent - + branch.extend(reversed(losses)) branch.append((snode, events[node])) @@ -863,7 +886,7 @@ def subtree_brecon_by_leaves(tree, brecon, leaves): brecon -- branch reconciliation leaves -- leaf nodes to keep in tree """ - + # record orignal parent pointers parents = dict((node, node.parent) for node in tree) @@ -917,14 +940,14 @@ def subtree_brecon_by_leaves(tree, brecon, leaves): branch_path[i][0] == branch_path[i+1][0])] for i in remove: del branch_path[i] - + brecon[node] = branch_path # remove unused nodes from brecon for node in brecon.keys(): if node.name not in tree: del brecon[node] - + return doomed @@ -940,7 +963,7 @@ def add_implied_spec_nodes_brecon(tree, brecon): parent = node.parent children = parent.children node2 = tree.new_node() - + node2.parent = parent children[children.index(node)] = node2 @@ -954,7 +977,7 @@ def add_implied_spec_nodes_brecon(tree, brecon): parent = node.parent children = parent.children node2 = tree.new_node() - + node2.parent = parent children[children.index(node)] = node2 @@ -965,7 +988,7 @@ def add_implied_spec_nodes_brecon(tree, brecon): brecon[node] = events[-1:] - + def write_brecon(out, brecon): @@ -1097,19 +1120,19 @@ def add_spec_node(node, snode, tree, recon, events): new node reconciles to species node 'snode'. Modifies recon and events accordingly """ - + newnode = treelib.TreeNode(tree.new_name()) parent = node.parent - + # find index of node in parent's children nodei = parent.children.index(node) - + # insert new node into tree tree.add_child(parent, newnode) parent.children[nodei] = newnode parent.children.pop() tree.add_child(newnode, node) - + # add recon and events info recon[newnode] = snode events[newnode] = "spec" @@ -1122,7 +1145,7 @@ def add_implied_spec_nodes(tree, stree, recon, events): adds speciation nodes to tree that are implied but are not present because of gene losses """ - + added_nodes = [] for node in list(tree): @@ -1132,8 +1155,8 @@ def add_implied_spec_nodes(tree, stree, recon, events): if node.parent is None: # ensure root of gene tree properly reconciles to # root of species tree - if recon[node] == stree.root: - continue + if recon[node] == stree.root: + continue tree.root = treelib.TreeNode(tree.new_name()) tree.add_child(tree.root, node) recon[tree.root] = stree.root @@ -1157,8 +1180,8 @@ def add_implied_spec_nodes(tree, stree, recon, events): added_nodes.append(add_spec_node(node, snode, tree, recon, events)) node = node.parent snode = snode.parent - - + + # determine whether node.parent is a dup # if so, send (a.k.a. species end) is part of species path if events[parent] == "dup": @@ -1174,7 +1197,7 @@ def change_recon_up(recon, node, events=None): """ Move the mapping of a node up one branch """ - + if events is not None and events[node] == "spec": # promote speciation to duplication # R'(v) = e(R(u)) @@ -1207,7 +1230,7 @@ def can_change_recon_up(recon, node, events=None): prnode = rnode.parent # rearrangement is valid if - return (not node.is_leaf() and + return (not node.is_leaf() and prnode is not None and # 1. there is parent sp. branch (node.parent is None or # 2. no parent to restrict move rnode != recon[node.parent] # 3. not already matching parent @@ -1221,7 +1244,7 @@ def enum_recon(tree, stree, depth=None, """ Enumerate reconciliations between a gene tree and a species tree """ - + if recon is None: recon = reconcile(tree, stree, gene2species) events = label_events(tree, recon) @@ -1238,14 +1261,14 @@ def enum_recon(tree, stree, depth=None, if can_change_recon_up(recon, node, events): schild = recon[node] change_recon_up(recon, node, events) - + # recurse depth2 = depth - 1 if depth is not None else None for r, e in enum_recon(tree, stree, depth2, i, preorder, recon, events): yield r, e - + change_recon_down(recon, node, schild, events) @@ -1257,8 +1280,8 @@ def enum_recon(tree, stree, depth=None, def perform_nni(tree, node1, node2, change=0, rooted=True): """Proposes a new tree using Nearest Neighbor Interchange - - Branch for NNI is specified by giving its two incident nodes (node1 and + + Branch for NNI is specified by giving its two incident nodes (node1 and node2). Change specifies which subtree of node1 will be swapped with the uncle. See figure below. @@ -1267,20 +1290,20 @@ def perform_nni(tree, node1, node2, change=0, rooted=True): uncle node1 / \ child[0] child[1] - + special case with rooted branch and rooted=False: - + node2 / \ node2' node1 / \ / \ uncle * child[0] child[1] - + """ - + if node1.parent != node2: - node1, node2 = node2, node1 - + node1, node2 = node2, node1 + # try to see if edge is one branch (not root edge) if not rooted and treelib.is_rooted(tree) and \ node2 == tree.root: @@ -1289,30 +1312,30 @@ def perform_nni(tree, node1, node2, change=0, rooted=True): node2 = node2.children[1] else: node2 = node2.children[0] - + # edge is not an internal edge, give up if len(node2.children) < 2: return - + if node1.parent == node2.parent == tree.root: uncle = 0 - + if len(node2.children[0].children) < 2 and \ len(node2.children[1].children) < 2: # can't do NNI on this branch return - else: + else: assert node1.parent == node2 - + # find uncle - uncle = 0 + uncle = 0 if node2.children[uncle] == node1: uncle = 1 - + # swap parent pointers node1.children[change].parent = node2 node2.children[uncle].parent = node1 - + # swap child pointers node2.children[uncle], node1.children[change] = \ node1.children[change], node2.children[uncle] @@ -1331,7 +1354,7 @@ def propose_random_nni(tree): node1 = random.sample(nodes, 1)[0] if not node1.is_leaf() and node1.parent is not None: break - + node2 = node1.parent #a = node1.children[random.randint(0, 1)] #b = node2.children[1] if node2.children[0] == node1 else node2.children[0] @@ -1345,7 +1368,7 @@ def propose_random_nni(tree): def perform_spr(tree, subtree, newpos): """ Proposes new topology using Subtree Pruning and Regrafting (SPR) - + a = subtree e = newpos @@ -1369,7 +1392,7 @@ def perform_spr(tree, subtree, newpos): Requirements: 1. a (subtree) is not root or children of root - 2. e (newpos) is not root, a, descendant of a, c (parent of a), or + 2. e (newpos) is not root, a, descendant of a, c (parent of a), or b (sibling of a) 3. tree is binary @@ -1400,13 +1423,13 @@ def propose_random_spr(tree): What if e == f (also equivalent to NNI) this is OK BEFORE - + d / \ e ... / \ - c ... - / \ + c ... + / \ a b ... ... @@ -1419,11 +1442,11 @@ def propose_random_spr(tree): ... / \ b ... ... - + What if d == f (also equivalent to NNI) this is OK - + BEFORE - + f / \ c e @@ -1432,17 +1455,17 @@ def propose_random_spr(tree): ... ... AFTER - + f / \ - b c - ... / \ + b c + ... / \ a e - ... ... + ... ... Requirements: 1. a (subtree) is not root or children of root - 2. e (newpos) is not root, a, descendant of a, c (parent of a), or + 2. e (newpos) is not root, a, descendant of a, c (parent of a), or b (sibling of a) 3. tree is binary """ @@ -1456,21 +1479,21 @@ def propose_random_spr(tree): if (a.parent is not None and a.parent.parent is not None): break subtree = a - + # find sibling (b) of a c = a.parent bi = 1 if c.children[0] == a else 0 b = c.children[bi] - + # choose newpos (e) e = None while True: e = random.sample(nodes, 1)[0] - + # test if e is a valid choice if e.parent is None or e == a or e == c or e == b: continue - + # also test if e is a descendent of a under_a = False ptr = e.parent @@ -1481,7 +1504,7 @@ def propose_random_spr(tree): ptr = ptr.parent if under_a: continue - + break newpos = e @@ -1561,7 +1584,7 @@ def propose(self): # choose SPR move self.node1, node3 = propose_random_spr(self.tree) - + # remember sibling of node1 p = self.node1.parent self.node2 = (p.children[1] if p.children[0] == self.node1 @@ -1578,7 +1601,7 @@ def revert(self): def reset(self): self.node1 = None - self.node2 = None + self.node2 = None class TreeSearchMix (TreeSearch): @@ -1650,7 +1673,7 @@ def propose(self): if top not in self.seen: #util.logger("tried", i, len(self.seen)) break - else: + else: #util.logger("maxtries", len(self.seen)) pass @@ -1658,7 +1681,7 @@ def propose(self): self.seen.add(top) self.tree = tree return self.tree - + def revert(self): self.tree = self.search.revert() @@ -1673,7 +1696,7 @@ def reset(self): def add_seen(self, tree): top = self._tree_hash(tree) self.seen.add(top) - + class TreeSearchPrescreen (TreeSearch): @@ -1718,7 +1741,7 @@ def propose(self): else: self.search.revert() - # propose one of the subproposals + # propose one of the subproposals choice = random.random() partsum = -util.INF @@ -1751,13 +1774,13 @@ def reset(self): def neighborjoin(distmat, genes, usertree=None): """Neighbor joining algorithm""" - + tree = treelib.Tree() leaves = {} dists = util.Dict(dim=2) restdists = {} - - + + # initialize distances for i in range(len(genes)): r = 0 @@ -1765,12 +1788,12 @@ def neighborjoin(distmat, genes, usertree=None): dists[genes[i]][genes[j]] = distmat[i][j] r += distmat[i][j] restdists[genes[i]] = r / (len(genes) - 2) - + # initialize leaves for gene in genes: tree.add(treelib.TreeNode(gene)) leaves[gene] = 1 - + # if usertree is given, determine merging order merges = [] newnames = {} @@ -1779,7 +1802,7 @@ def walk(node): if not node.is_leaf(): assert len(node.children) == 2, \ Exception("usertree is not binary") - + for child in node: walk(child) merges.append(node) @@ -1788,7 +1811,7 @@ def walk(node): newnames[node] = node.name walk(usertree.root) merges.reverse() - + # join loop while len(leaves) > 2: # search for closest genes @@ -1809,22 +1832,22 @@ def walk(node): node = merges.pop() lowpair = (newnames[node.children[0]], newnames[node.children[1]]) - + # join gene1 and gene2 gene1, gene2 = lowpair parent = treelib.TreeNode(tree.new_name()) tree.add_child(parent, tree.nodes[gene1]) tree.add_child(parent, tree.nodes[gene2]) - + # set distances - tree.nodes[gene1].dist = (dists[gene1][gene2] + restdists[gene1] - + tree.nodes[gene1].dist = (dists[gene1][gene2] + restdists[gene1] - restdists[gene2]) / 2.0 tree.nodes[gene2].dist = dists[gene1][gene2] - tree.nodes[gene1].dist - + # gene1 and gene2 are no longer leaves del leaves[gene1] del leaves[gene2] - + gene3 = parent.name r = 0 for gene in leaves: @@ -1833,10 +1856,10 @@ def walk(node): dists[gene][gene3] = dists[gene3][gene] r += dists[gene3][gene] leaves[gene3] = 1 - + if len(leaves) > 2: restdists[gene3] = r / (len(leaves) - 2) - + # join the last two genes into a tribranch gene1, gene2 = leaves.keys() if type(gene1) != int: @@ -1845,7 +1868,7 @@ def walk(node): tree.nodes[gene2].dist = dists[gene1][gene2] tree.root = tree.nodes[gene1] - # root tree according to usertree + # root tree according to usertree if usertree != None and treelib.is_rooted(usertree): roots = set([newnames[usertree.root.children[0]], newnames[usertree.root.children[1]]]) @@ -1853,11 +1876,11 @@ def walk(node): for child in tree.root.children: if child.name in roots: newroot = child - + assert newroot != None - + treelib.reroot(tree, newroot.name, newCopy=False) - + return tree @@ -1868,11 +1891,11 @@ def walk(node): def least_square_error(tree, distmat, genes, forcePos=True, weighting=False): """Least Squared Error algorithm for phylogenetic reconstruction""" - + # use SCIPY to perform LSE import scipy import scipy.linalg - + def makeVector(array): """convience function for handling different configurations of scipy""" if len(array.shape) == 2: @@ -1882,26 +1905,26 @@ def makeVector(array): return scipy.transpose(array)[0] else: return array - - + + if treelib.is_rooted(tree): rootedge = sorted([x.name for x in tree.root.children]) treelib.unroot(tree, newCopy=False) else: - rootedge = None - + rootedge = None + # create pairwise dist array dists = [] for i in xrange(len(genes)): for j in xrange(i+1, len(genes)): dists.append(distmat[i][j]) - + # create topology matrix topmat, edges = make_topology_matrix(tree, genes) - + # setup matrix and vector if weighting: - topmat2 = scipy.array([[util.safediv(x, math.sqrt(dists[i]), 0) + topmat2 = scipy.array([[util.safediv(x, math.sqrt(dists[i]), 0) for x in row] for i, row in enumerate(topmat)]) paths = scipy.array(map(math.sqrt, dists)) @@ -1909,28 +1932,28 @@ def makeVector(array): topmat2 = scipy.array(topmat) paths = scipy.array(dists) - + # solve LSE edgelens, resids, rank, singlars = scipy.linalg.lstsq(topmat2, paths) - + # force non-negative branch lengths if forcePos: edgelens = [max(float(x), 0) for x in makeVector(edgelens)] else: edgelens = [float(x) for x in makeVector(edgelens)] - + # calc path residuals (errors) paths2 = makeVector(scipy.dot(topmat2, edgelens)) resids = (paths2 - paths).tolist() paths = paths.tolist() - + # set branch lengths - set_branch_lengths_from_matrix(tree, edges, edgelens, paths, resids, + set_branch_lengths_from_matrix(tree, edges, edgelens, paths, resids, topmat=topmat, rootedge=rootedge) - - return util.Bundle(resids=resids, - paths=paths, - edges=edges, + + return util.Bundle(resids=resids, + paths=paths, + edges=edges, topmat=topmat) @@ -1942,10 +1965,10 @@ def make_topology_matrix(tree, genes): edges = splits.keys() # create topology matrix - n = len(genes) + n = len(genes) ndists = n*(n-1) / 2 topmat = util.make_matrix(ndists, len(edges)) - + vlookup = util.list2lookup(genes) n = len(genes) for e in xrange(len(edges)): @@ -1955,11 +1978,11 @@ def make_topology_matrix(tree, genes): i, j = util.sort([vlookup[gene1], vlookup[gene2]]) index = i*n-i*(i+1)/2+j-i-1 topmat[index][e] = 1.0 - + return topmat, edges -def set_branch_lengths_from_matrix(tree, edges, edgelens, paths, resids, +def set_branch_lengths_from_matrix(tree, edges, edgelens, paths, resids, topmat=None, rootedge=None): # recreate rooting branches if rootedge != None: @@ -1968,12 +1991,12 @@ def set_branch_lengths_from_matrix(tree, edges, edgelens, paths, resids, treelib.reroot(tree, rootedge[0], newCopy=False) else: treelib.reroot(tree, rootedge[1], newCopy=False) - + # find root edge in edges for i in xrange(len(edges)): if sorted(edges[i]) == rootedge: break - + edges[i] = [rootedge[0], tree.root.name] edges.append([rootedge[1], tree.root.name]) edgelens[i] /= 2.0 @@ -1982,11 +2005,11 @@ def set_branch_lengths_from_matrix(tree, edges, edgelens, paths, resids, resids.append(resids[i]) paths[i] /= 2.0 paths.append(paths[i]) - + if topmat != None: for row in topmat: row.append(row[i]) - + # set branch lengths for i in xrange(len(edges)): gene1, gene2 = edges[i] @@ -1999,14 +2022,14 @@ def set_branch_lengths_from_matrix(tree, edges, edgelens, paths, resids, def tree2distmat(tree, leaves): """Returns pair-wise distances between leaves of a tree""" - + # TODO: not implemented efficiently mat = [] for i in range(len(leaves)): mat.append([]) for j in range(len(leaves)): mat[-1].append(treelib.find_dist(tree, leaves[i], leaves[j])) - + return mat @@ -2021,7 +2044,7 @@ def find_splits(tree, rooted=False): If 'rooted' is True, then orient splits based on rooting """ - + all_leaves = set(tree.leaf_names()) nall_leaves = len(all_leaves) @@ -2062,9 +2085,9 @@ def walk(node): set2 = tuple(sorted(all_leaves - leaves)) if not rooted and min(set2) < min(set1): set1, set2 = set2, set1 - + splits.append((set1, set2)) - + return splits @@ -2100,7 +2123,7 @@ def split_bit_string(split, leaves=None, char1="*", char2=".", nochar=" "): chars.append(nochar) return "".join(chars) - + def robinson_foulds_error(tree1, tree2, rooted=False): """ @@ -2115,11 +2138,11 @@ def robinson_foulds_error(tree1, tree2, rooted=False): splits2 = find_splits(tree2, rooted=rooted) overlap = set(splits1) & set(splits2) - + #assert len(splits1) == len(splits2) denom = float(max(len(splits1), len(splits2))) - + if denom == 0.0: return 0.0 else: @@ -2154,14 +2177,14 @@ def consensus_majority_rule(trees, extended=True, rooted=False): tree.add_child(n, treelib.TreeNode(leaves[1])) tree.add_child(root, treelib.TreeNode(leaves[2])) return tree - + elif nleaves == 2: leaves = trees[0].leaf_names() root = tree.make_root() tree.add_child(root, treelib.TreeNode(leaves[0])) tree.add_child(root, treelib.TreeNode(leaves[1])) return tree - + # count all splits for tree in trees: @@ -2170,7 +2193,7 @@ def consensus_majority_rule(trees, extended=True, rooted=False): contree.nextname = max(tree.nextname for tree in trees) #util.print_dict(split_counts) - + # choose splits pick_splits = 0 rank_splits = split_counts.items() @@ -2180,7 +2203,7 @@ def consensus_majority_rule(trees, extended=True, rooted=False): for split, count in rank_splits: if not extended and count <= ntrees / 2.0: continue - + # choose split if it is compatiable if _add_split_to_tree(contree, split, count / float(ntrees), rooted): pick_splits += 1 @@ -2192,7 +2215,7 @@ def consensus_majority_rule(trees, extended=True, rooted=False): # add remaining leaves and remove clade data _post_process_split_tree(contree) - + return contree @@ -2205,7 +2228,7 @@ def splits2tree(splits, rooted=False): splits -- iterable of splits rooted -- if True treat splits as rooted/polarized """ - + tree = treelib.Tree() for split in splits: _add_split_to_tree(tree, split, 1.0, rooted) @@ -2239,7 +2262,7 @@ def _add_split_to_tree(tree, split, count, rooted=False): node = tree.add_child(root, treelib.TreeNode(list(split[1])[0])) node.data["leaves"] = split[1] node.data["boot"] = count - + return True def walk(node, clade): @@ -2253,13 +2276,13 @@ def walk(node, clade): child.data["leaves"] = clade child.data["boot"] = count return True - + # which children intersect this clade? intersects = [] for child in node: leaves = child.data["leaves"] intersect = clade & leaves - + if len(clade) == len(intersect): if len(intersect) < len(leaves): # subset, recurse @@ -2270,7 +2293,7 @@ def walk(node, clade): elif len(intersect) == 0: continue - + elif len(intersect) == len(leaves): # len(clade) > len(leaves) # superset @@ -2292,7 +2315,7 @@ def walk(node, clade): tree.add_child(new_node, child) return True - + # try to place split into tree if rooted: walk(tree.root, split[0]) @@ -2301,17 +2324,17 @@ def walk(node, clade): return True else: return walk(tree.root, split[1]) - + # split is in conflict return False - + def _post_process_split_tree(tree): """ Post-process a tree built from splits private method """ - + for node in list(tree): if len(node.data["leaves"]) > 1: for leaf_name in node.data["leaves"]: @@ -2327,8 +2350,8 @@ def _post_process_split_tree(tree): for node in tree: if "leaves" in node.data: del node.data["leaves"] - - + + def ensure_binary_tree(tree): """ @@ -2338,17 +2361,17 @@ def ensure_binary_tree(tree): # first tree just rerooting root branch if len(tree.root.children) > 2: treelib.reroot(tree, tree.root.children[0].name, newCopy=False) - + multibranches = [node for node in tree if len(node.children) > 2] for node in multibranches: children = list(node.children) - + # remove children for child in children: tree.remove(child) - + # add back in binary while len(children) > 2: left = children.pop() @@ -2358,7 +2381,7 @@ def ensure_binary_tree(tree): tree.add_child(newnode, left) tree.add_child(newnode, right) children.append(newnode) - + # add last two to original node tree.add_child(node, children.pop()) tree.add_child(node, children.pop()) @@ -2378,7 +2401,7 @@ def make_jc_matrix(t, a=1.): eat = math.exp(-4*a/3.*t) r = .25 * (1 + 3*eat) s = .25 * (1 - eat) - + return [[r, s, s, s], [s, r, s, s], [s, s, r, s], @@ -2395,7 +2418,7 @@ def make_hky_matrix(t, bgfreq=(.25,.25,.25,.25), kappa=1.0): bgfreq -- background base frequency kappa -- transition/transversion ratio """ - + # bases = "ACGT" # pi = bfreq @@ -2403,16 +2426,16 @@ def make_hky_matrix(t, bgfreq=(.25,.25,.25,.25), kappa=1.0): pi_y = bgfreq[1] + bgfreq[3] rho = pi_r / pi_y - - # convert the usual ratio definition (kappa) to Felsenstein's + + # convert the usual ratio definition (kappa) to Felsenstein's # definition (R) ratio = (bgfreq[3]*bgfreq[1] + bgfreq[0]*bgfreq[2]) * kappa / (pi_y*pi_r) # determine HKY parameters alpha_r, alpha_y, and beta b = 1.0 / (2.0 * pi_r * pi_y * (1.0+ratio)) - a_y = ((pi_r*pi_y*ratio - bgfreq[0]*bgfreq[2] - bgfreq[1]*bgfreq[3]) / - (2.0*(1+ratio)*(pi_y*bgfreq[0]*bgfreq[2]*rho + + a_y = ((pi_r*pi_y*ratio - bgfreq[0]*bgfreq[2] - bgfreq[1]*bgfreq[3]) / + (2.0*(1+ratio)*(pi_y*bgfreq[0]*bgfreq[2]*rho + pi_r*bgfreq[1]*bgfreq[3]))) a_r = rho * a_y @@ -2423,7 +2446,7 @@ def make_hky_matrix(t, bgfreq=(.25,.25,.25,.25), kappa=1.0): [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] - + for i in (0, 1, 2, 3): for j in (0, 1, 2, 3): # convenience variables @@ -2435,14 +2458,14 @@ def make_hky_matrix(t, bgfreq=(.25,.25,.25,.25), kappa=1.0): else: # prymidine a_i = a_y pi_ry = pi_y - delta_ij = int(i == j) + delta_ij = int(i == j) e_ij = int((i in (0, 2)) == (j in (0, 2))) ait = math.exp(-a_i*t) ebt = math.exp(-b*t) - mat[i][j] = (ait*ebt * delta_ij + - ebt * (1.0 - ait) * (bgfreq[j]*e_ij/pi_ry) + + mat[i][j] = (ait*ebt * delta_ij + + ebt * (1.0 - ait) * (bgfreq[j]*e_ij/pi_ry) + (1.0 - ebt) * bgfreq[j]) return mat @@ -2450,11 +2473,11 @@ def make_hky_matrix(t, bgfreq=(.25,.25,.25,.25), kappa=1.0): def sim_seq_branch(seq, time, matrix_func): """Simulate sequence evolving down one branch""" - + matrix = matrix_func(time) bases = "ACGT" lookup = {"A": 0, "C": 1, "G": 2, "T": 3} - + seq2 = [] for a in seq: seq2.append(bases[stats.sample(matrix[lookup[a]])]) @@ -2462,11 +2485,11 @@ def sim_seq_branch(seq, time, matrix_func): return "".join(seq2) -def sim_seq_tree(tree, seqlen, matrix_func=make_jc_matrix, +def sim_seq_tree(tree, seqlen, matrix_func=make_jc_matrix, bgfreq=[.25,.25,.25,.25], rootseq=None, keep_internal=False): """Simulate the evolution of a sequence down a tree""" - + bases = "ACGT" # make root sequence @@ -2488,11 +2511,11 @@ def walk(node, seq): # recurse for child in node.children: seq2 = sim_seq_branch(seq, child.dist, matrix_func) - walk(child, seq2) + walk(child, seq2) walk(tree.root, rootseq) return seqs - + #============================================================================= @@ -2505,7 +2528,7 @@ def sample_dlt_gene_tree(stree, duprate, lossrate, transrate, """Simulate a gene tree within a species tree with dup, loss, transfer""" # TODO: return brecon instead of (recon, events) - + stimes = treelib.get_tree_timestamps(stree) # initialize gene tree @@ -2517,7 +2540,7 @@ def sample_dlt_gene_tree(stree, duprate, lossrate, transrate, totalrate = duprate + lossrate + transrate - + def sim_branch(node, snode, dist): # sample next event @@ -2542,11 +2565,11 @@ def sim_branch(node, snode, dist): node.dist = time recon[node] = snode events[node] = "dup" - + # recurse sim_branch(node, snode, dist - time) sim_branch(node, snode, dist - time) - + elif pick <= (duprate + lossrate) / totalrate: # loss occurs node = tree.add_child(node, tree.new_node()) @@ -2573,12 +2596,12 @@ def sim_branch(node, snode, dist): assert len(others) > 0, (age, stimes) dest = random.sample(others, 1)[0] - + # recurse sim_branch(node, snode, dist - time) sim_branch(node, dest, age - stimes[dest]) - - + + def sim_spec(node, snode): if snode.is_leaf(): # leaf in species tree, terminal gene lineage @@ -2588,7 +2611,7 @@ def sim_spec(node, snode): # speciation in species tree, follow each branch for schild in snode.children: sim_branch(node, schild, schild.dist) - + sim_spec(tree.root, stree.root) @@ -2598,7 +2621,7 @@ def sim_spec(node, snode): brecon = recon_events2brecon(recon, events) - + return tree, brecon @@ -2607,7 +2630,7 @@ def sample_dltr_gene_tree(stree, duprate, lossrate, transrate, recombrate, genename=lambda sp, x: sp + "_" + str(x), removeloss=True): """Simulate a gene tree within a species tree with dup, loss, transfer""" - + stimes = treelib.get_tree_timestamps(stree) spec_times = sorted((x for x in stimes.values() if x > 0.0), reverse=True) spec_times.append(0.0) @@ -2625,7 +2648,7 @@ class Lineage (object): def __init__(self, node, snode): self.node = node self.snode = snode - + lineages = set() for schild in stree.root: lineages.add(Lineage(tree.root, schild)) @@ -2662,12 +2685,12 @@ def __init__(self, node, snode): for schild in l.snode.children: lineages.add(Lineage(child, schild)) continue - + # choose event type and lineage lineage = random.sample(lineages, 1)[0] node, snode = lineage.node, lineage.snode pick = stats.sample((duprate, lossrate, transrate, recombrate)) - + if pick == 0: # duplication child = tree.add_child(node, tree.new_node()) @@ -2740,11 +2763,11 @@ def __init__(self, node, snode): times[child2] = age brecon[child2] = [(gene.snode, "loss")] lineages.remove(gene) - + if removeloss: keep = [x for x in tree.leaves() if isinstance(x.name, str)] subtree_brecon_by_leaves(tree, brecon, keep) - + return tree, brecon diff --git a/argweaver/deps/compbio/vis/argvis.py b/argweaver/deps/compbio/vis/argvis.py index c30fb9ae..59cf6d42 100644 --- a/argweaver/deps/compbio/vis/argvis.py +++ b/argweaver/deps/compbio/vis/argvis.py @@ -31,7 +31,7 @@ def minlog(x, default=10): return log(max(x, default)) - + def layout_arg_leaves(arg): """Layout the leaves of an ARG""" @@ -63,7 +63,7 @@ def layout_arg_leaves(arg): # assign layout based on basetree layout # layout leaves return dict((arg[name], i) for i, name in enumerate(basetree.leaf_names())) - + def layout_arg(arg, leaves=None, yfunc=lambda x: x): """Layout the nodes of an ARG""" @@ -75,7 +75,7 @@ def layout_arg(arg, leaves=None, yfunc=lambda x: x): leafx = layout_arg_leaves(arg) else: leafx = util.list2lookup(leaves) - + for node in arg.postorder(): if node.is_leaf(): layout[node] = [leafx[node], yfunc(node.age)] @@ -93,9 +93,9 @@ def map_layout(layout, xfunc=lambda x: x, yfunc=lambda x: x): layout[node] = [xfunc(x), yfunc(y)] return layout - - - + + + def get_branch_layout(layout, node, parent, side=0, recomb_width=.4): @@ -103,7 +103,7 @@ def get_branch_layout(layout, node, parent, side=0, recomb_width=.4): nx, ny = layout[node] px, py = layout[parent] - + if node.event == "recomb": if len(node.parents) == 2 and node.parents[0] == node.parents[1]: step = recomb_width * [-1, 1][side] @@ -123,7 +123,7 @@ def show_arg(arg, layout=None, leaves=None, mut=None, recomb_width=.4, win = summon.Window() else: win.clear_groups() - + # ensure layout if layout is None: layout = layout_arg(arg, leaves) @@ -133,7 +133,7 @@ def branch_click(node, parent): print node.name, parent.name # draw ARG - win.add_group(draw_arg(arg, layout, recomb_width=recomb_width, + win.add_group(draw_arg(arg, layout, recomb_width=recomb_width, branch_click=branch_click)) # draw mutations @@ -144,10 +144,10 @@ def branch_click(node, parent): g.append(group(draw_mark(x1, t, col=(0,0,1)), color(1,1,1))) win.add_group(g) return win - + def draw_arg(arg, layout, recomb_width=.4, branch_click=None): - + def branch_hotspot(node, parent, x, y, y2): def func(): branch_click(node, parent) @@ -167,11 +167,11 @@ def func(): g.append(branch_hotspot(child, node, x1, y1, y2)) # draw recomb - for node in layout: + for node in layout: if node.event == "recomb": x, y = layout[node] g.append(draw_mark(x, y, col=(1, 0, 0))) - + return g @@ -195,7 +195,7 @@ def trans_camera(win, x, y): for tree, block in izip(arglib.iter_marginal_trees(arg), blocks): pos = block[0] print pos - + leaves = sorted((x for x in tree.leaves()), key=lambda x: x.name) layout = layout_arg(tree, leaves) win.add_group( @@ -203,9 +203,9 @@ def trans_camera(win, x, y): draw_tree(tree, layout), text_clip( "%d-%d" % (block[0], block[1]), - treewidth*.05, 0, + treewidth*.05, 0, treewidth*.95, -max(l[1] for l in layout.values()), - 4, 20, + 4, 20, "center", "top"))) # mark responsible recomb node @@ -225,11 +225,11 @@ def trans_camera(win, x, y): nx, ny = layout[tree[node.name]] py = layout[tree[node.name].parents[0]][1] start = arg[node.name].data["ancestral"][0][0] - win.add_group(lines(color(0,1,0), + win.add_group(lines(color(0,1,0), x+nx, ny, x+nx, py, color(1,1,1))) - - + + x += treewidth win.set_visible(* win.get_root().get_bounding() + ("exact",)) @@ -256,7 +256,7 @@ def func(): def print_branch(node, parent): print "node", node.name - + tree_track = iter(tree_track) if mut: @@ -280,19 +280,19 @@ def trans_camera(win, x, y): for block, tree in chain([(block, tree)], tree_track): pos = block[0] print pos - + layout = treelib.layout_tree(tree, xscale=1, yscale=1) treelib.layout_tree_vertical(layout, leaves=0) g = win.add_group( translate(treex, 0, color(1,1,1), - sumtree.draw_tree(tree, layout, + sumtree.draw_tree(tree, layout, vertical=True), (draw_labels(tree, layout) if show_labels else group()), text_clip( "%d-%d" % (block[0], block[1]), - treewidth*.05, 0, + treewidth*.05, 0, treewidth*.95, -max(l[1] for l in layout.values()), - 4, 20, + 4, 20, "center", "top"))) @@ -308,7 +308,7 @@ def trans_camera(win, x, y): clicking.append(branch_hotspot(node, node.parent, x, y, y2)) #win.add_group(clicking) - + # draw mut if mut: for mpos, age, chroms in mut: @@ -322,8 +322,8 @@ def trans_camera(win, x, y): elif mpos > block[1]: mut.push((mpos, age, chroms)) break - - + + treex += treewidth #win.set_visible(* win.get_root().get_bounding() + ("exact",)) @@ -334,9 +334,9 @@ def trans_camera(win, x, y): def show_coal_track(tree_track): - + win = summon.Window() - + bgcolor = (1, 1, 1, .1) cmap = util.rainbow_color_map(low=0.0, high=1.0) @@ -365,7 +365,7 @@ def func(): print "pos=%s age=%f" % (util.int2pretty(int(x)), y) win.add_group(hotspot("click", 0, 0, end, maxage, func)) - + win.home("exact") @@ -405,7 +405,7 @@ def on_scroll_window(win): def on_resize_window(win): region = win.get_visible() print region - + branch_color = (1, 1, 1) @@ -419,43 +419,43 @@ def on_resize_window(win): win.set_binding(input_key("["), lambda : trans_camera(win, -treewidth, 0)) win.add_view_change_listener(lambda : on_scroll_window(win)) #win.remove_resize_listener(lambda : on_resize_window(win)) - + treex = 0 step = 2 - + names = [] seq_range = [0, 0] treewidth = 10 tree = None layout = None - + for item in smc: if item["tag"] == "NAMES": names = item["names"] if not use_names: names = map(str, range(len(names))) - + treewidth = len(names) - + elif item["tag"] == "RANGE": seq_range = [item["start"], item["end"]] - + elif item["tag"] == "TREE": tree = item["tree"] - + layout = treelib.layout_tree(tree, xscale=1, yscale=1) treelib.layout_tree_vertical(layout, leaves=0) #map_layout(layout, yfunc=minlog) region_text = text_clip("%d-%d" % (item["start"], item["end"]), - treewidth*.05, 0, + treewidth*.05, 0, treewidth*.95, -max(l[1] for l in layout.values()), - 4, 20, + 4, 20, "center", "top") - + g = win.add_group( translate(treex, 0, color(1,1,1), - sumtree.draw_tree(tree, layout, + sumtree.draw_tree(tree, layout, vertical=True), (draw_labels(tree, layout) if show_labels else group()), @@ -463,7 +463,7 @@ def on_resize_window(win): axis=(treewidth, 0), miny=1.0, maxy=1.0) )) - + clicking = group() g.append(clicking) @@ -480,10 +480,10 @@ def on_resize_window(win): mark_tree(tree, layout, item["recomb_node"], time=item["recomb_time"], col=recomb_color))) - + treex += treewidth + step - - + + ''' tree_track = iter(tree_track) if mut: @@ -507,19 +507,19 @@ def trans_camera(win, x, y): for block, tree in chain([(block, tree)], tree_track): pos = block[0] print pos - + layout = treelib.layout_tree(tree, xscale=1, yscale=1) treelib.layout_tree_vertical(layout, leaves=0) g = win.add_group( translate(treex, 0, color(1,1,1), - sumtree.draw_tree(tree, layout, + sumtree.draw_tree(tree, layout, vertical=True), (draw_labels(tree, layout) if show_labels else group()), text_clip( "%d-%d" % (block[0], block[1]), - treewidth*.05, 0, + treewidth*.05, 0, treewidth*.95, -max(l[1] for l in layout.values()), - 4, 20, + 4, 20, "center", "top"))) @@ -535,7 +535,7 @@ def trans_camera(win, x, y): clicking.append(branch_hotspot(node, node.parent, x, y, y2)) #win.add_group(clicking) - + # draw mut if mut: for mpos, age, chroms in mut: @@ -549,23 +549,23 @@ def trans_camera(win, x, y): elif mpos > block[1]: mut.push((mpos, age, chroms)) break - - + + treex += treewidth ''' win.home("exact") - + return win def show_coal_track3(tree_track): - + win = summon.Window() - - + + bgcolor = (1, 1, 1, .1) cmap = util.rainbow_color_map(low=0.5, high=1.0) @@ -576,7 +576,7 @@ def show_coal_track3(tree_track): times = treelib.get_tree_timestamps(tree) nleaves = len(tree.leaves()) maxage2 = 0 - for node in tree: + for node in tree: if len(node.children) > 1: age = times[node] sizes = [len(x.leaves()) for x in node.children] @@ -594,7 +594,7 @@ def func(): print "pos=%s age=%f" % (util.int2pretty(int(x)), y) win.add_group(hotspot("click", 0, 0, end, maxage, func)) - + win.home("exact") @@ -602,10 +602,10 @@ def func(): def show_coal_track2(tree_track): - + win = summon.Window() - - + + bgcolor = (1, 1, 1, .1) cmap = util.rainbow_color_map(low=0.0, high=1.0) tracks = {} @@ -638,7 +638,7 @@ def func(): print "pos=%s age=%f" % (util.int2pretty(int(x)), y) win.add_group(hotspot("click", 0, 0, end, maxage, func)) - + win.home("exact") @@ -646,10 +646,10 @@ def func(): def show_coal_track2(tree_track): - + win = summon.Window() - - + + bgcolor = (1, 1, 1, .1) cmap = util.rainbow_color_map(low=0.0, high=1.0) @@ -660,7 +660,7 @@ def show_coal_track2(tree_track): times = treelib.get_tree_timestamps(tree) nleaves = len(tree.leaves()) maxage2 = 0 - for node in tree: + for node in tree: if len(node.children) > 1: age = times[node] sizes = [len(x.leaves()) for x in node.children] @@ -681,7 +681,7 @@ def func(): print "pos=%s age=%f" % (util.int2pretty(int(x)), y) win.add_group(hotspot("click", 0, 0, end, maxage, func)) - + win.home("exact") @@ -690,7 +690,7 @@ def func(): def draw_tree(tree, layout, orient="vertical"): - + vis = group() bends = {} @@ -704,7 +704,7 @@ def draw_tree(tree, layout, orient="vertical"): bends[node] = (nx, py) else: bends[node] = (px, ny) - + # draw branch vis.append(lines(nx, ny, bends[node][0], bends[node][1])) @@ -719,12 +719,12 @@ def draw_tree(tree, layout, orient="vertical"): def draw_mark(x, y, col=(1,0,0), size=.5, func=None): """Draw a mark at (x, y)""" - + if func: h = hotspot("click", x-size, y-size, x+size, y+size, func) else: h = group() - + return zoom_clamp( color(*col), box(x-size, y-size, x+size, y+size, fill=True), @@ -746,7 +746,7 @@ def mark_tree(tree, layout, name, y=None, time=None, return draw_mark(nx, yfunc(y), col=col, size=size) -def draw_branch_mark(arg, layout, node=None, parent=None, pos=None, +def draw_branch_mark(arg, layout, node=None, parent=None, pos=None, chroms=None, age=None, col=(0,0,1)): """Draw a mark on a branch of an ARG""" @@ -755,7 +755,7 @@ def draw_branch_mark(arg, layout, node=None, parent=None, pos=None, if parent is None: assert pos is not None parent = arg.get_local_parent(node, pos) - + if node and parent: if age is None: t = random.uniform(layout[node][1], layout[parent][1]) @@ -768,7 +768,7 @@ def draw_branch_mark(arg, layout, node=None, parent=None, pos=None, -def draw_branch(arg, layout, node=None, parent=None, chroms=None, +def draw_branch(arg, layout, node=None, parent=None, chroms=None, pos=None, col=None): """Draw a mark on a branch of an ARG""" @@ -778,7 +778,7 @@ def draw_branch(arg, layout, node=None, parent=None, chroms=None, if parent is None: assert pos is not None parent = arg.get_local_parent(node, pos) - + if node and parent: x1, y1, x2, y2 = get_branch_layout(layout, node, parent) if col is None: @@ -797,7 +797,7 @@ def draw_branch(arg, layout, node=None, parent=None, chroms=None, def inorder_tree(tree): queue = [("queue", tree.root)] - + while queue: cmd, node = queue.pop() @@ -811,7 +811,7 @@ def inorder_tree(tree): [("queue", node.children[1]), ("visit", node), ("queue", node.children[0])]) - + def layout_tree_leaves_even(tree): layout = {} @@ -822,7 +822,7 @@ def layout_tree_leaves_even(tree): layout[node.name] = y else: y += 1 - + return layout @@ -834,7 +834,7 @@ def layout_tree_leaves(tree): if node.is_leaf(): layout[node.name] = y else: - #y += 1 + #y += 1 y += (node.age / 1e3) + 1 #y += exp(node.age / 5e2) + 1 #y += log(node.age + 1) ** 3 @@ -843,7 +843,7 @@ def layout_tree_leaves(tree): mid = (max(vals) + min(vals)) / 2.0 for k, v in layout.items(): layout[k] = (v - mid) - + return layout @@ -862,7 +862,7 @@ def layout_chroms(arg, start=None, end=None): layout_func = layout_tree_leaves #layout_func = layout_tree_leaves_even - + for spr in arglib.iter_arg_sprs(arg, start=start, end=end, use_leaves=True): print "layout", spr[0] blocks.append([last_pos, spr[0]]) @@ -883,7 +883,7 @@ def layout_chroms(arg, start=None, end=None): rindex = rnode.parents[0].children.index(rnode) if left and rindex != 0: rnode.parents[0].children.reverse() - + last_pos = spr[0] blocks.append([last_pos, end]) @@ -916,6 +916,7 @@ def walk(node): def mouse_click(win): print win.get_mouse_pos("world") + def chrom_click(win, chrom, block): def func(): if win: @@ -933,7 +934,7 @@ def draw_arg_threads(arg, blocks, layout, sites=None, draw_group=None, win=None): - leaf_names = list(arg.leaf_names()) + leaf_names = set(arg.leaf_names()) # TEST: rnodes = dict((r.pos, r) for r in arg if r.event == "recomb") @@ -956,7 +957,7 @@ def draw_arg_threads(arg, blocks, layout, sites=None, trims = [] - + for k, (x1, x2) in enumerate(blocks): # calc trims length = x2 - x1 @@ -964,7 +965,7 @@ def draw_arg_threads(arg, blocks, layout, sites=None, spr_trim2 = min(spr_trim, (length - minlen) / 2.0) trims.append((x1 + spr_trim2, x2 - spr_trim2)) trim = trims[-1] - + # horizontal lines l = [] for name in leaf_names: @@ -976,18 +977,18 @@ def draw_arg_threads(arg, blocks, layout, sites=None, # SPRs if k > 0: l = [] - + # TEST: #rnode = rnodes.get(x1, None) #young = (rnode is not None and rnode.age < 500) - + for name in leaf_names: #c = [1,0,0] if young else spr_colors[name] c = spr_colors[name] y1 = layout[k-1][name] y2 = layout[k][name] l.extend([color(*c), trims[k-1][1], y1, trims[k][0], y2]) - + draw_group.append(lines(*l)) # hotspots @@ -1003,7 +1004,9 @@ def draw_arg_threads(arg, blocks, layout, sites=None, if sites: l = [] for pos, col in sites.iter_region(x1, x2): - split = sites.get_minor(pos) + split = set(sites.get_minor(pos)) & leaf_names + if len(split) == 0: + continue if compat: if tree is None: tree = arg.get_marginal_tree((x1+x2)/2.0) @@ -1018,7 +1021,7 @@ def draw_arg_threads(arg, blocks, layout, sites=None, else: c = color(*snp_colors["compat"]) derived = split - + for d in derived: if d in layout[k]: y = layout[k][d] @@ -1026,4 +1029,3 @@ def draw_arg_threads(arg, blocks, layout, sites=None, draw_group.append(lines(*l)) return draw_group - diff --git a/argweaver/deps/rasmus/stats.py b/argweaver/deps/rasmus/stats.py index 9a73bfee..f59c76e5 100644 --- a/argweaver/deps/rasmus/stats.py +++ b/argweaver/deps/rasmus/stats.py @@ -4,21 +4,22 @@ """ - # python libs -from math import * +from math import ceil +from math import exp +from math import floor +from math import log +from math import pi +from math import sqrt from itertools import izip import cmath import random -import os # rasmus libs from rasmus import util from rasmus import tablelib - - def logprod(lst): """Computes the product of a list of numbers""" return sum(log(i) for i in lst) @@ -41,16 +42,18 @@ def mean(vals): n += 1 return s / float(n) + def median(vals): """Computes the median of a list of numbers""" lenvals = len(vals) sortvals = sorted(vals) - + if lenvals % 2 == 0: return (sortvals[lenvals / 2] + sortvals[lenvals / 2 - 1]) / 2.0 else: return sortvals[lenvals / 2] + def mode(vals): """Computes the mode of a list of numbers""" top = 0 @@ -64,44 +67,46 @@ def mode(vals): def msqerr(vals1, vals2): """Mean squared error""" - + assert len(vals1) == len(vals2), "lists are not the same length" - - - return mean([(vals1[i] - vals2[i]) ** 2 + + return mean([(vals1[i] - vals2[i]) ** 2 for i in xrange(len(vals1))]) - - + def variance(vals): """Variance""" u = mean(vals) return sum((x - u)**2 for x in vals) / float(len(vals)-1) + def sdev(vals): """Standard deviation""" return sqrt(variance(vals)) + def serror(vals): """Stanadrd error""" return sdev(vals) / sqrt(len(vals)) + def covariance(lst1, lst2): """Covariance""" m1 = mean(lst1) m2 = mean(lst2) tot = 0.0 for i in xrange(len(lst1)): - tot += (lst1[i] - m1) * (lst2[i] - m2) + tot += (lst1[i] - m1) * (lst2[i] - m2) return tot / (len(lst1)-1) def covmatrix(mat): """Covariance Matrix""" size = len(mat) - return [[cov(mat[i], mat[j]) for j in range(size)] + return [[covariance(mat[i], mat[j]) for j in range(size)] for i in range(size)] + def corrmatrix(mat): """Correlation Matrix""" size = len(mat) @@ -121,21 +126,21 @@ def corr(lst1, lst2): def corr_pvalue(r, n): """Returns the signficance of correlation > r with n samples""" - + import rpy.r t = r / sqrt((1 - r*r) / float(n - 2)) return rpy.r.pt(-t, n-2) - + def qqnorm(data, plot=None): """Quantile-quantile plot""" - + from rasmus import gnuplot data2 = sorted(data) norm = [random.normalvariate(0, 1) for x in range(len(data2))] norm.sort() - - if plot == None: + + if plot is None: return gnuplot.plot(data2, norm) else: plot.plot(data2, norm) @@ -147,20 +152,24 @@ def entropy(probs, base=2): return - sum(p * log(p, base) for p in probs if p > 0.0) + def cross_entropy(p, q, base=2): try: - return - sum(i * log(j, base) for i,j in izip(p, q) if i > 0.0) + return - sum(i * log(j, base) for i, j in izip(p, q) if i > 0.0) except OverflowError: return util.INF + def kl_div(p, q): """Compute the KL divergence for two discrete distributions""" return cross_entropy(p, q) - entropy(p) + def akaike_ic(lnl, k): """Akaike information criterion""" return 2 * k - 2 * lnl + def akaike_icc(lnl, n, k): """Akaike information criterion with second order correction Good for small sample sizes @@ -178,13 +187,12 @@ def bayesian_ic(lnl, n, k): return -2 * lnl + k * log(n) - def fitLine(xlist, ylist): """2D regression""" - + xysum = 0 xxsum = 0 - n = len(xlist) + n = len(xlist) for i in range(n): xysum += xlist[i] * ylist[i] xxsum += xlist[i] * xlist[i] @@ -205,7 +213,7 @@ def fitLineError(xlist, ylist, slope, inter): """Returns the Mean Square Error of the data fit""" error = 0 n = len(xlist) - + for i in range(n): error += ((xlist[i]*slope + inter) - ylist[i]) ** 2 return error / n @@ -213,18 +221,18 @@ def fitLineError(xlist, ylist, slope, inter): def pearsonsRegression(observed, expected): """Pearson's coefficient of regression""" - + # error sum of squares ess = sum((a - b)**2 for a, b in izip(observed, expected)) - + # total sum of squares u = mean(observed) tss = sum((a - u)**2 for a in observed) - + r2 = 1 - ess / tss return r2 - + def pearsonsRegressionLine(x, y, m, b): observed = y expected = [m*i + b for i in x] @@ -266,7 +274,7 @@ def percentile(vals, perc, rounding=-1, sort=True): rounding -- round down if -1 or round up for 1 sort -- if True, sort vals first """ - + if sort: vals2 = sorted(vals) else: @@ -286,7 +294,7 @@ def dither(vals, radius): def logadd(lna, lnb): """Adding numbers in log-space""" - + diff = lna - lnb if diff < 500: return log(exp(diff) + 1.0) + lnb @@ -309,7 +317,7 @@ def logsum(vals): for i in xrange(len(vals)): if i != maxi and vals[i] - maxval > SUM_LOG_THRESHOLD: expsum += exp(vals[i] - maxval) - + return maxval + log(expsum) @@ -329,7 +337,7 @@ def logsub(lna, lnb): return log(diff2) + lnb else: return lna - + def logadd_sign(sa, lna, sb, lnb): """Adding numbers in log-space""" @@ -369,21 +377,19 @@ def logadd_sign(sa, lna, sb, lnb): def smooth(vals, radius): """ return an averaging of vals using a radius - + Note: not implemented as fast as possible runtime: O(len(vals) * radius) """ - + vals2 = [] vlen = len(vals) - + for i in xrange(vlen): radius2 = min(i, vlen - i - 1, radius) vals2.append(mean(vals[i-radius2:i+radius2+1])) - - return vals2 - + return vals2 def iter_window_index(x, xdist, esp=None): @@ -391,61 +397,59 @@ def iter_window_index(x, xdist, esp=None): iterates a sliding window over x with width 'xdist' returns an iterator over list of indices in x that represent windows - + x must be sorted least to greatest """ - vlen = len(x) #if esp is None: # esp = min(x[i+1] - x[i] for i in range(vlen-1) # if x[i+1] - x[i] > 0) / 2.0 - + # simple case if vlen == 0: return - + start = x[0] - end = x[-1] - window = [0] - + low = start high = start + xdist - lowi = 0 # inclusive - highi = 0 # inclusive + lowi = 0 # inclusive + highi = 0 # inclusive # move up high boundary while highi+1 < vlen and x[highi+1] < high: highi += 1 yield (lowi, highi, low, high) - + while highi+1 < vlen: - low_step = x[lowi] - low # dist until expell - high_step = x[highi+1] - high # dist until include + low_step = x[lowi] - low # dist until expell + high_step = x[highi+1] - high # dist until include # advance though duplicates if low_step == 0: lowi += 1 continue - + if high_step == 0: highi += 1 continue # determine new low high boundary if low_step <= high_step: - low = x[lowi] #+ min(esp, (high_step - low_step) / 2.0) - high = low + xdist + low = x[lowi] # + min(esp, (high_step - low_step) / 2.0) + high = low + xdist lowi += 1 - + if high_step <= low_step: highi += 1 - if highi >= vlen: break - high = x[highi] #+ min(esp, (low_step - high_step) / 2.0) + if highi >= vlen: + break + high = x[highi] # + min(esp, (low_step - high_step) / 2.0) low = high - xdist assert abs((high - low) - xdist) < .001, (low, high) - + yield (lowi, highi, low, high) @@ -461,7 +465,7 @@ def iter_window_index_step(x, size, step, minsize=0): lowi = 0 highi = 0 - + # move up high boundary while highi+1 < vlen and x[highi+1] < high: highi += 1 @@ -480,16 +484,14 @@ def iter_window_index_step(x, size, step, minsize=0): # move up high boundary while highi+1 < vlen and x[highi+1] < high: highi += 1 - - + def iter_window(x, xdist, func=lambda win: win, minsize=0, key=lambda x: x): """ iterates a sliding window over x with radius xradius - + x must be sorted least to greatest """ - for lowi, highi, low, high in iter_window_index(map(key, x), xdist): if highi - lowi >= minsize: yield (high + low)/2.0, func(x[lowi:highi]) @@ -498,53 +500,48 @@ def iter_window(x, xdist, func=lambda win: win, minsize=0, key=lambda x: x): def iter_window_step(x, width, step, func=lambda win: win, minsize=0): """ iterates a sliding window over x with width 'width' - + x must be sorted least to greatest return an iterator with (midx, func(x[lowi:highi])) """ - - for lowi, highi, low, high in iter_window_index_step(x, width, step, minsize): + for lowi, highi, low, high in iter_window_index_step( + x, width, step, minsize): yield (high + low) / 2.0, func(x[lowi:highi]) - - - - def smooth2(x, y, xradius, minsize=0, sort=False): """ return an averaging of x and y using xradius - + x must be sorted least to greatest """ - vlen = len(x) assert vlen == len(y) - + # simple case if vlen == 0: return [], [] - + if sort: x, y = util.sort_many(x, y) - + x2 = [] y2 = [] - + start = min(x) end = max(x) xtot = x[0] ytot = y[0] - + low = 0 high = 0 - + for i in xrange(vlen): xi = x[i] - + xradius2 = min(xi - start, end - xi, xradius) - + # move window while x[low] < xi - xradius2: xtot -= x[low] @@ -554,18 +551,18 @@ def smooth2(x, y, xradius, minsize=0, sort=False): high += 1 xtot += x[high] ytot += y[high] - + denom = float(high - low + 1) if denom >= minsize: x2.append(xtot / denom) y2.append(ytot / denom) - + return x2, y2 def factorial(x, k=1): """Simple implementation of factorial""" - + n = 1 for i in xrange(int(k)+1, int(x)+1): n *= i @@ -574,7 +571,7 @@ def factorial(x, k=1): def logfactorial(x, k=1): """returns the log(factorial(x) / factorial(k)""" - + n = 0 for i in xrange(int(k)+1, int(x)+1): n += log(i) @@ -584,14 +581,14 @@ def logfactorial(x, k=1): def choose(n, k): if n == 0 and k == 0: return 1 - + if n < 0 or k < 0 or k > n: return 0 - + # optimization for speed if k > n/2: k = n - k - + t = 1.0 n2 = n + 1.0 for i in xrange(1, k+1): @@ -599,34 +596,36 @@ def choose(n, k): return int(t + 0.5) #return factorial(n, n - k) / factorial(k) + def fchoose(n, k): if n == 0 and k == 0: return 1 - + if n < 0 or k < 0 or k > n: return 0 - + # optimization for speed if k > n/2: k = n - k - + t = 1.0 n2 = n + 1.0 for i in xrange(1, k+1): t *= (n2 - i) / i return t + def logchoose(n, k): if n == 0 and k == 0: return 0.0 - + if n < 0 or k < 0 or k > n: return -util.INF - + # optimization for speed if k > n/2: k = n - k - + t = 0.0 n2 = n + 1.0 for i in xrange(1, k+1): @@ -656,10 +655,9 @@ def sample(weights): """ Randomly choose an int between 0 and len(probs)-1 using the weights stored in list probs. - + item i will be chosen with probability weights[i]/sum(weights) """ - total = sum(weights) pick = random.random() * total x = 0 @@ -668,7 +666,7 @@ def sample(weights): if x >= pick: return i return len(weights) - 1 - + def rhyper(m, n, M, N, report=0): ''' @@ -676,22 +674,18 @@ def rhyper(m, n, M, N, report=0): hypergeometric distribution over/under/both (report = 0/1/2) (uses R through RPy) - + N = total balls in urn M = total white balls in urn n = drawn balls from urn m = drawn white balls from urn - - ''' + ''' from rpy import r - - assert( (type(m) == type(n) == type(M) == type(N) == int) + assert ((type(m) == type(n) == type(M) == type(N) == int) and m <= n and m <= M and n <= N) - - - + if report == 0: #p-val for over-repr. return r.phyper(m-1, M, N-M, n, lower_tail=False) @@ -700,26 +694,27 @@ def rhyper(m, n, M, N, report=0): return r.phyper(m, M, N-M, n) elif report == 2: #tuple (over, under) - return r.phyper(m-1, M, N-M, n, lower_tail=False), r.phyper(m, M, N-M, n) + return (r.phyper(m-1, M, N-M, n, lower_tail=False), + r.phyper(m, M, N-M, n)) else: raise "unknown option" def cdf(vals): """Computes the CDF of a list of values""" - + vals = sorted(vals) tot = float(len(vals)) x = [] y = [] - + for i, x2 in enumerate(vals): x.append(x2) y.append(i / tot) - + return x, y - - + + def enrichItems(in_items, out_items, M=None, N=None, useq=True, extra=False): """Calculates enrichment for items within an in-set vs and out-set. Returns a sorted table. @@ -727,22 +722,21 @@ def enrichItems(in_items, out_items, M=None, N=None, useq=True, extra=False): # DEPRECATED # TODO: remove this function - # count items counts = util.Dict(default=[0, 0]) for item in in_items: counts[item][0] += 1 for item in out_items: counts[item][1] += 1 - + if N is None: N = len(in_items) + len(out_items) if M is None: M = len(in_items) - - tab = tablelib.Table(headers=["item", "in_count", "out_count", + + tab = tablelib.Table(headers=["item", "in_count", "out_count", "pval", "pval_under"]) - + # do hypergeometric for item, (a, b) in counts.iteritems(): tab.add(item=item, @@ -750,15 +744,15 @@ def enrichItems(in_items, out_items, M=None, N=None, useq=True, extra=False): out_count=b, pval=rhyper(a, a+b, M, N), pval_under=rhyper(a, a+b, M, N, 1)) - + # add qvalues if useq: qval = qvalues(tab.cget("pval")) qval_under = qvalues(tab.cget("pval_under")) - + tab.add_col("qval", data=qval) tab.add_col("qval_under", data=qval_under) - + if extra: tab.add_col("in_size", data=[M]*len(tab)) tab.add_col("out_size", data=[N-M]*len(tab)) @@ -768,8 +762,8 @@ def enrichItems(in_items, out_items, M=None, N=None, useq=True, extra=False): tab.add_col("size_ratio", data=[ M / float(N) for row in tab]) tab.add_col("fold", data=[row["item_ratio"] / row["size_ratio"] - for row in tab]) - + for row in tab]) + tab.sort(col='pval') return tab @@ -780,7 +774,6 @@ def qvalues(pvals): return ret - #============================================================================= # Distributions # @@ -797,40 +790,45 @@ def binomialPdf(k, params): p, n = params return choose(n, k) * (p ** k) * ((1.0-p) ** (n - k)) + def gaussianPdf(x, params): return 1/sqrt(2*pi) * exp(- x**2 / 2.0) + def normalPdf(x, params): mu, sigma = params # sqrt(2*pi) = 2.5066282746310002 return exp(- (x - mu)**2 / (2.0 * sigma**2)) / (sigma * 2.5066282746310002) + def normalCdf(x, params): mu, sigma = params return (1 + erf((x - mu)/(sigma * sqrt(2)))) / 2.0 + def logNormalPdf(x, params): - """mu and sigma are the mean and standard deviation of the + """mu and sigma are the mean and standard deviation of the variable's logarithm""" - + mu, sigma = params - return 1/(x * sigma * sqrt(2*pi)) * \ - exp(- (log(x) - mu)**2 / (2.0 * sigma**2)) + return (1/(x * sigma * sqrt(2*pi)) * + exp(- (log(x) - mu)**2 / (2.0 * sigma**2))) + def logNormalCdf(x, params): - """mu and sigma are the mean and standard deviation of the + """mu and sigma are the mean and standard deviation of the variable's logarithm""" - + mu, sigma = params return (1 + erf((log(x) - mu)/(sigma * sqrt(2)))) / 2.0 def poissonPdf(x, params): lambd = params[0] - + if x < 0 or lambd <= 0: return 0.0 - + a = 0 for i in xrange(1, int(x)+1): a += log(lambd / float(i)) @@ -841,12 +839,12 @@ def poissonCdf(x, params): """Cumulative distribution function of the Poisson distribution""" # NOTE: not implemented accurately for large x or lambd lambd = params[0] - + if x < 0: return 0 else: - return (gamma(floor(x+1)) - gammainc(floor(x + 1), lambd)) / \ - factorial(floor(x)) + return ((gamma(floor(x+1)) - gammainc(floor(x + 1), lambd)) / + factorial(floor(x))) def poissonvariate(lambd): @@ -864,7 +862,7 @@ def poissonvariate(lambd): def exponentialPdf(x, params): lambd = params[0] - + if x < 0 or lambd < 0: return 0.0 else: @@ -873,7 +871,7 @@ def exponentialPdf(x, params): def exponentialCdf(x, params): lambd = params[0] - + if x < 0 or lambd < 0: return 0.0 else: @@ -883,13 +881,14 @@ def exponentialCdf(x, params): def exponentialvariate(lambd): return -log(random.random()) / lambd + def gammaPdf(x, params): alpha, beta = params if x <= 0 or alpha <= 0 or beta <= 0: return 0.0 else: - return (exp(-x * beta) * (x ** (alpha - 1)) * (beta ** alpha)) / \ - gamma(alpha) + return ((exp(-x * beta) * (x ** (alpha - 1)) * (beta ** alpha)) / + gamma(alpha)) def loggammaPdf(x, params): @@ -899,6 +898,7 @@ def loggammaPdf(x, params): else: return -x*beta + (alpha - 1)*log(x) + alpha*log(beta) - gammaln(alpha) + def gammaPdf2(x, params): alpha, beta = params if x <= 0 or alpha <= 0 or beta <= 0: @@ -914,45 +914,45 @@ def gammaCdf(x, params): else: return gammainc(alpha, x * beta) / gamma(alpha) + def invgammaPdf(x, params): a, b = params - if x <=0 or a <= 0 or b <= 0: + if x <= 0 or a <= 0 or b <= 0: return 0.0 else: return (b**a) / gamma(a) * (1.0/x)**(a + 1) * exp(-b/x) + def loginvgammaPdf(x, params): a, b = params if x < 0 or a < 0 or b < 0: return -util.INF else: - return a*log(b) - gammaln(a) + (a+1)*log(1.0/x) -b/x - - + return a*log(b) - gammaln(a) + (a+1)*log(1.0/x) - b/x def betaPdf2(x, params): """A simpler implementation of beta distribution but will overflow for values of alpha and beta near 100 """ - alpha, beta = params if 0 < x < 1 and alpha > 0 and beta > 0: - return gamma(alpha + beta) / (gamma(alpha)*gamma(beta)) * \ - x ** (alpha-1) * (1-x)**(beta-1) + return (gamma(alpha + beta) / (gamma(alpha)*gamma(beta)) * + x ** (alpha-1) * (1-x)**(beta-1)) else: return 0.0 + def betaPdf(x, params): alpha, beta = params - + if 0 < x < 1 and alpha > 0 and beta > 0: - return exp(gammaln(alpha + beta) - (gammaln(alpha) + gammaln(beta)) + \ - (alpha-1) * log(x) + (beta-1) * log(1-x)) + return (exp(gammaln(alpha + beta) - + (gammaln(alpha) + gammaln(beta)) + + (alpha-1) * log(x) + (beta-1) * log(1-x))) else: return 0.0 - def betaPdf3(x, params): @@ -960,11 +960,11 @@ def betaPdf3(x, params): if 0 < x < 1 and alpha > 0 and beta > 0: n = min(alpha-1, beta-1) m = max(alpha-1, beta-1) - + prod1 = 1 - for i in range(1,n+1): + for i in range(1, n+1): prod1 *= ((n+i)*x*(1-x))/i - + prod2 = 1 if alpha > beta: for i in range(n+1, m+1): @@ -972,7 +972,7 @@ def betaPdf3(x, params): else: for i in range(n+1, m+1): prod2 *= ((n+i)*(1-x))/i - + return prod1 * prod2 * (alpha + beta - 1) else: return 0.0 @@ -983,24 +983,21 @@ def negbinomPdf(k, r, p): r*log(p) + k * log(1-p)) - def gamma(x): """ - Lanczos approximation to the gamma function. - - found on http://www.rskey.org/gamma.htm + Lanczos approximation to the gamma function. + + found on http://www.rskey.org/gamma.htm """ - - ret = 1.000000000190015 + \ - 76.18009172947146 / (x + 1) + \ - -86.50532032941677 / (x + 2) + \ - 24.01409824083091 / (x + 3) + \ - -1.231739572450155 / (x + 4) + \ - 1.208650973866179e-3 / (x + 5) + \ - -5.395239384953e-6 / (x + 6) - - return ret * sqrt(2*pi)/x * (x + 5.5)**(x+.5) * exp(-x-5.5) + ret = (1.000000000190015 + + 76.18009172947146 / (x + 1) + + -86.50532032941677 / (x + 2) + + 24.01409824083091 / (x + 3) + + -1.231739572450155 / (x + 4) + + 1.208650973866179e-3 / (x + 5) + + -5.395239384953e-6 / (x + 6)) + return ret * sqrt(2*pi)/x * (x + 5.5)**(x+.5) * exp(-x-5.5) def gammaln(xx): @@ -1010,8 +1007,8 @@ def gammaln(xx): float gammln(float xx) Returns the value ln[(xx)] for xx > 0. { - Internal arithmetic will be done in double precision, a nicety that you can omit if five-figure - accuracy is good enough. + Internal arithmetic will be done in double precision, a nicety + that you can omit if five-figure accuracy is good enough. double x,y,tmp,ser; static double cof[6]={76.18009172947146,-86.50532032941677, 24.01409824083091,-1.231739572450155, @@ -1026,29 +1023,27 @@ def gammaln(xx): } """ - cof = [76.18009172947146,-86.50532032941677, - 24.01409824083091,-1.231739572450155, - 0.1208650973866179e-2,-0.5395239384953e-5] - + cof = [76.18009172947146, -86.50532032941677, + 24.01409824083091, -1.231739572450155, + 0.1208650973866179e-2, -0.5395239384953e-5] + y = x = xx tmp = x + 5.5 tmp -= (x + 0.5) * log(tmp) ser = 1.000000000190015 - + for j in range(6): y += 1 ser += cof[j] / y - - return - tmp + log(2.5066282746310005 * ser / x) - + return - tmp + log(2.5066282746310005 * ser / x) -GAMMA_INCOMP_ACCURACY = 1000 def gammainc(a, x): """Lower incomplete gamma function""" # found on http://www.rskey.org/gamma.htm - + GAMMA_INCOMP_ACCURACY = 1000 + ret = 0 term = 1.0/x for n in xrange(GAMMA_INCOMP_ACCURACY): @@ -1061,38 +1056,39 @@ def gammainc(a, x): def erf(x): # http://www.theorie.physik.uni-muenchen.de/~serge/erf-approx.pdf - + a = 8/(3*pi) * (pi - 3)/(4 - pi) axx = a * x * x - + if x >= 0: return sqrt(1 - exp(-x*x * (4.0/pi + axx)/(1 + axx))) else: return - sqrt(1 - exp(-x*x * (4.0/pi + axx)/(1 + axx))) - def chiSquare(rows, expected=None, nparams=0): # ex: rows = [[1,2,3],[1,4,5]] assert util.equal(map(len, rows)) - if 0 in map(sum,rows): return 0,1.0 + if 0 in map(sum, rows): + return 0, 1.0 cols = zip(* rows) - if 0 in map(sum,cols): return 0,1.0 + if 0 in map(sum, cols): + return 0, 1.0 if not expected: expected = make_expected(rows) chisq = 0 - for obss,exps in zip(rows,expected): + for obss, exps in zip(rows, expected): for obs, exp in zip(obss, exps): chisq += ((obs-exp)**2)/exp df = max(len(rows)-1, 1)*max(len(rows[0])-1, 1) - nparams - p = chi_square_lookup(chisq,df) + p = chi_square_lookup(chisq, df) - return chisq,p + return chisq, p def make_expected(rows): @@ -1101,7 +1097,7 @@ def make_expected(rows): grandtotal = float(sum(rowtotals)) expected = [] - for row,rowtotal in zip(rows,rowtotals): + for row, rowtotal in zip(rows, rowtotals): expected_row = [] for obs, coltotal in zip(row, coltotals): exp = rowtotal * coltotal / grandtotal @@ -1112,21 +1108,21 @@ def make_expected(rows): def chiSquareFit(xbins, ybins, func, nsamples, nparams, minsamples=5): sizes = [xbins[i+1] - xbins[i] for i in xrange(len(xbins)-1)] - sizes.append(sizes[-1]) # NOTE: assumes bins are of equal size - + sizes.append(sizes[-1]) # NOTE: assumes bins are of equal size + # only focus on bins that are large enough counts = [ybins[i] * sizes[i] * nsamples for i in xrange(len(xbins)-1)] - + expected = [] for i in xrange(len(xbins)-1): - expected.append((func(xbins[i]) + func(xbins[i+1]))/2.0 * - sizes[i] * nsamples) - + expected.append((func(xbins[i]) + func(xbins[i+1]))/2.0 * + sizes[i] * nsamples) + # ensure we have enough expected samples in each bin ind = util.find(util.gefunc(minsamples), expected) counts = util.mget(counts, ind) expected = util.mget(expected, ind) - + if len(counts) == 0: return [0, 1], counts, expected else: @@ -1168,94 +1164,95 @@ def chiSquareFit(xbins, ybins, func, nsamples, nparams, minsamples=5): def chi_square_lookup(value, df): - + ps = [0.20, 0.10, 0.05, 0.025, 0.01, 0.001] - + if df <= 0: - return 1.0 - + return 1.0 + row = chi_square_table[min(df, 30)] - for i in range(0,len(row)): + for i in range(0, len(row)): if row[i] >= value: i = i-1 break - - if i == -1: return 1 - else: return ps[i] + if i == -1: + return 1 + else: + return ps[i] def spearman(vec1, vec2): """Spearman's rank test""" - + assert len(vec1) == len(vec2), "vec1 and vec2 are not the same length" - + n = len(vec1) rank1 = util.sortranks(vec1) rank2 = util.sortranks(vec2) - - R = sum((vec1[i] - vec2[i])**2 for i in xrange(n)) - + + R = sum((rank1[i] - rank2[i])**2 for i in xrange(n)) + Z = (6*R - n*(n*n - 1)) / (n*(n + 1) * sqrt(n - 1)) - + return Z - -# input: -# xdata, ydata - data to fit -# func - a function of the form f(x, params) -# -def fitCurve(xdata, ydata, func, paramsInit): +def fitCurve(xdata, ydata, func, paramsInit): + """ + Fit a function to data points. + + Args: + xdata, ydata - data to fit + func - a function of the form f(x, params) + """ import scipy import scipy.optimize y = scipy.array(ydata) p0 = scipy.array(paramsInit) - + def error(params): y2 = scipy.array(map(lambda x: func(x, params), xdata)) return y - y2 params, msg = scipy.optimize.leastsq(error, p0) - + resid = error(params) - + return list(params), sum(resid*resid) - + def fitDistrib(func, paramsInit, data, start, end, step, perc=1.0): xdata, ydata = util.distrib(data, low=start, width=step) ydata = [i / perc for i in ydata] xdata = util.histbins(xdata) params, resid = fitCurve(xdata, ydata, func, paramsInit) return params, resid - -def plotfuncFit(func, paramsInit, xdata, ydata, start, end, step, plot = None, + +def plotfuncFit(func, paramsInit, xdata, ydata, start, end, step, plot=None, **options): from rasmus import gnuplot if not plot: plot = gnuplot.Gnuplot() - + options.setdefault('style', 'boxes') - + params, resid = fitCurve(xdata, ydata, func, paramsInit) plot.plot(util.histbins(xdata), ydata, **options) plot.plotfunc(lambda x: func(x, params), start, end, step) - + return plot, params, resid - -def plotdistribFit(func, paramsInit, data, start, end, step, plot = None, + +def plotdistribFit(func, paramsInit, data, start, end, step, plot=None, **options): xdata, ydata = util.distrib(data, low=start, width=step) - return plotfuncFit(func, paramsInit, xdata, ydata, start, end, step/10, plot, - **options) - - + return plotfuncFit( + func, paramsInit, xdata, ydata, start, end, step/10, plot, **options) def chi_square_fit(cdf, params, data, ndivs=20, minsamples=5, plot=False, @@ -1276,7 +1273,7 @@ def chi_square_fit(cdf, params, data, ndivs=20, minsamples=5, plot=False, obs = scipy.array(map(len, bins)) ind = util.find(lambda x: x[-1] >= start and x[0] <= end, bins) obs = util.mget(obs, ind) - + x = [bin[0] for bin in bins] expected = [len(data) * cdf(x[1], params)] expected.extend([len(data) * @@ -1284,13 +1281,13 @@ def chi_square_fit(cdf, params, data, ndivs=20, minsamples=5, plot=False, for i in range(1, len(x)-1)]) expected.append(len(data) * (1.0 - cdf(x[-1], params))) expected = scipy.array(util.mget(expected, ind)) - + chi2, pval = scipy.stats.chisquare(obs, expected) - if plot: + if plot: p = gnuplot.plot(util.mget(x, ind), obs) p.plot(util.mget(x, ind), expected) - + return chi2, pval @@ -1312,7 +1309,7 @@ def fit_distrib(cdf, params_init, data, ndivs=20, minsamples=5, obs = scipy.array(map(len, bins)) ind = util.find(lambda x: x[-1] >= start and x[0] <= end, bins) obs = util.mget(obs, ind) - + def optfunc(params): x = [bin[0] for bin in bins] expected = [len(data) * cdf(x[1], params)] @@ -1321,7 +1318,7 @@ def optfunc(params): for i in range(1, len(x)-1)]) expected.append(len(data) * (1.0 - cdf(x[-1], params))) expected = scipy.array(util.mget(expected, ind)) - + chi2, pval = scipy.stats.chisquare(obs, expected) return chi2 @@ -1330,178 +1327,62 @@ def optfunc(params): return list(params), pval - - - def solveCubic(a, b, c, real=True): """solves x^3 + ax^2 + bx + c = 0 for x""" - + p = b - a*a / 3.0 q = c + (2*a*a*a - 9*a*b) / 27.0 - + # special case: avoids division by zero later on if p == q == 0: return [- a / 3.0] - - # + + # # u = (q/2 +- sqrt(q^2/4 + p^3/27))^(1/3) # - + # complex math is used to find complex roots sqrteqn = cmath.sqrt(q*q/4.0 + p*p*p/27.0) - + # find fist cube root u1 = (q/2.0 + sqrteqn)**(1/3.0) - + # special case: avoids division by zero later on if u1 == 0: u1 = (q/2.0 - sqrteqn)**(1/3.0) - + # find other two cube roots u2 = u1 * complex(-.5, -sqrt(3)/2) u3 = u1 * complex(-.5, sqrt(3)/2) - + # finds roots of cubic polynomial root1 = p / (3*u1) - u1 - a / 3.0 root2 = p / (3*u2) - u2 - a / 3.0 root3 = p / (3*u3) - u3 - a / 3.0 - + if real: - return [x.real + return [x.real for x in [root1, root2, root3] if abs(x.imag) < 1e-10] else: return [root1, root2, root3] -def _solveCubic_test(n=100): - - def test(a, b, c): - xs = solveCubic(a, b, c) - - for x in xs: - y = x**3 + a*x*x + b*x + c - assert abs(y) < 1e-4, y - - test(0, 0, 0) - test(0, 1, 1) - test(0, 0, 1) - - for i in xrange(n): - - a = random.normalvariate(10, 5) - b = random.normalvariate(10, 5) - c = random.normalvariate(10, 5) - - test(a, b, c) - - def bisect_root(f, x0, x1, err=1e-7): """Find a root of a function func(x) using the bisection method""" f0 = f(x0) - f1 = f(x1) - + #f1 = f(x1) + while (x1 - x0) / 2.0 > err: x2 = (x0 + x1) / 2.0 f2 = f(x2) - + if f0 * f2 > 0: x0 = x2 f0 = f2 else: x1 = x2 - f1 = f2 + #f1 = f2 return (x0 + x1) / 2.0 - - - - - -#============================================================================= -# testing - -if __name__ == "__main__": - - - # iter_window - from rasmus import util - from rasmus import gnuplot - - vals = sorted([random.random() * 20 for x in range(600)]) - - vals += sorted([40 + random.random() * 20 for x in range(600)]) - - ''' - win = filter(lambda x: len(x) > 0, - list(iter_window_index(vals, 5))) - - p = util.plot(util.cget(win, 2))#, style="lines") - p.enableOutput(False) - p.plot(util.cget(win, 3)) #, style="lines") - - for i, y in enumerate(vals): - p.plot([i, len(vals)], [y, y], style="lines") - p.enableOutput(True) - p.replot() - ''' - - def mean2(v): - if len(v) == 0: - return 0.0 - else: - return mean(v) - - x, y = zip(* iter_window_step(vals, 5, 1, len)) - gnuplot.plot(x, y) - - - - -#============================================================================= -# OLD CODE - -''' -def smooth_old(x, radius): - """ - return an averaging of vals using a radius - - Note: not implemented as fast as possible - runtime: O(len(vals) * radius) - """ - - vlen = len(x) - - # simple case - if vlen == 0: - return [] - - x2 = [] - - tot = x[0] - - low = 0 - high = 0 - - for i in range(vlen): - xi = x[i] - - xradius2 = min(i, vlen - i - 1, xradius) - - # move window - while x[low] < xi - xradius2: - xtot -= x[low] - ytot -= y[low] - low += 1 - while x[high] < xi + xradius2: - high += 1 - xtot += x[high] - ytot += y[high] - - denom = float(high - low + 1) - x2.append(xtot / denom) - y2.append(ytot / denom) - - return x2, y2 -''' diff --git a/argweaver/deps/rasmus/tablelib.py b/argweaver/deps/rasmus/tablelib.py index 67bdf208..0db0ab16 100644 --- a/argweaver/deps/rasmus/tablelib.py +++ b/argweaver/deps/rasmus/tablelib.py @@ -1,9 +1,7 @@ """ tablelib.py -Portable Tabular Format (PTF) - -Implements and standardizes Manolis style tab-delimited table file format. +Parse, format, and manipulate tabular data. --Example---------------------------------------------------- @@ -19,44 +17,42 @@ Directives are on a single line and begin with two hashes '##' No space after colon is allowed. - Table can also handle custom types. Custom types must do the following - 1. default value: - default = mytype() + 1. default value: + default = mytype() returns default value 2. convert from string - val = mytype(string) + val = mytype(string) converts from string to custom type 3. convert to string string = str(val) converts val of type 'mytype' to a string - TODO: I could change this interface... - I could just use mytype.__str__(val) 4. type inference (optional) type(val) returns instance of 'mytype' TODO: I could not require this (only map() really needs it and __init__()) - """ # python libs import copy -import StringIO -import sys +from itertools import chain, imap, izip import os -import itertools +from sqlite3 import dbapi2 as sqlite +from StringIO import StringIO +import sys # rasmus libs from rasmus import util # table directives -DIR_TYPES = 1 +DIR_TYPES = 1 + # a special unique null type (more 'null' than None) NULL = object() @@ -67,160 +63,169 @@ class TableException (Exception): def __init__(self, errmsg, filename=None, lineno=None): msg = "" add_space = False - add_semicolon = False - + add_colon = False + if filename: msg += "%s" % filename add_space = True - add_semicolon = True - + add_colon = True + if lineno: - add_semicolon = True + add_colon = True if add_space: msg += " " msg += "line %d" % lineno - - if add_semicolon: + + if add_colon: msg += ": " - + msg = msg + errmsg - + Exception.__init__(self, msg) #=========================================================================== # Types handling -# def guess_type(text): - """Guesses the type of a value encoded in a string""" - - if text.isdigit(): + """Guess the type of a value encoded in a string.""" + + try: + int(text) return int + except: + pass try: float(text) return float except ValueError: pass - + try: str2bool(text) return bool except ValueError: pass - + return str def str2bool(text=None): - """Will parse every way manolis stores a boolean as a string""" - + """Parse a boolean stored as a string.""" + if text is None: # default value return False - - text2 = text.lower() - - if text2 == "false": + + text = text.lower() + + if text == "false": return False - elif text2 == "true": + elif text == "true": return True else: raise ValueError("unknown string for bool '%s'" % text) - - - -#============================================================================= - -_type_definitions = [["string", str], - ["unknown", str], # backwards compatiable name - ["str", str], # backwards compatiable name - ["int", int], - ["float", float], - ["bool", bool]] +_type_definitions = [ + ["string", str], + ["unknown", str], # backwards compatiable name + ["str", str], # backwards compatiable name + ["string", unicode], + ["int", int], + ["int", long], + ["float", float], + ["bool", bool], +] # NOTE: ordering of name-type pairs is important # the first occurrence of a type gives the perferred name for writing - -def parse_type(text): - for name, t in _type_definitions: - if text == name: - return t - raise Exception("unknown type '%s'" % text) +def parse_type(type_name): + """Parse a type name into a type.""" + for name, type_object in _type_definitions: + if type_name == name: + return type_object + raise Exception("unknown type '%s'" % type_name) -def format_type(t): - for name, t2 in _type_definitions: - if t == t2: - return name - raise Exception("unknown type '%s'" % t) +def format_type(type_object): + """Format a type into a type name.""" + for name, type_object2 in _type_definitions: + if type_object == type_object2: + return name + raise Exception("unknown type '%s'" % type_object) #=========================================================================== # Table class -# class Table (list): - """Class implementing the Portable Table Format""" + """A table of data""" + + def __init__(self, rows=None, + headers=None, + types={}, + filename=None, + nheaders=1): - def __init__(self, rows=None, - headers=None, - types={}, - filename=None): - # set table info self.headers = copy.copy(headers) self.types = copy.copy(types) - self.filename = filename self.comments = [] self.delim = "\t" - self.nheaders = 1 - - + self.nheaders = nheaders + self.filename = filename + # set data - if rows is not None: - it = iter(rows) - try: - first_row = it.next() - - # data is a list of dicts - if isinstance(first_row, dict): - self.append(first_row) - for row in it: - self.append(dict(row)) - - if self.headers is None: - self.headers = sorted(self[0].keys()) - - # data is a list of lists - elif isinstance(first_row, (list, tuple)): - if self.headers is None: - self.headers = range(len(first_row)) - self.nheaders = 0 - for row in itertools.chain([first_row], it): - self.append(dict(zip(self.headers, row))) - - - # set table info - for key in self.headers: - # guess any types not specified - if key not in self.types: - self.types[key] = type(self[0][key]) - - except StopIteration: - pass - - - + if rows: + self._set_data(rows) + + def _set_data(self, rows=[]): + """Set the table data from an iterable.""" + try: + # use first row to guess data style + rows = iter(rows) + first_row = rows.next() + except StopIteration: + # No data given + return + + if isinstance(first_row, dict): + # data is a list of dicts + # set default headers based on first row keys + if self.headers is None: + self.headers = sorted(first_row.keys()) + + # add data + self.extend(imap(dict, chain([first_row], rows))) + + elif isinstance(first_row, (list, tuple)): + # data is a list of lists + # use first row to determine headers + if self.nheaders == 0: + if self.headers is None: + self.headers = range(len(first_row)) + rows = chain([first_row], rows) + else: + self.headers = list(first_row) + + # add data + self.extend(dict(zip(self.headers, row)) for row in rows) + + # guess any types not specified + if len(self) > 0: + for key in self.headers: + row = self[0] + if key not in self.types: + self.types[key] = type(row[key]) + def clear(self, headers=None, delim="\t", nheaders=1, types=None): - """Clears the contents of the table""" - - self[:] = [] + """Clear the contents of the table.""" + + # clear table info self.headers = copy.copy(headers) if types is None: self.types = {} @@ -229,83 +234,67 @@ def clear(self, headers=None, delim="\t", nheaders=1, types=None): self.comments = [] self.delim = delim self.nheaders = nheaders - - + + # clear data + self[:] = [] + def new(self, headers=None): """ - return a new table with the same info but no data - - headers - if specified, only a subset of the headers will be copied + Return a new table with the same info but no data. + + headers: if specified, only a subset of the headers will be copied. """ - if headers is None: headers = self.headers - + tab = type(self)(headers=headers) - + tab.types = util.subdict(self.types, headers) tab.comments = copy.copy(self.comments) tab.delim = self.delim tab.nheaders = self.nheaders - + return tab - - + #=================================================================== # Input/Output - # - + def read(self, filename, delim="\t", nheaders=1, - headers=None, types=None, guess_types=True): - for row in self.read_iter(filename, delim=delim, nheaders=nheaders, - headers=headers, types=types, - guess_types=guess_types): - self.append(row) - - + headers=None, types=None, guess_types=True): + self.extend(self.read_iter( + filename, delim=delim, nheaders=nheaders, + headers=headers, types=types, + guess_types=guess_types)) + return self + def read_iter(self, filename, delim="\t", nheaders=1, - headers=None, types=None, guess_types=True): - """Reads a character delimited file and returns a list of dictionaries - - notes: - Lines that start with '#' are treated as comments and are skiped - Blank lines are skipped. - - If the first comment starts with '#Types:' the following tokens - are interpreted as the data type of the column and values in that - column are automatically converted. - - supported datatypes: - - string - - int - - float - - bool - - unknown (no conversion is done, left as a string) + headers=None, types=None, guess_types=True): + """ + Reads a character delimited file and yields a dict for each row. + Blank lines are skipped. Lines that start with a single '#' + are treated as comments. Lines starting with '##' are treated as + directives. """ - infile = util.open_stream(filename) - + # remember filename for later saving if isinstance(filename, str): self.filename = filename - # clear table self.clear(headers, delim, nheaders, types) - # temps for reading only - self.tmptypes = None + self._tmptypes = None + first_row = True - # line number for error reporting lineno = 0 - - + try: for line in infile: - line = line.rstrip() + line = line.rstrip('\n') lineno += 1 # skip blank lines @@ -320,7 +309,7 @@ def read_iter(self, filename, delim="\t", nheaders=1, # split row into tokens tokens = line.split(delim) - + # if no headers read yet, use this line as a header if not self.headers: # parse headers @@ -330,14 +319,15 @@ def read_iter(self, filename, delim="\t", nheaders=1, else: # default headers are numbers self.headers = range(len(tokens)) - assert len(tokens) == len(self.headers), tokens # populate types - if not self.types: - if self.tmptypes: - assert len(self.tmptypes) == len(self.headers) - self.types = dict(zip(self.headers, self.tmptypes)) + if first_row: + first_row = False + if self._tmptypes: + # use explicit types + assert len(self._tmptypes) == len(self.headers) + self.types = dict(zip(self.headers, self._tmptypes)) else: # default types if guess_types: @@ -347,37 +337,27 @@ def read_iter(self, filename, delim="\t", nheaders=1, else: for header in self.headers: self.types.setdefault(header, str) - + # parse data row = {} - for i in xrange(len(tokens)): - key = self.headers[i] - t = self.types[key] - if t is bool: - row[key] = str2bool(tokens[i]) - else: - row[key] = t(tokens[i]) - - # return completed row + for header, token in izip(self.headers, tokens): + type_object = self.types[header] + if type_object is bool: + type_object = str2bool + row[header] = type_object(token) + + # yield completed row yield row - - + except Exception, e: # report error in parsing input file raise TableException(str(e), self.filename, lineno) - #raise - - + # clear temps - del self.tmptypes - - raise StopIteration + del self._tmptypes - - def _parse_header(self, tokens): """Parse the tokens as headers""" - self.headers = tokens # check that headers are unique @@ -387,26 +367,26 @@ def _parse_header(self, tokens): raise TableException("Duplicate header '%s'" % header) check.add(header) - - - def write(self, filename=sys.stdout, delim="\t"): + def write(self, filename=sys.stdout, delim="\t", comments=False, + nheaders=None): """Write a table to a file or stream. - + If 'filename' is a string it will be opened as a file. If 'filename' is a stream it will be written to directly. """ - # remember filename for later saving if isinstance(filename, str): self.filename = filename - + out = util.open_stream(filename, "w") - - self.write_header(out, delim=delim) - + + self.write_header(out, delim=delim, comments=comments, + nheaders=(nheaders if nheaders is not None + else self.nheaders)) + # tmp variable types = self.types - + # write data for row in self: # code is inlined here for speed @@ -416,38 +396,38 @@ def write(self, filename=sys.stdout, delim="\t"): rowstr.append(types[header].__str__(row[header])) else: rowstr.append('') - print >>out, delim.join(rowstr) - - - def write_header(self, out=sys.stdout, delim="\t"): - # ensure all info is complete + out.write(delim.join(rowstr)) + out.write('\n') + + def write_header(self, out=sys.stdout, delim="\t", comments=False, + nheaders=None): + # ensure all info is complete. + # introspect types or use str by default. for key in self.headers: if key not in self.types: if len(self) > 0: self.types[key] = type(self[0][key]) else: self.types[key] = str - - + # ensure types are in directives if DIR_TYPES not in self.comments: - self.comments = [DIR_TYPES] + self.comments - + self.comments.insert(0, DIR_TYPES) # write comments - for line in self.comments: - if isinstance(line, str): - print >>out, line - else: - self._write_directive(line, out, delim) - - + if comments: + for line in self.comments: + if isinstance(line, str): + out.write(line) + out.write('\n') + else: + self._write_directive(line, out, delim) + # write header - if self.nheaders > 0: - print >>out, delim.join(self.headers) + if nheaders > 0: + out.write(delim.join(self.headers)) + out.write('\n') - - def write_row(self, out, row, delim="\t"): rowstr = [] types = self.types @@ -459,81 +439,64 @@ def write_row(self, out, row, delim="\t"): out.write(delim.join(rowstr)) out.write("\n") - # NOTE: back-compat - writeRow = write_row - - def save(self): - """Writes the table to the last used filename for the read() or write() - function""" - + """ + Writes the table to the last used filename. + """ if self.filename is not None: self.write(self.filename) else: raise Exception("Table has no filename") - - + #=================================================================== # Input/Output: Directives - # - + def _determine_directive(self, line): - if line.startswith("#Types:") or \ - line.startswith("#types:") or \ - line.startswith("##types:"): - # backwards compatible + if line.startswith("##types:"): return DIR_TYPES - else: return None - - - + def _read_directive(self, line): """Attempt to read a line with a directive""" - + directive = self._determine_directive(line) - if directive is None: return False - - rest = line[line.index(":")+1:] + + rest = line[line.index(":")+1:] self.comments.append(directive) - + if directive == DIR_TYPES: - self.tmptypes = map(parse_type, rest.rstrip().split(self.delim)) + self._tmptypes = map( + parse_type, rest.rstrip('\n').split(self.delim)) return True - else: return False - - + def _write_directive(self, line, out, delim): """Write a directive""" - + if line == DIR_TYPES: out.write("##types:" + delim.join(format_type(self.types[h]) for h in self.headers) + "\n") - + else: raise "unknown directive:", line - #=================================================================== # Table manipulation - # - + def add(self, **kargs): """Add a row to the table - + tab.add(col1=val1, col2=val2, col3=val3) """ self.append(kargs) - - + def add_col(self, header, coltype=None, default=NULL, pos=None, data=None): """Add a column to the table. You must populate column data yourself. - + header - name of the column coltype - type of the values in that column default - default value of the column @@ -542,72 +505,67 @@ def add_col(self, header, coltype=None, default=NULL, pos=None, data=None): # ensure header is unique if header in self.headers: raise Exception("header '%s' is already in table" % header) - + # default column position is last column if pos is None: pos = len(self.headers) - + # default coltype is guessed from data if coltype is None: if data is None: raise Exception("must specify data or coltype") else: coltype = type(data[0]) - + # default value is inferred from column type if default is NULL: default = coltype() - + # update table info self.headers.insert(pos, header) self.types[header] = coltype - + # add data if data is not None: for i in xrange(len(self)): self[i][header] = data[i] - def remove_col(self, *cols): """Removes a column from the table""" - + for col in cols: self.headers.remove(col) del self.types[col] - + for row in self: del row[col] - - + def rename_col(self, oldname, newname): """Renames a column""" - + # change header col = self.headers.index(oldname) - if col == -1: raise Exception("column '%s' is not in table" % oldname) - + self.headers[col] = newname - + # change info self.types[newname] = self.types[oldname] del self.types[oldname] - + # change data for row in self: row[newname] = row[oldname] del row[oldname] - def get_matrix(self, rowheader="rlabels"): """Returns mat, rlabels, clabels - + where mat is a copy of the table as a 2D list rlabels are the row labels clabels are the column labels """ - # get labels if rowheader is not None and rowheader in self.headers: rlabels = self.cget(rowheader) @@ -621,20 +579,32 @@ def get_matrix(self, rowheader="rlabels"): mat = [] for row in self: mat.append(util.mget(row, clabels)) - + return mat, rlabels, clabels - - + + def as_lists(self, cols=None): + """Iterate over rows as lists""" + if cols is None: + cols = self.headers + for row in self: + yield [row[header] for header in cols] + + def as_tuples(self, cols=None): + """Iterate over rows as lists""" + if cols is None: + cols = self.headers + for row in self: + yield tuple(row[header] for header in cols) + def filter(self, cond): """Returns a table with a subset of rows such that cond(row) == True""" tab = self.new() - + for row in self: if cond(row): tab.append(row) - - return tab + return tab def map(self, func, headers=None): """Returns a new table with each row mapped by function 'func'""" @@ -651,23 +621,21 @@ def map(self, func, headers=None): # try order new headers the same way as old headers headers = first_row.keys() lookup = util.list2lookup(self.headers) - top = len(headers) + top = len(headers) headers.sort(key=lambda x: (lookup.get(x, top), x)) - + tab = type(self)( - itertools.chain([first_row], (func(x) for x in self[1:])), + chain([first_row], (func(x) for x in self[1:])), headers=headers) tab.delim = self.delim tab.nheaders = self.nheaders - - return tab + return tab def uniq(self, key=None, col=None): """ Returns a copy of this table with consecutive repeated rows removed """ - tab = self.new() if len(self) == 0: @@ -689,68 +657,61 @@ def uniq(self, key=None, col=None): if key_row != last_row: tab.append(row) last_row = key_row - return tab - - + def groupby(self, key=None): - """Groups the row of the table into separate tables based on the + """Groups the row of the table into separate tables based on the function key(row). Returns a dict where the keys are the values retruned from key(row) and the values are tables. - + Ex: tab = Table([{'name': 'matt', 'major': 'CS'}, {'name': 'mike', 'major': 'CS'}, {'name': 'alex', 'major': 'bio'}]) lookup = tab.groupby(lambda x: x['major']) - + lookup ==> {'CS': Table([{'name': 'matt', 'major': 'CS'}, {'name': 'mike', 'major': 'CS'}]), 'bio': Table([{'name': 'alex', 'major': 'bio'}])} - + Can also use a column name such as: tab.groupby('major') - + """ - - groups = {} - + if isinstance(key, str): keystr = key key = lambda x: x[keystr] - + if key is None: raise Exception("must specify keyfunc") - - + for row in self: key2 = key(row) - + # add new table if necessary if key2 not in groups: groups[key2] = self.new() - + groups[key2].append(row) - + return groups - - + def lookup(self, *keys, **options): """Returns a lookup dict based on a column 'key' or multiple keys - + extra options: default=None uselast=False # allow multiple rows, just use last """ - options.setdefault("default", None) options.setdefault("uselast", False) lookup = util.Dict(dim=len(keys), default=options["default"]) uselast = options["uselast"] - + for row in self: keys2 = util.mget(row, keys) ptr = lookup @@ -759,51 +720,49 @@ def lookup(self, *keys, **options): if not uselast and keys2[-1] in ptr: raise Exception("duplicate key '%s'" % str(keys2[-1])) ptr[keys2[-1]] = row - + lookup.insert = False return lookup - - + def get(self, rows=None, cols=None): """Returns a table with a subset of the rows and columns""" - + # determine rows and cols if rows is None: rows = range(len(self)) - + if cols is None: cols = self.headers - + tab = self.new(cols) - - # copy data + + # copy data for i in rows: - row = {} + row = self[i] + row2 = {} for j in cols: - row[j] = self[i][j] - tab.append(row) - + row2[j] = row[j] + tab.append(row2) + return tab - - + def cget(self, *cols): """Returns columns of the table as separate lists""" - + ret = [] - + for col in cols: newcol = [] ret.append(newcol) - + for row in self: newcol.append(row[col]) - + if len(ret) == 1: return ret[0] - else: + else: return ret - def get_row(self, *rows): """Returns row(s) as list(s)""" @@ -816,23 +775,18 @@ def get_row(self, *rows): # return multiple rows (or zero) return [[self[i][j] for j in self.headers] for i in rows] - - - - - + def sort(self, cmp=None, key=None, reverse=False, col=None): """Sorts the table inplace""" - + if col is not None: key = lambda row: row[col] elif cmp is None and key is None: # sort by first column key = lambda row: row[self.headers[0]] - + list.sort(self, cmp=cmp, key=key, reverse=reverse) - - + def __getitem__(self, key): if isinstance(key, slice): # return another table if key is a slice @@ -841,155 +795,151 @@ def __getitem__(self, key): return tab else: return list.__getitem__(self, key) - - + def __getslice__(self, a, b): # for python version compatibility return self.__getitem__(slice(a, b)) - def __repr__(self): - s = StringIO.StringIO("w") + s = StringIO() self.write_pretty(s) return s.getvalue() - - + def write_pretty(self, out=sys.stdout, spacing=2): mat2, rlabels, clabels = self.get_matrix(rowheader=None) mat = [] - + # get headers mat.append(clabels) - + # get data mat.extend(mat2) - + util.printcols(mat, spacing=spacing, out=out) - def __str__(self): - s = StringIO.StringIO("w") + s = StringIO() self.write(s) return s.getvalue() - - - #=========================================================================== -# convenience functions -# +# Convenience functions def read_table(filename, delim="\t", headers=None, nheaders=1, types=None, guess_types=True): """Read a Table from a file written in PTF""" - + table = Table() table.read(filename, delim=delim, headers=headers, nheaders=nheaders, types=types, guess_types=guess_types) return table -# NOTE: back-compat -readTable = read_table - -def iter_table(filename, delim="\t", nheaders=1): - """Iterate through the rows of a Table from a file written in PTF""" - +def iter_table(filename, delim="\t", nheaders=1, types=None, guess_types=True): + """Iterate through the rows of a Table from a file.""" table = Table() - return table.read_iter(filename, delim=delim, nheaders=nheaders) + return table.read_iter(filename, delim=delim, nheaders=nheaders, + types=types, guess_types=guess_types) -# NOTE: back-compat -iterTable = iter_table +def histtab(items, headers=None, item="item", count="count", percent="percent", + cols=None): + """Make a histogram table.""" + if cols is not None: + # items is a Table. + items = items.as_tuples(cols=cols) + if headers is None: + headers = cols + [count, percent] + + if headers is None: + headers = [item, count, percent] -def histtab(items, headers=["item", "count", "percent"]): h = util.hist_dict(items) tab = Table(headers=headers) tot = float(sum(h.itervalues())) + hist_items = h.items() - if len(headers) == 2: - for key, val in h.items(): - tab.append({headers[0]: key, - headers[1]: val}) - - elif len(headers) == 3: - for key, val in h.items(): - tab.append({headers[0]: key, - headers[1]: val, - headers[2]: val / tot}) - + if cols is not None: + for key, val in hist_items: + row = dict(zip(cols, key)) + row[count] = val + tab.append(row) else: - raise Exception("Wrong number of headers (2 or 3 only)") - - tab.sort(col=headers[1], reverse=True) - + for key, val in hist_items: + tab.append({item: key, + count: val}) + + if percent is not None: + for i, (key, val) in enumerate(hist_items): + tab[i][percent] = val / tot + + tab.sort(col=count, reverse=True) + return tab -def join_tables(* args, **kwargs): +def join_tables(*args, **kwargs): """Join together tables into one table. Each argument is a tuple (table_i, key_i, cols_i) - - key_i is either a column name or a function that maps a + + key_i is either a column name or a function that maps a table row to a unique key """ - + if len(args) == 0: return Table() - + # determine common keys tab, key, cols = args[0] if isinstance(key, str): keys = tab.cget(key) - lookups = [tab.lookup(key)] + lookups = [tab.lookup(key)] else: keys = map(key, tab) lookup = {} for row in tab: lookup[key(row)] = row lookups = [lookup] - + keyset = set(keys) - for tab, key, cols in args[1:]: if isinstance(key, str): keyset = keyset & set(tab.cget(key)) - lookups.append(tab.lookup(key)) + lookups.append(tab.lookup(key)) else: keyset = keyset & set(map(key, tab)) lookup = {} for row in tab: lookup[key(row)] = row - + lookups.append(lookup) - + keys = filter(lambda x: x in keyset, keys) - - + # build new table if "headers" not in kwargs: headers = util.concat(*util.cget(args, 2)) else: headers = kwargs["headers"] tab = Table(headers=headers) - + for key in keys: row = {} for (tab2, key2, cols), lookup in zip(args, lookups): row.update(util.subdict(lookup[key], cols)) tab.append(row) - + return tab def showtab(tab, name='table'): """Show a table in a new xterm""" - + name = name.replace("'", "") tmp = util.tempfile(".", "tmp", ".tab") tab.write_pretty(file(tmp, "w")) @@ -999,13 +949,6 @@ def showtab(tab, name='table'): def sqlget(dbfile, query, maxrows=None, headers=None, headernum=False): """Get a table from a sqlite file""" - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError: - try: - from sqlite3 import dbapi2 as sqlite - except ImportError: - import sqlite # open database if hasattr(dbfile, "cursor"): @@ -1016,13 +959,13 @@ def sqlget(dbfile, query, maxrows=None, headers=None, headernum=False): con = sqlite.connect(dbfile, isolation_level="DEFERRED") cur = con.cursor() auto_close = True - + cur.execute(query) # infer header names if headers is None and not headernum: headers = [x[0] for x in cur.description] - + if maxrows is not None: lst = [] try: @@ -1041,14 +984,6 @@ def sqlget(dbfile, query, maxrows=None, headers=None, headernum=False): def sqlexe(dbfile, sql): - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError: - try: - from sqlite3 import dbapi2 as sqlite - except ImportError: - import sqlite - # open database if hasattr(dbfile, "cursor"): con = dbfile @@ -1071,7 +1006,7 @@ def sql_create_table(cur, table_name, tab, overwrite=True): def issubclass2(t1, t2): if type(t1) != type: return False - return issubclass(t1, t2) + return issubclass(t1, t2) # drop old table if needed if overwrite: @@ -1080,10 +1015,8 @@ def issubclass2(t1, t2): # build columns cols = [] for header in tab.headers: - t = tab.types[header] - if issubclass2(t, basestring): cols.append("%s TEXT" % header) elif issubclass2(t, int): @@ -1099,24 +1032,12 @@ def issubclass2(t1, t2): cols = ",".join(cols) # create table - cur.execute("""CREATE TABLE %s (%s);""" % - (table_name, cols)) + cur.execute("""CREATE TABLE %s (%s);""" % (table_name, cols)) - - -#def sql_insert_rows(cur, headers, types, rows def sqlput(dbfile, table_name, tab, overwrite=True, create=True): """Insert a table into a sqlite file""" - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError: - try: - from sqlite3 import dbapi2 as sqlite - except ImportError: - import sqlite - # open database if hasattr(dbfile, "cursor"): con = dbfile @@ -1132,18 +1053,17 @@ def sqlput(dbfile, table_name, tab, overwrite=True, create=True): filename = tab tab = Table() it = tab.read_iter(filename) - + try: # force a reading of the headers row = it.next() - rows = itertools.chain([row], it) + rows = chain([row], it) except StopIteration: rows = [] pass else: rows = tab - if create: sql_create_table(cur, table_name, tab, overwrite=overwrite) @@ -1151,20 +1071,18 @@ def sqlput(dbfile, table_name, tab, overwrite=True, create=True): def issubclass2(t1, t2): if type(t1) != type: return False - return issubclass(t1, t2) - + return issubclass(t1, t2) text = set() for header in tab.headers: t = tab.types[header] - + if issubclass2(t, basestring) or not ( - issubclass2(t, int) or - issubclass2(t, float) or - issubclass2(t, bool)): + issubclass2(t, int) or + issubclass2(t, float) or + issubclass2(t, bool)): text.add(header) - # insert rows for row in rows: vals = [] @@ -1175,52 +1093,49 @@ def issubclass2(t1, t2): vals.append(tab.types[header].__str__(row[header])) vals = ",".join(vals) cur.execute("INSERT INTO %s VALUES (%s);" % (table_name, vals)) - + con.commit() if auto_close: con.close() - + #=========================================================================== # Matrix functions -# def matrix2table(mat, rlabels=None, clabels=None, rowheader="rlabels"): """ convert a matrix into a table - + use table.get_matrix() to convert back to a matrix - """ - if clabels is None: clabels = range(len(mat[0])) nheaders = 0 else: nheaders = 1 - + if rlabels is None: tab = Table(headers=clabels) else: tab = Table(headers=[rowheader] + clabels) tab.nheaders = nheaders - - + for i, row in enumerate(mat): if rlabels is not None: row2 = {rowheader: rlabels[i]} else: row2 = {} - + for j in xrange(len(mat[i])): row2[clabels[j]] = mat[i][j] - + tab.append(row2) - + return tab -def write_matrix(filename, mat, rlabels=None, clabels=None, rowheader="rlabels"): +def write_matrix(filename, mat, rlabels=None, clabels=None, + rowheader="rlabels"): tab = matrix2table(mat, rlabels=rlabels, clabels=clabels, @@ -1228,238 +1143,7 @@ def write_matrix(filename, mat, rlabels=None, clabels=None, rowheader="rlabels") tab.write(filename) - def read_matrix(filename, rowheader="rlabels"): - tab = read_table(filename) + tab = read_table(filename) mat, rlabels, clabels = tab.get_matrix(rowheader=rowheader) return mat, rlabels, clabels - - -#=========================================================================== -# testing -# - - -if __name__ == "__main__": - import StringIO - - - - ################################################# - text="""\ -##types:str int int -# -# hello -# -name 0 1 -matt 123 3 -alex 456 2 -mike 789 1 -""" - - tab = read_table(StringIO.StringIO(text), nheaders=0) - - print tab - print tab[0][1] - - - tab.add_col('extra', bool, False) - for row in tab: - row['extra'] = True - - - - ################################################# - text="""\ -##types:str int int -name num num2 -matt 123 3 -alex 456 2 -mike 789 1 -""" - - tab = read_table(StringIO.StringIO(text)) - tab.sort() - - print repr(tab) - print tab - print tab.cget('name', 'num') - - - ################################################# - # guess types - text="""\ -name num num2 status -matt 11123 3.0 false -alex 456 2.0 true -mike 789 1.0 false -""" - - tab = read_table(StringIO.StringIO(text)) - tab.sort() - - print repr(tab) - - - -''' - ################################################# - # catch parse error - if 0: - text="""\ -##types:str int int -name num num -matt 123 0 -alex 456 2 -mike 789 1 -""" - - tab = readTable(StringIO.StringIO(text)) - tab.sort() - - print repr(tab) - print tab - print tab.cget('name', 'num') - - - ################################################# - # timing - if 0: - from rasmus import util - - text=["##types:" + "int\t" * 99 + "int", - "\t".join(map(str, range(100))) ] - - for i in range(10000): - text.append("1\t" * 99 + "1") - text = "\n".join(text) - - stream = StringIO.StringIO(text) - - util.tic("read table") - tab = readTable(stream) - util.toc() - - - ################################################# - # specialized types - if 1: - text="""\ -##types:str int strand_type -name num strand -matt 123 + -alex 456 - -mike 789 + -john 0 + -""" - - - - - class strand_type: - def __init__(self, text=None): - if text is None: - self.val = True - else: - if text == "+": - self.val = True - elif text == "-": - self.val = False - else: - raise Exception("cannot parse '%s' as strand_type" % - str(text)) - - - def __str__(self): - if self.val: - return "+" - else: - return "-" - - - def strand_parser(text=None): - if text is None: - return True - else: - if text == "+": - return True - elif text == "-": - return False - else: - raise Exception("cannot parse '%s' as strand_type" % - str(text)) - - def strand_formatter(val): - if val: - return "+" - else: - return "-" - - strand_type = TableType(strand_parser, strand_formatter) - - - stream = StringIO.StringIO(text) - tab = readTable(stream, type_lookup=[["strand_type", strand_type]]) - print tab.types - print tab - - ################################################# - # quoted strings - if 1: - text=\ -r"""##types:str bool quoted_string -name num blah -matt True hello\tthere -alex False hello\nthere -mike True hello\\there -john False hello\n\\\nthere -""" - - stream = StringIO.StringIO(text) - tab = readTable(stream) - print tab.types - print tab - - - ################################################# - # python data structures/code - if 1: - def eval2(text=None): - if text is None: - return None - else: - return eval(text) - - python_type = TableType(eval2, str) - - - - tab = Table(headers=["name", "list"], - types={"list": python_type}, - type_lookup=[["python", python_type]]) - - - tab.append({"name": "matt", "list": [1,2,3]}) - tab.append({"name": "mike", "list": [4,5,6]}) - tab.append({"name": "alex", "list": [7,8,9]}) - - tab.write() - - ################################################## - # join tables - if 1: - tab1 = Table([[0, 1, 2], - [1, 3, 4], - [2, 5, 6], - [3, 7, 8]], - headers=['a', 'b', 'c']) - tab2 = Table([[0, 6, 6], - [1, 7, 7], - [3, 8, 8]], - headers=['a2', 'b2', 'c2']) - - tab3 = joinTables((tab1, lambda x: x['a']+1, ['c', 'b']), (tab2, 'a2', ['b2'])) - - print tab3 - -''' diff --git a/argweaver/deps/rasmus/testing.py b/argweaver/deps/rasmus/testing.py index 166a7e63..3ca3b547 100644 --- a/argweaver/deps/rasmus/testing.py +++ b/argweaver/deps/rasmus/testing.py @@ -1,7 +1,11 @@ -import sys, os, shutil, unittest import optparse +import os +import shutil +import sys +import unittest from itertools import izip + from . import util from . import stats @@ -13,6 +17,7 @@ def clean_dir(path): if os.path.exists(path): shutil.rmtree(path) + def makedirs(path): if not os.path.exists(path): os.makedirs(path) @@ -24,8 +29,6 @@ def make_clean_dir(path): os.makedirs(path) - - def fequal(f1, f2, rel=.0001, eabs=1e-12): """assert whether two floats are approximately equal""" @@ -58,7 +61,7 @@ def integrate(func, a, b, step): def eq_sample_pdf(samples, pdf, ndivs=20, start=-util.INF, end=util.INF, pval=.05, step=None): - """Returns true if a sample matches a distribution""" + """Asserts a sample matches a probability density distribution""" if step is None: step = (max(samples) - min(samples)) / float(ndivs) @@ -71,13 +74,34 @@ def eq_sample_pdf(samples, pdf, assert p >= pval, p +def eq_sample_pmf(samples, pmf, pval=.05): + """Asserts a sample matches a probability mass distribution""" + import scipy.stats + + hist = util.hist_dict(samples) + total = sum(hist.itervalues()) + observed = [] + expected = [] + for sample, count in hist.iteritems(): + if count >= 5: + observed.append(count) + expected.append(pmf(sample) * total) + + chi2, p = scipy.stats.chisquare( + scipy.array(observed), scipy.array(expected)) + assert p >= pval, p + + _do_pause = True + + def pause(text="press enter to continue: "): """Pause until the user presses enter""" if _do_pause: sys.stderr.write(text) raw_input() + def set_pausing(enabled=True): global _do_pause _do_pause = enabled @@ -115,7 +139,6 @@ def test_main(): conf, args = o.parse_args() - if conf.list_tests: list_tests(1) return @@ -125,7 +148,6 @@ def test_main(): else: set_pausing(False) - # process unittest arguments argv = [sys.argv[0]] diff --git a/argweaver/deps/rasmus/treelib.py b/argweaver/deps/rasmus/treelib.py index 0ba9230b..28d28621 100644 --- a/argweaver/deps/rasmus/treelib.py +++ b/argweaver/deps/rasmus/treelib.py @@ -1,7 +1,7 @@ # -# Tree data structures +# Tree data structures # -# Contains special features for representing phylogeny. +# Contains special features for representing phylogeny. # See compbio.phylo for more. # # @@ -9,15 +9,13 @@ # python libs import copy -import math -import random import sys -import os import StringIO # rasmus libs try: from rasmus import util + util except ImportError: import util try: @@ -29,9 +27,11 @@ # ply parsing support try: from rasmus import treelib_parser + treelib_parser except ImportError: try: import treelib_parser + treelib_parser except ImportError: treelib_parser = None @@ -43,7 +43,7 @@ class TreeNode (object): """A class for nodes in a rooted Tree - + Contains fields for branch length 'dist' and custom data 'data' """ @@ -54,36 +54,34 @@ def __init__(self, name=None): self.dist = 0 self.data = {} - def __iter__(self): """Iterate through child nodes""" return iter(self.children) - - + def copy(self, parent=None, copyChildren=True): """Returns a copy of a TreeNode and all of its children""" - + node = TreeNode(self.name) node.name = self.name node.dist = self.dist node.parent = parent node.data = copy.copy(self.data) - + if copyChildren: for child in self.children: node.children.append(child.copy(node)) - + return node - + def is_leaf(self): """Returns True if the node is a leaf (no children)""" return len(self.children) == 0 - + def recurse(self, func, *args): """Applies a function 'func' to the children of a node""" for child in self.children: func(child, *args) - + def leaves(self): """Returns the leaves beneath the node in traversal order""" leaves = [] @@ -94,13 +92,13 @@ def walk(node): for child in node.children: walk(child) walk(self) - + return leaves - + def leaf_names(self): """Returns the leaf names beneath the node in traversal order""" return [x.name for x in self.leaves()] - + def write_data(self, out): """Writes the data of the node to the file stream 'out'""" out.write(str(self.dist)) @@ -111,12 +109,11 @@ def __repr__(self): return "" % self.name - class BranchData (object): """A class for managing branch specific data for a Tree - - By default, this class implements bootstrap data for TreeNode's. - + + By default, this class implements bootstrap data for TreeNode's. + To incorporate new kinds of branch data, do the following. Subclass this class (say, MyBranchData). Create Tree's with Tree(branch_data=MyBranchData()). This will ensure your new branch data @@ -133,19 +130,19 @@ def get_branch_data(self, node): return {"boot": node.data["boot"]} else: return {} - + def set_branch_data(self, node, data): """Set the branch specific data from 'data' to node.data""" if "boot" in data: node.data["boot"] = data["boot"] - + def split_branch_data(self, node): """Split a branch's data into two copies""" if "boot" in node.data: return {"boot": node.data["boot"]}, {"boot": node.data["boot"]} else: return {}, {} - + def merge_branch_data(self, data1, data2): """Merges the branch data from two neighboring branches into one""" if "boot" in data1 and "boot" in data2: @@ -153,13 +150,12 @@ def merge_branch_data(self, data1, data2): return {"boot": data1["boot"]} else: return {} - - + class Tree (object): """ Basic rooted tree - + Well suited for phylogenetic trees """ @@ -171,62 +167,55 @@ def __init__(self, nextname=1, branch_data=BranchData()): self.data = {} self.branch_data = branch_data - def copy(self): """Returns a copy of the tree""" - tree = Tree(nextname = self.nextname) - + tree = Tree(nextname=self.nextname) + # copy structure - if self.root != None: + if self.root is not None: # copy all nodes tree.root = self.root.copy() - + # set all names def walk(node): tree.nodes[node.name] = node for child in node.children: walk(child) walk(tree.root) - + # copy extra data tree.copy_data(self) tree.copy_node_data(self) - - return tree + return tree #========================================= # iterators - + def __iter__(self): """Iterate through nodes of tree""" return self.nodes.itervalues() - def __len__(self): """Returns number of nodes in tree""" return len(self.nodes) - def __getitem__(self, key): """Returns node by name""" return self.nodes[key] - def __setitem__(self, key, node): """Adds a node to the tree""" node.name = key self.add(node) - def __contains__(self, name): """Returns True if tree has node with name 'name'""" return name in self.nodes - def preorder(self, node=None, is_leaf=lambda x: x.is_leaf()): """Iterate through nodes in pre-order traversal""" - + if node is None: node = self.root @@ -240,10 +229,9 @@ def preorder(self, node=None, is_leaf=lambda x: x.is_leaf()): for child in reversed(node.children): queue.append(child) - def postorder(self, node=None, is_leaf=lambda x: x.is_leaf()): """Iterate through nodes in post-order traversal""" - + if node is None: node = self.root @@ -259,7 +247,6 @@ def postorder(self, node=None, is_leaf=lambda x: x.is_leaf()): yield node stack.pop() - def inorder(self, node=None, is_leaf=lambda x: x.is_leaf()): """Iterate through nodes with in-order traversal""" @@ -267,7 +254,7 @@ def inorder(self, node=None, is_leaf=lambda x: x.is_leaf()): node = self.root stack = [[node, 0]] - + while len(stack) > 0: node, i = stack[-1] @@ -282,25 +269,23 @@ def inorder(self, node=None, is_leaf=lambda x: x.is_leaf()): # left has been visited # yield current node then visit right yield node - + # recurse into children stack.append([node.children[i], 0]) stack[-2][1] += 1 else: stack.pop() - - + #============================= # structure functions - def make_root(self, name = None): + def make_root(self, name=None): """Create a new root node""" if name is None: name = self.new_name() self.root = TreeNode(name) return self.add(self.root) - def add(self, node): """Add a node to the tree Does not add node to any specific location (use add_child instead). @@ -308,7 +293,6 @@ def add(self, node): self.nodes[node.name] = node return node - def add_child(self, parent, child): """Add a child node to an existing node 'parent' in the tree""" assert parent != child @@ -318,57 +302,51 @@ def add_child(self, parent, child): parent.children.append(child) return child - def new_node(self, name=None): """Add a new node with name 'name' to the tree""" if name is None: name = self.new_name() return self.add(TreeNode(name)) - def remove(self, node): """ Removes a node from a tree. Notifies parent (if it exists) that node has been removed. """ - + if node.parent: node.parent.children.remove(node) del self.nodes[node.name] - - + def remove_tree(self, node): """ Removes subtree rooted at 'node' from tree. Notifies parent (if it exists) that node has been removed. """ - + def walk(node): if node.name in self.nodes: del self.nodes[node.name] for child in node.children: walk(child) walk(node) - + if node.parent: node.parent.children.remove(node) - - + def rename(self, oldname, newname): """Rename a node in the tree""" node = self.nodes[oldname] del self.nodes[oldname] self.nodes[newname] = node node.name = newname - - + def new_name(self): """Returns a new node name that should be unique in the tree""" name = self.nextname self.nextname += 1 return name - def unique_name(self, name, names, sep="_"): """Create a new unique name not already in names""" i = 1 @@ -379,19 +357,17 @@ def unique_name(self, name, names, sep="_"): names.add(name2) return name2 - def add_tree(self, parent, childTree): """Add a subtree to the tree.""" - + # Merge nodes and change the names of childTree names if they conflict # with existing names - self.merge_names(childTree) + self.merge_names(childTree) self.add_child(parent, childTree.root) - - + def replace_tree(self, node, childTree): """Remove node and replace it with the root of childTree""" - + # merge nodes and change the names of childTree names if they conflict # with existing names self.merge_names(childTree) @@ -401,12 +377,11 @@ def replace_tree(self, node, childTree): parent.children[index] = childTree.root childTree.root.parent = parent del self.nodes[node.name] - - + def merge_names(self, tree2): """Merge the node names from tree2 into this tree. Change any names that conflict""" - + for name in tree2.nodes: if name in self.nodes: name2 = self.new_name() @@ -418,14 +393,12 @@ def merge_names(self, tree2): if name >= self.nextname: self.nextname = name + 1 self.nodes[name] = tree2.nodes[name] - - + def clear(self): """Clear all nodes from tree""" self.nodes = {} self.root = None - - + def leaves(self, node=None): """Return the leaves of the tree in order""" if node is None: @@ -433,28 +406,24 @@ def leaves(self, node=None): if node is None: return [] return node.leaves() - - - def leaf_names(self, node = None): + + def leaf_names(self, node=None): """Returns the leaf names of the tree in order""" return map(lambda x: x.name, self.leaves(node)) - - + #=============================== # data functions def has_data(self, dataname): """Does the tree contain 'dataname' in its extra data""" return dataname in self.default_data - - + def copy_data(self, tree): """Copy tree data to another""" self.branch_data = tree.branch_data self.default_data = copy.copy(tree.default_data) self.data = copy.copy(tree.data) - - + def copy_node_data(self, tree): """Copy node data to another tree""" for name, node in self.nodes.iteritems(): @@ -462,14 +431,12 @@ def copy_node_data(self, tree): node.data = copy.copy(tree.nodes[name].data) self.set_default_data() - def set_default_data(self): """Set default values in each node's data""" for node in self.nodes.itervalues(): for key, val in self.default_data.iteritems(): node.data.setdefault(key, val) - - + def clear_data(self, *keys): """Clear tree data""" for node in self.nodes.itervalues(): @@ -479,12 +446,11 @@ def clear_data(self, *keys): for key in keys: if key in node.data: del node.data[key] - - + #====================================================================== # branch data functions - # forward branch data calles to branch data manager - + # forward branch data calles to branch data manager + def get_branch_data(self, node): """Returns branch specific data from a node""" return self.branch_data.get_branch_data(node) @@ -492,30 +458,28 @@ def get_branch_data(self, node): def set_branch_data(self, node, data): """Set the branch specific data from 'data' to node.data""" return self.branch_data.set_branch_data(node, data) - + def split_branch_data(self, node): """Split a branch's data into two copies""" return self.branch_data.split_branch_data(node) - + def merge_branch_data(self, data1, data2): - """Merges the branch data from two neighboring branches into one""" + """Merges the branch data from two neighboring branches into one""" return self.branch_data.merge_branch_data(data1, data2) - - + #======================================================================= # input and output - # - + def read_data(self, node, data): """Default data reader: reads optional bootstrap and branch length""" # also parse nhx comments data = read_nhx_data(node, data) - + if ":" in data: boot, dist = data.split(":") node.dist = float(dist) - + if len(boot) > 0: if boot.isdigit(): node.data["boot"] = int(boot) @@ -533,11 +497,10 @@ def read_data(self, node, data): # treat as name if data: node.name = data - - + def write_data(self, node): """Default data writer: writes optional bootstrap and branch length""" - + string = "" if "boot" in node.data and \ not node.is_leaf() and \ @@ -553,33 +516,29 @@ def write_data(self, node): string += ":%f" % node.dist return string - - + def read_newick(self, filename, readData=None): """ Reads newick tree format from a file stream - + You can specify a specialized node data reader with 'readData' """ return read_tree(filename, read_data=readData, tree=self) - def write(self, out=sys.stdout, writeData=None, oneline=False, rootData=False): """Write the tree in newick notation""" - self.write_newick(out, writeData=writeData, + self.write_newick(out, writeData=writeData, oneline=oneline, rootData=rootData) - - + def write_newick(self, out=sys.stdout, writeData=None, oneline=False, - rootData=False): + rootData=False): """Write the tree in newick notation""" - write_newick(self, util.open_stream(out, "w"), + write_newick(self, util.open_stream(out, "w"), writeData=writeData, oneline=oneline, rootData=rootData) - - + def get_one_line_newick(self, root_data=False, writeData=None): """Get a presentation of the tree in a oneline string newick format""" stream = StringIO.StringIO() @@ -588,7 +547,6 @@ def get_one_line_newick(self, root_data=False, writeData=None): return stream.getvalue() - #============================================================================ # Input/Output functions @@ -606,14 +564,14 @@ def read_newick(infile, read_data=None, tree=None): def iter_trees(treefile): """read multiple trees from a tree file""" - + infile = util.open_stream(treefile) yield read_tree(infile) try: while True: yield read_tree(infile) - except Exception, e: + except Exception: pass @@ -632,8 +590,7 @@ def iter_stream(infile): infile = iter_stream(infile) else: infile = iter(infile) - - running = True + word = [] for c in infile: if c == "": @@ -645,7 +602,7 @@ def iter_stream(infile): if word: yield "".join(word) word[:] = [] - + elif c in ";(),:[]": # special tokens if word: @@ -666,7 +623,7 @@ def iter_stream(infile): else: # word token word.append(c) - + if word: yield "".join(word) word[:] = [] @@ -680,7 +637,7 @@ def parse_newick(infile, read_data=None, tree=None): read_data -- an optional function for reading node data fields tree -- an optional tree to populate """ - + # node stack ancestors = [] @@ -706,7 +663,7 @@ def parse_newick(infile, read_data=None, tree=None): token = tokens.next() empty = False - if token == '(': # new branchset + if token == '(': # new branchset if data: read_data(node, "".join(data)) data = [] @@ -717,7 +674,7 @@ def parse_newick(infile, read_data=None, tree=None): ancestors.append(node) node = child - elif token == ',': # another branch + elif token == ',': # another branch if data: read_data(node, "".join(data)) data = [] @@ -729,16 +686,16 @@ def parse_newick(infile, read_data=None, tree=None): parent.children.append(child) node = child - elif token == ')': # optional name next + elif token == ')': # optional name next if data: read_data(node, "".join(data)) data = [] node = ancestors.pop() - elif token == ':': # optional length next + elif token == ':': # optional length next data.append(token) - elif token == ';': # end of tree + elif token == ';': # end of tree if data: read_data(node, "".join(data)) data = [] @@ -747,20 +704,17 @@ def parse_newick(infile, read_data=None, tree=None): else: if prev_token in '(,': node.name = token - + elif prev_token in '):': data.append(token) else: data.append(token) - + except StopIteration: if empty: raise Exception("Empty tree") - except Exception, e: - raise # Exception("Malformed newick: " + repr(e)) - # setup node names names = set() for node in nodes: @@ -776,19 +730,18 @@ def parse_newick(infile, read_data=None, tree=None): break tree.set_default_data() - return tree def write_newick(tree, out=sys.stdout, writeData=None, oneline=False, rootData=False): """Write the tree in newick notation""" - write_newick_node(tree, tree.root, util.open_stream(out, "w"), + write_newick_node(tree, tree.root, util.open_stream(out, "w"), writeData=writeData, oneline=oneline, rootData=rootData) - -def write_newick_node(tree, node, out=sys.stdout, + +def write_newick_node(tree, node, out=sys.stdout, depth=0, writeData=None, oneline=False, rootData=False): """Write the node in newick format to the out file stream""" @@ -810,7 +763,7 @@ def write_newick_node(tree, node, out=sys.stdout, else: out.write("(\n") for child in node.children[:-1]: - write_newick_node(tree, child, out, depth+1, + write_newick_node(tree, child, out, depth+1, writeData=writeData, oneline=oneline) if oneline: out.write(",") @@ -820,7 +773,7 @@ def write_newick_node(tree, node, out=sys.stdout, writeData=writeData, oneline=oneline) if oneline: out.write(")") - else: + else: out.write("\n" + (" " * depth) + ")") # don't print data for root node @@ -870,7 +823,7 @@ def read_newick_ply(filename, readData=None, tree=None): # get parse tree text = util.read_until(util.open_stream(filename), ";")[0] + ";" expr = treelib_parser.yacc.parse(text) - + # walk the parse tree and build the tree names = set() @@ -888,7 +841,7 @@ def walk(expr): if node.name is None: node.name = tree.new_name() - + # ensure unique name node.name = tree.unique_name(node.name, names) @@ -910,15 +863,15 @@ def walk(expr): return tree - + def read_newick_recursive(filename, tree=None): """ Reads a big newick file with a custom parser DEPRECATED """ - - infile = util.open_stream(filename) #file(filename) + + infile = util.open_stream(filename) opens = [0] names = set() @@ -928,20 +881,25 @@ def read_newick_recursive(filename, tree=None): def readchar(): while True: char = infile.read(1) - if not char or char not in " \t\n": break - if char == "(": opens[0] += 1 - if char == ")": opens[0] -= 1 + if not char or char not in " \t\n": + break + if char == "(": + opens[0] += 1 + if char == ")": + opens[0] -= 1 return char def read_until(chars): token = "" while True: - #char = readchar() while True: char = infile.read(1) - if not char or char not in " \t\n": break - if char == "(": opens[0] += 1 - if char == ")": opens[0] -= 1 + if not char or char not in " \t\n": + break + if char == "(": + opens[0] += 1 + if char == ")": + opens[0] -= 1 if char in chars or char == "": return token, char @@ -950,12 +908,14 @@ def read_until(chars): def read_dist(): word = "" while True: - #char = readchar() while True: char = infile.read(1) - if not char or char not in " \t\n": break - if char == "(": opens[0] += 1 - if char == ")": opens[0] -= 1 + if not char or char not in " \t\n": + break + if char == "(": + opens[0] += 1 + if char == ")": + opens[0] -= 1 if not char in "-0123456789.e": return float(word) @@ -965,12 +925,14 @@ def read_dist(): def read_name(): token = "" while True: - #char = readchar() while True: char = infile.read(1) - if not char or char not in " \t\n": break - if char == "(": opens[0] += 1 - if char == ")": opens[0] -= 1 + if not char or char not in " \t\n": + break + if char == "(": + opens[0] += 1 + if char == ")": + opens[0] -= 1 if char in ":)," or char == "": return token, char @@ -989,7 +951,7 @@ def read_item(): if char == ":": node.dist = read_dist() return node - else: + else: #word, char = read_until(":),") word, char = read_name() word = char1 + word.rstrip() @@ -1001,7 +963,6 @@ def read_item(): node.dist = read_dist() return node - def read_root(): word, char = read_until("(") @@ -1031,7 +992,7 @@ def read_parent_tree(treefile, labelfile=None, labels=None, tree=None): labels = util.read_strings(labelfile) elif labels is None: - nitems = (len(lines) + 1)/ 2 + nitems = (len(lines) + 1) / 2 labels = map(str, range(nitems)) tree.make_root() @@ -1051,7 +1012,7 @@ def read_parent_tree(treefile, labelfile=None, labels=None, tree=None): if parentid == -1: # keep track of all roots tree.add_child(tree.root, child) - else: + else: if not parentid in tree.nodes: parent = TreeNode(parentid) tree.add(parent) @@ -1080,7 +1041,7 @@ def read_parent_tree(treefile, labelfile=None, labels=None, tree=None): def write_parent_tree(treefile, tree, labels=None): """Writes tree to the parent array format""" - + ids = {} if labels is None: @@ -1100,7 +1061,7 @@ def walk(node): # build ptree array ptree = [0] * len(ids) for node, idname in ids.iteritems(): - if node.parent != None: + if node.parent is not None: ptree[idname] = ids[node.parent] else: ptree[idname] = -1 @@ -1117,16 +1078,17 @@ def parse_nhx_comment(comment): if "=" in pair: yield pair.split("=") + def format_nhx_comment(data): """Format a NHX comment""" return "[&&NHX:" + ":".join("%s=%s" % (k, v) for k, v in data.iteritems()) + "]" - + def parse_nhx_data(text): """Parse the data field of an NHX file""" data = None - + if "[" in text: data = {} i = text.find("[") @@ -1152,7 +1114,7 @@ def read_nhx_data(node, text): def write_nhx_data(node): """Write data function for writing th data field of an NHX file""" - + text = Tree().write_data(node) if node.data: text += format_nhx_comment(node.data) @@ -1164,8 +1126,9 @@ def write_nhx_data(node): def assert_tree(tree): """Assert that the tree data structure is internally consistent""" - + visited = set() + def walk(node): assert node.name in tree.nodes assert node.name not in visited @@ -1176,15 +1139,15 @@ def walk(node): assert child.parent == node node.recurse(walk) walk(tree.root) - - assert tree.root.parent is None - assert len(tree.nodes) == len(visited), "%d %d" % (len(tree.nodes), len(visited)) + assert tree.root.parent is None + assert len(tree.nodes) == len(visited), ( + "%d %d" % (len(tree.nodes), len(visited))) def lca(nodes): """Returns the Least Common Ancestor (LCA) of a list of nodes""" - + if len(nodes) == 1: return nodes[0] elif len(nodes) > 2: @@ -1193,17 +1156,17 @@ def lca(nodes): node1, node2 = nodes set1 = set([node1]) set2 = set([node2]) - + while True: if node1 in set2: return node1 if node2 in set1: return node2 - if node1.parent != None: + if node1.parent is not None: node1 = node1.parent - if node2.parent != None: + if node2.parent is not None: node2 = node2.parent - + set1.add(node1) set2.add(node2) else: @@ -1213,38 +1176,38 @@ def lca(nodes): def find_dist(tree, name1, name2): """Returns the branch distance between two nodes in a tree""" - if not name1 in tree.nodes or \ - not name2 in tree.nodes: + if (not name1 in tree.nodes or + not name2 in tree.nodes): raise Exception("nodes '%s' and '%s' are not in tree" % (name1, name2)) - + # find root path for node1 node1 = tree.nodes[name1] - path1 = [node1] + path1 = [node1] while node1 != tree.root: node1 = node1.parent path1.append(node1) - + # find root path for node2 node2 = tree.nodes[name2] path2 = [node2] while node2 != tree.root: node2 = node2.parent path2.append(node2) - + # find when paths diverge i = 1 while i <= len(path1) and i <= len(path2) and (path1[-i] == path2[-i]): i += 1 - + dist = 0 for j in range(i, len(path1)+1): dist += path1[-j].dist for j in range(i, len(path2)+1): dist += path2[-j].dist - + return dist - + def descendants(node, lst=None): """Return a list of all the descendants beneath a node""" @@ -1260,7 +1223,7 @@ def count_descendants(node, sizes=None): """Returns a dict with number of leaves beneath each node""" if sizes is None: sizes = {} - + if len(node.children) > 0: sizes[node] = 0 for child in node.children: @@ -1268,78 +1231,79 @@ def count_descendants(node, sizes=None): sizes[node] += sizes[child] else: sizes[node] = 1 - + return sizes def subtree(tree, node): """Return a copy of a subtree of 'tree' rooted at 'node'""" - + # make new tree - tree2 = Tree(nextname = tree.new_name()) - + tree2 = Tree(nextname=tree.new_name()) + # copy nodes and data tree2.root = node.copy() tree2.copy_data(tree) - + # add nodes def walk(node): tree2.add(node) node.recurse(walk) walk(tree2.root) - + return tree2 def max_disjoint_subtrees(tree, subroots): - """Returns a list of rooted subtrees with atmost one node from + """Returns a list of rooted subtrees with atmost one node from the list 'subroots' """ - + marks = {} - + # mark the path from each subroot to the root for subroot in subroots: ptr = subroot - while ptr != None: + while ptr is not None: lst = marks.setdefault(ptr, []) lst.append(subroot) ptr = ptr.parent # subtrees are those trees with nodes that have at most one mark subroots2 = [] + def walk(node): marks.setdefault(node, []) - if len(marks[node]) < 2 and \ - (not node.parent or len(marks[node.parent]) >= 2): + if (len(marks[node]) < 2 and + (not node.parent or len(marks[node.parent]) >= 2)): subroots2.append(node) node.recurse(walk) walk(tree.root) - + return subroots2 def tree2graph(tree): """Convert a tree to a graph data structure (sparse matrix)""" mat = {} - - # init all rows of adjacency matrix to + + # init all rows of adjacency matrix to for name in tree.nodes: mat[name] = {} - + for name, node in tree.nodes.iteritems(): for child in node.children: mat[name][child.name] = child.dist - + if node.parent: mat[name][node.parent.name] = node.dist - + return mat def graph2tree(mat, root, closedset=None): """Convert a graph to a tree data structure""" - + if closedset is None: closedset = set() tree = Tree() @@ -1353,38 +1317,38 @@ def walk(name): child_node = walk(child) child_node.dist = mat[name][child] tree.add_child(node, child_node) - return node + return node tree.root = walk(root) - + tree.nextname = max(name for name in tree.nodes if isinstance(name, int)) - + return tree def remove_single_children(tree, simplify_root=True): """ Remove all nodes from the tree that have exactly one child - + Branch lengths are added together when node is removed. """ - + # find single children removed = [node for node in tree if len(node.children) == 1 and node.parent] - + # actually remove children for node in removed: newnode = node.children[0] - + # add distance newnode.dist += node.dist - + # change parent and child pointers newnode.parent = node.parent index = node.parent.children.index(node) node.parent.children[index] = newnode - + # remove old node del tree.nodes[node.name] @@ -1396,21 +1360,20 @@ def remove_single_children(tree, simplify_root=True): tree.remove(oldroot) tree.root.parent = None tree.root.dist += oldroot.dist - - return removed + return removed def remove_exposed_internal_nodes(tree, leaves=None): """ Remove all leaves that were originally internal nodes - + leaves -- a list of original leaves that should stay - + if leaves is not specified, only leaves with strings as names will be kept """ - - if leaves != None: + + if leaves is not None: stay = set(leaves) else: # use the fact that the leaf name is a string to determine @@ -1419,7 +1382,7 @@ def remove_exposed_internal_nodes(tree, leaves=None): for leaf in tree.leaves(): if isinstance(leaf.name, basestring): stay.add(leaf) - + # post order traverse tree def walk(node): # keep a list of children to visit, since they may remove themselves @@ -1435,15 +1398,15 @@ def subtree_by_leaves(tree, leaves=None, keep_single=False, simplify_root=True): """ Remove any leaf not in leaves set - + leaves -- a list of leaves that should stay keep_single -- if False, remove all single child nodes simplify_root -- if True, basal branch is removed when removing single children nodes """ - - stay = set(leaves) - + + stay = set(leaves) + # post order traverse tree def walk(node): # keep a list of children to visit, since they may remove themselves @@ -1465,7 +1428,7 @@ def walk(node): def subtree_by_leaf_names(tree, leaf_names, keep_single=False, newCopy=False): """Returns a subtree with only the leaves specified""" - + if newCopy: tree = tree.copy() return subtree_by_leaves(tree, [tree.nodes[x] for x in leaf_names], @@ -1497,22 +1460,21 @@ def walk(node): return leaves walk(tree.root) - # reorder tree to match tree2 leaf_lookup = util.list2lookup(tree2.leaf_names()) def mean(lst): return sum(lst) / float(len(lst)) - - def walk(node): + + def walk2(node): if node.is_leaf(): return set([node.name]) else: leaf_sets = [] for child in node.children: - leaf_sets.append(walk(child)) + leaf_sets.append(walk2(child)) scores = [mean(util.mget(leaf_lookup, l)) for l in leaf_sets] rank = util.sortindex(scores) @@ -1524,7 +1486,7 @@ def walk(node): ret = ret.union(l) return ret - walk(tree.root) + walk2(tree.root) def set_tree_topology(tree, tree2): @@ -1533,10 +1495,9 @@ def set_tree_topology(tree, tree2): trees must have nodes with the same names """ - nodes = tree.nodes nodes2 = tree2.nodes - + for node in tree: node2 = nodes2[node.name] @@ -1545,7 +1506,7 @@ def set_tree_topology(tree, tree2): node.parent = nodes[node2.parent.name] else: node.parent = None - + # set children if node.is_leaf(): assert node2.is_leaf() @@ -1556,7 +1517,6 @@ def set_tree_topology(tree, tree2): tree.root = nodes[tree2.root.name] - #============================================================================= # Rerooting functions # @@ -1567,10 +1527,9 @@ def is_rooted(tree): return len(tree.root.children) <= 2 - def unroot(tree, newCopy=True): """Return an unrooted copy of tree""" - + if newCopy: tree = tree.copy() @@ -1586,7 +1545,7 @@ def unroot(tree, newCopy=True): nodes[0].dist = 0 tree.set_branch_data(nodes[0], {}) nodes[0].parent = None - + # replace root del tree.nodes[tree.root.name] tree.root = nodes[0] @@ -1597,16 +1556,15 @@ def reroot(tree, newroot, onBranch=True, newCopy=True): """ Change the rooting of a tree """ - + # TODO: remove newCopy (or assert newCopy=False) if newCopy: tree = tree.copy() - # handle trivial case - if (not onBranch and tree.root.name == newroot) or \ - (onBranch and newroot in [x.name for x in tree.root.children] and \ - len(tree.root.children) == 2): + if ((not onBranch and tree.root.name == newroot) or + (onBranch and newroot in [x.name for x in tree.root.children] and + len(tree.root.children) == 2)): return tree assert not onBranch or newroot != tree.root.name, "No branch specified" @@ -1616,7 +1574,7 @@ def reroot(tree, newroot, onBranch=True, newCopy=True): # handle trivial case if not onBranch and tree.root.name == newroot: return tree - + if onBranch: # add new root in middle of branch newNode = TreeNode(tree.new_name()) @@ -1627,12 +1585,12 @@ def reroot(tree, newroot, onBranch=True, newCopy=True): tree.set_branch_data(node1, rootdata1) newNode.dist = rootdist / 2.0 tree.set_branch_data(newNode, rootdata2) - + node2 = node1.parent node2.children.remove(node1) tree.add_child(newNode, node1) tree.add_child(node2, newNode) - + ptr = node2 ptr2 = newNode newRoot = newNode @@ -1641,11 +1599,10 @@ def reroot(tree, newroot, onBranch=True, newCopy=True): ptr2 = tree.nodes[newroot] ptr = ptr2.parent newRoot = ptr2 - + newRoot.parent = None - + # reverse parent child relationship of all nodes on path node1 to root - oldroot = tree.root nextDist = ptr2.dist nextData = tree.get_branch_data(ptr2) ptr2.dist = 0 @@ -1653,21 +1610,21 @@ def reroot(tree, newroot, onBranch=True, newCopy=True): nextPtr = ptr.parent ptr.children.remove(ptr2) tree.add_child(ptr2, ptr) - + tmp = ptr.dist tmpData = tree.get_branch_data(ptr) ptr.dist = nextDist tree.set_branch_data(ptr, nextData) nextDist = tmp nextData = tmpData - + ptr2 = ptr ptr = nextPtr - + if nextPtr is None: break tree.root = newRoot - + return tree @@ -1696,11 +1653,10 @@ def midpoint_root(tree): dists.append((tmp[-1][0] + tmp[-2][0], node, tmp[-1][2], tmp[-1][1], tmp[-2][2], tmp[-2][1])) - + maxdist, top, child1, leaf1, child2, leaf2 = max(dists) middist = maxdist / 2.0 - # find longer part of path if depths[child1][0] + child1.dist >= middist: ptr = leaf1 @@ -1709,7 +1665,7 @@ def midpoint_root(tree): # find branch that contains midpoint dist = 0.0 - while ptr != top: + while ptr != top: if ptr.dist + dist >= middist: # reroot tree reroot(tree, ptr.name, onBranch=True, newCopy=False) @@ -1723,9 +1679,8 @@ def midpoint_root(tree): dist += ptr.dist ptr = ptr.parent - - - assert 0 # shouldn't get here + + raise AssertionError("Could not find branch with midpoint") #============================================================================= @@ -1759,16 +1714,16 @@ def walk(node): t = walk(child) # ensure branch lengths are ultrametrix - if t2: + if t2: assert abs(t - t2)/t < esp, (node.name, t, t2) t2 = t times[node] = t return t + node.dist walk(root) - + return times -get_tree_timestamps = get_tree_ages # backwards compatiability +get_tree_timestamps = get_tree_ages # backwards compatiability def set_dists_from_ages(tree, times): @@ -1781,24 +1736,23 @@ def set_dists_from_ages(tree, times): node.dist = times[node.parent] - times[node] else: node.dist = 0.0 -set_dists_from_timestamps = set_dists_from_ages # backwards compatiability +set_dists_from_timestamps = set_dists_from_ages # backwards compatiability def check_ages(tree, times): """Asserts that timestamps are consistent with tree""" - + for node in tree: if node.parent: - if times[node.parent] - times[node] < 0.0 or \ - abs(((times[node.parent] - times[node]) - - node.dist)/node.dist) > .001: + if (times[node.parent] - times[node] < 0.0 or + abs(((times[node.parent] - times[node]) - + node.dist)/node.dist) > .001): draw_tree_names(tree, maxlen=7, minlen=7) util.printcols([(a.name, b) for a, b in times.items()]) print print node.name, node.dist, times[node.parent] - times[node] raise Exception("negative time span") -check_timestamps = check_ages # backwards compatiability - +check_timestamps = check_ages # backwards compatiability #============================================================================= @@ -1809,13 +1763,13 @@ def tree2parent_table(tree, data_cols=[]): This parent table will have a special numbering for the internal nodes, such that their id is also their row in the table. - + parent table is a standard format of the Compbio Lab as of 02/01/2007. It is a list of triples (node_name, parent_name, dist, ...) - + * parent_name indicates the parent of the node. If the node is a root (has no parent), then parent_name is -1 - + * dist is the distance between the node and its parent. * additional columns can be added using the data_cols argument. The @@ -1833,7 +1787,7 @@ def tree2parent_table(tree, data_cols=[]): for col in data_cols: row.append(node.data[col]) ptable.append(row) - + return ptable @@ -1842,10 +1796,10 @@ def parent_table2tree(ptable, data_cols=[], convert_names=True): if convert_names is True, names that are strings that look like integers are converted to ints. - + See tree2parent_table for details """ - + tree = Tree() parents = {} @@ -1857,7 +1811,7 @@ def parent_table2tree(ptable, data_cols=[], convert_names=True): name = int(name) if parent.isdigit() or parent == "-1": parent = int(parent) - + node = TreeNode(name) node.dist = row[2] tree.add(node) @@ -1865,7 +1819,7 @@ def parent_table2tree(ptable, data_cols=[], convert_names=True): for col, val in zip(data_cols, row[3:]): node.data[col] = val - + # link up parents for node, parent_name in parents.iteritems(): if parent_name == -1: @@ -1873,9 +1827,8 @@ def parent_table2tree(ptable, data_cols=[], convert_names=True): else: parent = tree.nodes[parent_name] tree.add_child(parent, node) - - return tree + return tree def tree2parent_table_ordered(tree, leaf_names=None): @@ -1883,28 +1836,28 @@ def tree2parent_table_ordered(tree, leaf_names=None): This parent table will have a special numbering for the internal nodes, such that their id is also their row in the table. - + parent table is a standard format of the Compbio Lab as of 02/01/2007. It is a list of triples (node_name, parent_name, dist) - + * If the node is a leaf node_name is the leaf name (a string) * If the node is internal node_name is an int representing which row (0-based) the node is in the table. - - * parent_name indicates the parent of the node. If the parent is root, a + + * parent_name indicates the parent of the node. If the parent is root, a -1 is used as the parent_name. - + * dist is the distance between the node and its parent. - + Arguments: - leaf_names -- specifies that a tree with only a subset of the leaves + leaf_names -- specifies that a tree with only a subset of the leaves should be used - + NOTE: root is not given a row, because root does not have a distance the nodeid of the root is -1 """ - - if leaf_names != None: + + if leaf_names is not None: tree = subtree_by_leaf_names(tree, leaf_names, newCopy=True) else: leaf_names = tree.leaf_names() @@ -1917,7 +1870,7 @@ def tree2parent_table_ordered(tree, leaf_names=None): nodeids[tree.nodes[leaf]] = leaf nodes.append(tree.nodes[leaf]) nodeid += 1 - + # assign a numbering to the internal nodes for node in tree: if node.is_leaf(): @@ -1933,79 +1886,74 @@ def tree2parent_table_ordered(tree, leaf_names=None): parentTable = [] for node in nodes: parentTable.append([nodeids[node], nodeids[node.parent], node.dist]) - + return parentTable - + def parent_table2tree_ordered(ptable): """Converts a parent table to a Tree - + See tree2parentTable for details """ - # TODO: allow named internal nodes - + tree = Tree() - + # create nodes maxint = 0 - for name, parent_name, dist in parentTable: + for name, parent_name, dist in ptable: node = TreeNode(name) node.dist = dist tree.add(node) - + if isinstance(name, int): maxint = max(name, maxint) - + # make a root node tree.nextname = maxint + 1 tree.make_root() # link up parents - for name, parent_name, dist in parentTable: + for name, parent_name, dist in ptable: if parent_name == -1: parent = tree.root else: parent = tree.nodes[parent_name] tree.add_child(parent, tree.nodes[name]) - + return tree def write_parent_table(ptable, out=sys.stdout): """Writes a parent table to out - + out can be a filename or file stream """ - out = util.open_stream(out, "w") for row in ptable: out.write("\t".join(map(str, row)) + "\n") - def read_parent_table(filename): """Reads a parent table from the file 'filename' - + filename can also be an open file stream """ - infile = util.open_stream(filename) ptable = [] - + for line in infile: row = line.rstrip("\n").split("\t") name, parent, dist = row[:3] - + if name.is_digit(): name = int(name) if parent.is_digit() or parent == "-1": parent = int(parent) - + ptable.append([name, parent, float(dist)] + row[3:]) - - return ptable + return ptable #============================================================================= @@ -2013,17 +1961,17 @@ def read_parent_table(filename): def make_ptree(tree): """Make parent tree array from tree""" - + nodes = [] nodelookup = {} ptree = [] - + def walk(node): for child in node.children: walk(child) nodes.append(node) walk(tree.root) - + def leafsort(a, b): if a.is_leaf(): if b.is_leaf(): @@ -2035,81 +1983,80 @@ def leafsort(a, b): return 1 else: return 0 - + # bring leaves to front nodes.sort(cmp=leafsort) nodelookup = util.list2lookup(nodes) - + for node in nodes: if node == tree.root: ptree.append(-1) else: ptree.append(nodelookup[node.parent]) - + assert nodes[-1] == tree.root - + return ptree, nodes, nodelookup - #============================================================================= # Tree visualization - + def layout_tree(tree, xscale, yscale, minlen=-util.INF, maxlen=util.INF, - rootx=0, rooty=0): + rootx=0, rooty=0): """\ Determines the x and y coordinates for every branch in the tree. - + Branch lengths are determined by node.dist - """ - """ - /----- ] + + """ + /----- ] | ] nodept[node] ---+ node ] | | \--------- """ - + # first determine sizes and nodepts coords = {} sizes = {} # number of descendants (leaves have size 1) nodept = {} # distance between node y-coord and top bracket y-coord + def walk(node): # calculate new y-coordinate for node - + # compute node sizes sizes[node] = 0 for child in node.children: sizes[node] += walk(child) - + if node.is_leaf(): sizes[node] = 1 nodept[node] = yscale - 1 else: top = nodept[node.children[0]] - bot = (sizes[node] - sizes[node.children[-1]])*yscale + \ - nodept[node.children[-1]] + bot = ((sizes[node] - sizes[node.children[-1]]) * yscale + + nodept[node.children[-1]]) nodept[node] = (top + bot) / 2.0 - + return sizes[node] walk(tree.root) - + # determine x, y coordinates - def walk(node, x, y): - xchildren = x+min(max(node.dist*xscale, minlen), maxlen) + def walk2(node, x, y): + xchildren = x+min(max(node.dist*xscale, minlen), maxlen) coords[node] = [xchildren, y + nodept[node]] - + if not node.is_leaf(): ychild = y for child in node.children: - walk(child, xchildren, ychild) + walk2(child, xchildren, ychild) ychild += sizes[child] * yscale - walk(tree.root, rootx, rooty) - - return coords + walk2(tree.root, rootx, rooty) + return coords def layout_tree_hierarchical(tree, xscale, yscale, @@ -2118,59 +2065,61 @@ def layout_tree_hierarchical(tree, xscale, yscale, use_dists=True): """\ Determines the x and y coordinates for every branch in the tree. - + Leaves are drawn to line up. Best used for hierarchical clustering. """ - + """ - /----- ] + /----- ] | ] nodept[node] ---+ node ] | | \--------- """ - + # first determine sizes and nodepts coords = {} sizes = {} # number of descendants (leaves have size 1) depth = {} # how deep in tree is node nodept = {} # distance between node y-coord and top bracket y-coord + def walk(node): # calculate new y-coordinate for node - + # recurse: compute node sizes - sizes[node] = 0 + sizes[node] = 0 for child in node.children: sizes[node] += walk(child) - + if node.is_leaf(): sizes[node] = 1 nodept[node] = yscale - 1 depth[node] = 0 else: top = nodept[node.children[0]] - bot = (sizes[node] - sizes[node.children[-1]])*yscale + \ - nodept[node.children[-1]] + bot = ((sizes[node] - sizes[node.children[-1]]) * yscale + + nodept[node.children[-1]]) nodept[node] = (top + bot) / 2.0 depth[node] = max(depth[child] + 1 for child in node.children) - + return sizes[node] walk(tree.root) - + # determine x, y coordinates maxdepth = depth[tree.root] - def walk(node, x, y): + + def walk2(node, x, y): xchildren = x + xscale * (maxdepth - depth[node]) coords[node] = [xchildren, y + nodept[node]] - + if not node.is_leaf(): ychild = y for child in node.children: - walk(child, x, ychild) + walk2(child, x, ychild) ychild += sizes[child] * yscale - walk(tree.root, rootx, rooty) - + walk2(tree.root, rootx, rooty) + return coords @@ -2179,7 +2128,6 @@ def layout_tree_vertical(layout, offset=None, root=0, leaves=None, """ Make layout vertical """ - if offset is None: if leaves is not None: for node in layout: @@ -2197,7 +2145,6 @@ def layout_tree_vertical(layout, offset=None, root=0, leaves=None, return layout - #============================================================================= # Tree color map @@ -2216,16 +2163,16 @@ def walk(node): node.color = color_mix(colors) walk(tree.root) return func - - + + def color_mix(colors): """Mixes together several color vectors into one""" - + sumcolor = [0, 0, 0] for c in colors: sumcolor[0] += c[0] sumcolor[1] += c[1] - sumcolor[2] += c[2] + sumcolor[2] += c[2] for i in range(3): sumcolor[i] /= float(len(colors)) return sumcolor @@ -2233,7 +2180,7 @@ def color_mix(colors): def make_expr_mapping(maps, default_color=(0, 0, 0)): """Returns a function that maps strings matching an expression to a value - + maps -- a list of pairs (expr, value) """ @@ -2245,16 +2192,16 @@ def make_expr_mapping(maps, default_color=(0, 0, 0)): exacts[key] = val else: exps.append((key, val)) - + # create mapping function def mapping(key): if key in exacts: - return exacts[key] - + return exacts[key] + # return default color if not isinstance(key, str): return default_color - + # eval expressions first in order of appearance for exp, val in exps: if exp[-1] == "*": @@ -2263,23 +2210,23 @@ def mapping(key): elif exp[0] == "*": if key.endswith(exp[1:]): return val - + raise Exception("Cannot map key '%s' to any value" % key) return mapping def read_tree_color_map(filename): """Reads a tree colormap from a file""" - + infile = util.open_stream(filename) maps = [] - + for line in infile: expr, red, green, blue = line.rstrip().split("\t") maps.append([expr, map(float, (red, green, blue))]) - + name2color = make_expr_mapping(maps) - + def leafmap(node): return name2color(node.name) @@ -2287,80 +2234,78 @@ def leafmap(node): #========================================================================= -# Draw Tree ASCII art +# Draw Tree ASCII art def draw_tree(tree, labels={}, scale=40, spacing=2, out=sys.stdout, - canvas=None, x=0, y=0, display=True, labelOffset=-1, - minlen=1,maxlen=10000): + canvas=None, x=0, y=0, display=True, labelOffset=-1, + minlen=1, maxlen=10000): """ Print a ASCII Art representation of the tree """ if canvas is None: canvas = textdraw.TextCanvas() - + xscale = scale yscale = spacing - # determine node sizes sizes = {} nodept = {} + def walk(node): if node.is_leaf(): sizes[node] = 1 - nodept[node] = yscale - 1 + nodept[node] = yscale - 1 else: sizes[node] = 0 for child in node.children: sizes[node] += walk(child) if not node.is_leaf(): top = nodept[node.children[0]] - bot = (sizes[node] - sizes[node.children[-1]])*yscale + \ - nodept[node.children[-1]] + bot = ((sizes[node] - sizes[node.children[-1]]) * yscale + + nodept[node.children[-1]]) nodept[node] = (top + bot) / 2 return sizes[node] walk(tree.root) - - - def walk(node, x, y): + + def walk2(node, x, y): # calc coords - xchildren = int(x+min(max(node.dist*xscale,minlen),maxlen)) - + xchildren = int(x + min(max(node.dist * xscale, minlen), maxlen)) + # draw branch canvas.line(x, y+nodept[node], xchildren, y+nodept[node], '-') if node.name in labels: branchlen = xchildren - x lines = str(labels[node.name]).split("\n") labelwidth = max(map(len, lines)) - - labellen = min(labelwidth, - max(int(branchlen-1),0)) - canvas.text(x + 1 + (branchlen - labellen)/2., - y+nodept[node]+labelOffset, + + labellen = min(labelwidth, + max(int(branchlen - 1), 0)) + canvas.text(x + 1 + (branchlen - labellen)/2., + y+nodept[node]+labelOffset, labels[node.name], width=labellen) - + if node.is_leaf(): - canvas.text(xchildren +1, y+yscale-1, str(node.name)) + canvas.text(xchildren + 1, y + yscale - 1, str(node.name)) else: top = y + nodept[node.children[0]] - bot = y + (sizes[node]-sizes[node.children[-1]]) * yscale + \ - nodept[node.children[-1]] - + bot = (y + (sizes[node]-sizes[node.children[-1]]) * yscale + + nodept[node.children[-1]]) + # draw children canvas.line(xchildren, top, xchildren, bot, '|') - + ychild = y for child in node.children: - walk(child, xchildren, ychild) + walk2(child, xchildren, ychild) ychild += sizes[child] * yscale - canvas.set(xchildren, y+nodept[node], '+') canvas.set(xchildren, top, '/') canvas.set(xchildren, bot, '\\') canvas.set(x, y+nodept[node], '+') - walk(tree.root, x+0, 0) - + walk2(tree.root, x+0, 0) + if display: canvas.display(out) @@ -2369,7 +2314,7 @@ def draw_tree_lens(tree, *args, **kargs): labels = {} for node in tree.nodes.values(): labels[node.name] = "%f" % node.dist - + draw_tree(tree, labels, *args, **kargs) @@ -2386,8 +2331,9 @@ def draw_tree_boot_lens(tree, *args, **kargs): if isinstance(node.data["boot"], int): labels[node.name] = "(%d) %f" % (node.data["boot"], node.dist) else: - labels[node.name] = "(%.2f) %f" % (node.data["boot"], node.dist) - + labels[node.name] = "(%.2f) %f" % ( + node.data["boot"], node.dist) + draw_tree(tree, labels, *args, **kargs) @@ -2396,7 +2342,7 @@ def draw_tree_names(tree, *args, **kargs): for node in tree.nodes.values(): if not node.is_leaf(): labels[node.name] = "%s" % node.name - + draw_tree(tree, labels, *args, **kargs) @@ -2406,10 +2352,7 @@ def draw_tree_name_lens(tree, *args, **kargs): if not node.is_leaf(): labels[node.name] = "%s " % node.name else: - labels[node.name] ="" + labels[node.name] = "" labels[node.name] += "%f" % node.dist - - draw_tree(tree, labels, *args, **kargs) - - + draw_tree(tree, labels, *args, **kargs) diff --git a/argweaver/deps/rasmus/util.py b/argweaver/deps/rasmus/util.py index 5e3e292e..3c1d4b5f 100644 --- a/argweaver/deps/rasmus/util.py +++ b/argweaver/deps/rasmus/util.py @@ -2,15 +2,14 @@ Common Utilities - file: util.py + file: util.py authors: Matt Rasmussen date: 11/30/2005 - - Provides basic functional programming functions for manipulating lists and + + Provides basic functional programming functions for manipulating lists and dicts. Also provides common utilities (timers, plotting, histograms) - -""" +""" # python libs @@ -30,14 +29,13 @@ # Note: I had trouble using 1e1000 directly, because bytecode had trouble # representing infinity (possibly) -INF = float("1e1000") - +INF = float("1e1000") class Bundle (dict): """ A small class for creating a closure of variables - handy for nested functions that need to assign to variables in an + handy for nested functions that need to assign to variables in an outer scope Example: @@ -47,50 +45,45 @@ def func1(): def func2(): this.var1 += 1 func2() - print this.var1 + print this.var1 func1() - + will produce: 1 - """ - def __init__(self, **variables): for key, val in variables.iteritems(): setattr(self, key, val) dict.__setitem__(self, key, val) - + def __setitem__(self, key, val): setattr(self, key, val) dict.__setitem__(self, key, val) - class Dict (dict): """My personal nested Dictionary with default values""" - - + def __init__(self, items=None, dim=1, default=None, insert=True): """ - items -- items to initialize Dict (can be dict, list, iter) - dim -- number of dimensions of the dictionary - default -- default value of a dictionary item - insert -- if True, insert missing keys + items: items to initialize Dict (can be dict, list, iter) + dim: number of dimensions of the dictionary + defaul: default value of a dictionary item + insert: if True, insert missing keys """ - + if items is not None: dict.__init__(self, items) else: dict.__init__(self) - + self._dim = dim self._null = default self._insert = insert - + # backwards compatiability self.data = self - - + def __getitem__(self, i): if not i in self: if self._dim > 1: @@ -102,16 +95,15 @@ def __getitem__(self, i): return ret return dict.__getitem__(self, i) - def has_keys(self, *keys): if len(keys) == 0: return True elif len(keys) == 1: - return dict.has_key(self, keys[0]) + return keys[0] in self else: - return dict.has_key(self, keys[0]) and \ - self[keys[0]].has_keys(*keys[1:]) - + return (keys[0] in self and + self[keys[0]].has_keys(*keys[1:])) + def write(self, out=sys.stdout): def walk(node, path): if node.dim == 1: @@ -123,24 +115,23 @@ def walk(node, path): else: for i in node: walk(node[i], path + [i]) - + print >>out, "< DictMatrix " walk(self, []) print >>out, ">" - class PushIter (object): """Wrap an iterator in another iterator that allows one to push new items onto the front of the iteration stream""" - + def __init__(self, it): self._it = iter(it) self._queue = [] def __iter__(self): return self - + def next(self): """Returns the next item in the iteration stream""" if len(self._queue) > 0: @@ -161,12 +152,10 @@ def peek(self, default=None): self.push(next) return next - - #============================================================================= -# list and dict functions for functional programming +# List and dict functions for functional programming def equal(* vals): """Returns True if all arguments are equal""" @@ -180,8 +169,7 @@ def equal(* vals): def remove(lst, *vals): - """Returns a copy of list 'lst' with values 'vals' removed - """ + """Returns a copy of list 'lst' with values 'vals' removed""" delset = set(vals) return [i for i in lst if i not in delset] @@ -191,15 +179,14 @@ def remove(lst, *vals): def reverse(lst): - """Returns a reversed copy of a list - """ + """Returns a reversed copy of a list""" lst2 = list(lst) lst2.reverse() return lst2 def replace(lst, old_item, new_item, replace_all=False): - """Replace an item in a list""" + """Replace an item in a list.""" if replace_all: for i in range(len(lst)): if lst[i] == old_item: @@ -210,15 +197,15 @@ def replace(lst, old_item, new_item, replace_all=False): def cget(mat, *i): - """Returns the column(s) '*i' of a 2D list 'mat' - - mat -- matrix or 2D list - *i -- columns to extract from matrix - - NOTE: If one column is given, the column is returned as a list. - If multiple columns are given, a list of columns (also lists) is returned - """ - + """ + Returns the column(s) '*i' of a 2D list 'mat' + + mat: matrix or 2D list + *i: columns to extract from matrix + + NOTE: If one column is given, the column is returned as a list. + If multiple columns are given, a list of columns (also lists) is returned + """ if len(i) == 1: return [row[i[0]] for row in mat] else: @@ -227,18 +214,16 @@ def cget(mat, *i): def mget(lst, ind): - """Returns a list 'lst2' such that lst2[i] = lst[ind[i]] - - Or in otherwords, get the subsequence of 'lst' """ - return [lst[i] for i in ind] + Returns a list 'lst2' such that lst2[i] = lst[ind[i]] + In otherwords, get the subsequence of 'lst'. + """ + return [lst[i] for i in ind] def concat(* lists): - """Concatenates several lists into one - """ - + """Concatenates several lists into one.""" lst = [] for l in lists: lst.extend(l) @@ -248,18 +233,17 @@ def concat(* lists): def flatten(lst, depth=INF): """ Flattens nested lists/tuples into one list - + depth -- specifies how deep flattening should occur """ - flat = [] - + for elm in lst: if hasattr(elm, "__iter__") and depth > 0: flat.extend(flatten(elm, depth-1)) else: flat.append(elm) - + return flat @@ -280,13 +264,12 @@ def subdict(dic, keys): def revdict(dic, allowdups=False): """ - Reverses a dict 'dic' such that the keys become values and the + Reverses a dict 'dic' such that the keys become values and the values become keys. - - allowdups -- if True, one of several key-value pairs with the same value + + allowdups -- if True, one of several key-value pairs with the same value will be arbitrarily choosen. Otherwise an expection is raised """ - dic2 = {} if allowdups: for key, val in dic.iteritems(): @@ -295,31 +278,31 @@ def revdict(dic, allowdups=False): for key, val in dic.iteritems(): assert key not in dic2, "duplicate value '%s' in dict" % val dic2[val] = key - + return dic2 def list2lookup(lst): """ - Creates a dict where each key is lst[i] and value is i + Create a dict where each key is lst[i] and value is i. """ return dict((elm, i) for i, elm in enumerate(lst)) def mapdict(dic, key=lambda x: x, val=lambda x: x): """ - Creates a new dict where keys and values are mapped + Create a new dict where keys and values are mapped. """ dic2 = {} for k, v in dic.iteritems(): dic2[key(k)] = val(v) - return dic2 def mapwindow(func, size, lst): - """Apply a function 'func' to a sliding window of size 'size' within - a list 'lst'""" + """ + Apply a function 'func' to a sliding window of size 'size'. + """ lst2 = [] lstlen = len(lst) radius = int(size // 2) @@ -332,19 +315,19 @@ def mapwindow(func, size, lst): def groupby(func, lst, multi=False): - """Places i and j of 'lst' into the same group if func(i) == func(j). - - func -- is a function of one argument that maps items to group objects - lst -- is a list of items - multi -- if True, func must return a list of keys (key1, ..., keyn) for - item a. groupby will return a nested dict 'dct' such that - dct[key1]...[keyn] == a - - returns: - a dictionary such that the keys are groups and values are items found in - that group - """ - + """ + Places i and j of 'lst' into the same group if func(i) == func(j). + + func: a function of one argument that maps items to group objects + lst: a list of items + multi: if True, func must return a list of keys (key1, ..., keyn) for + item a. groupby will return a nested dict 'dct' such that + dct[key1]...[keyn] == a + + Returns: + a dictionary such that the keys are groups and values are items + found in that group. + """ if not multi: dct = defaultdict(lambda: []) for i in lst: @@ -357,18 +340,17 @@ def groupby(func, lst, multi=False): for key in keys[:-1]: d = d.setdefault(key, {}) d.setdefault(keys[-1], []).append(i) - + return dct def iter_groups2(items, key): """ - Iterates over groups of consecutive values x from 'items' that have equal key(x)""" - + Iterate over groups of consecutive values x from 'items' with equal key(x) + """ def iter_subgroup(): pass - NULL = object() last_key = NULL group = [] @@ -378,7 +360,7 @@ def iter_subgroup(): if k != last_key: if group: yield group - + # start new group group = [] last_key = k @@ -390,9 +372,8 @@ def iter_subgroup(): def iter_groups(items, key): """ - Iterates over groups of consecutive values x from 'items' that have equal key(x)""" - - + Iterates over groups of consecutive values x from 'items' with equal key(x) + """ NULL = object() last_key = NULL group = [] @@ -402,48 +383,42 @@ def iter_groups(items, key): if k != last_key: if group: yield group - + # start new group group = [] last_key = k group.append(item) + # yield last group if group: yield group - def unique(lst): """ Returns a copy of 'lst' with only unique entries. The list is stable (the first occurance is kept). """ - found = set() - + lst2 = [] for i in lst: if i not in found: lst2.append(i) found.add(i) - + return lst2 def mapapply(funcs, lst): """ - apply each function in 'funcs' to one element in 'lst' + Apply each function in 'funcs' to one element in 'lst' """ - - lst2 = [] - for func, item in izip(funcs, lst): - lst2.append(func(item)) - return lst2 + return [func(item) for func, item in izip(funcs, lst)] def cumsum(vals): - """Returns a cumalative sum of vals (as a list)""" - + """Return a cumalative sum of vals (as a list).""" lst = [] tot = 0 for v in vals: @@ -451,10 +426,9 @@ def cumsum(vals): lst.append(tot) return lst + def icumsum(vals): - """Returns a cumalative sum of vals (as an iterator)""" - - lst = [] + """Return a cumalative sum of vals (as an iterator).""" tot = 0 for v in vals: tot += v @@ -463,13 +437,12 @@ def icumsum(vals): def frange(start, end, step): """ - Generates a range of floats - - start -- begining of range - end -- end of range - step -- step size + Generates a range of floats + + start: begining of range + end: end of range + step: step size """ - i = 0 val = start while val < end: @@ -479,14 +452,12 @@ def frange(start, end, step): def ilen(iterator): - """ - Returns the size of an iterator - """ + """Return the size of an iterator.""" return sum(1 for i in iterator) def exc_default(func, val, exc=Exception): - """Specify a default value for when an exception occurs""" + """Specify a default value for when an exception occurs.""" try: return func() except exc: @@ -496,8 +467,10 @@ def exc_default(func, val, exc=Exception): #============================================================================= # simple matrix functions -def make_matrix(nrows, ncols, val = 0): - +def make_matrix(nrows, ncols, val=0): + """ + Create a new matrix with 'nrows' rows and 'ncols' columns. + """ return [[val for i in xrange(ncols)] for j in xrange(nrows)] @@ -505,22 +478,21 @@ def make_matrix(nrows, ncols, val = 0): def transpose(mat): """ Transpose a matrix - + Works better than zip() in that rows are lists not tuples """ - assert equal(* map(len, mat)), "rows are not equal length" - + mat2 = [] - + for j in xrange(len(mat[0])): row2 = [] mat2.append(row2) for row in mat: row2.append(row[j]) - + return mat2 - + def submatrix(mat, rows=None, cols=None): """ @@ -528,43 +500,41 @@ def submatrix(mat, rows=None, cols=None): Rows and columns will appear in the order as indicated in 'rows' and 'cols' """ - - if rows == None: + if rows is None: rows = xrange(len(mat)) - if cols == None: + if cols is None: cols = xrange(len(mat[0])) - + mat2 = [] - + for i in rows: newrow = [] mat2.append(newrow) for j in cols: newrow.append(mat[i][j]) - + return mat2 def map2(func, *matrix): """ Maps a function onto the elements of a matrix - + Also accepts multiple matrices. Thus matrix addition is - + map2(add, matrix1, matrix2) - + """ - matrix2 = [] - + for i in xrange(len(matrix[0])): - row2 = [] + row2 = [] matrix2.append(row2) for j in xrange(len(matrix[0][i])): args = [x[i][j] for x in matrix] row2.append(func(* args)) - + return matrix2 @@ -582,24 +552,23 @@ def max2(matrix): def range2(width, height): """Iterates over the indices of a matrix - + Thus list(range2(3, 2)) returns [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)] """ - for i in xrange(width): for j in xrange(height): yield i, j #============================================================================= -# list counting and finding functions +# List counting and finding functions def count(func, lst): """ Counts the number of times func(x) is True for x in list 'lst' - + See also: counteq(a, lst) count items equal to a countneq(a, lst) count items not equal to a @@ -614,22 +583,39 @@ def count(func, lst): n += 1 return n -def counteq(a, lst): return count(eqfunc(a), lst) -def countneq(a, lst): return count(neqfunc(a), lst) -def countle(a, lst): return count(lefunc(a), lst) -def countlt(a, lst): return count(ltfunc(a), lst) -def countge(a, lst): return count(gefunc(a), lst) -def countgt(a, lst): return count(gtfunc(a), lst) + +def counteq(a, lst): + return count(eqfunc(a), lst) + + +def countneq(a, lst): + return count(neqfunc(a), lst) + + +def countle(a, lst): + return count(lefunc(a), lst) + + +def countlt(a, lst): + return count(ltfunc(a), lst) + + +def countge(a, lst): + return count(gefunc(a), lst) + + +def countgt(a, lst): + return count(gtfunc(a), lst) def find(func, *lsts): """ Returns the indices 'i' of 'lst' where func(lst[i]) == True - + if N lists are passed, N arguments are passed to 'func' at a time. - Thus, find(func, list1, list2) returns the list of indices 'i' where + Thus, find(func, list1, list2) returns the list of indices 'i' where func(list1[i], list2[i]) == True - + See also: findeq(a, lst) find items equal to a findneq(a, lst) find items not equal to a @@ -638,9 +624,8 @@ def find(func, *lsts): findge(a, lst) find items greater than or equal to a findgt(a, lst) find items greater than a """ - pos = [] - + if len(lsts) == 1: # simple case, one list lst = lsts[0] @@ -650,41 +635,56 @@ def find(func, *lsts): else: # multiple lists given assert equal(* map(len, lsts)), "lists are not same length" - - nvars = len(lsts) + for i in xrange(len(lsts[0])): if func(* [x[i] for x in lsts]): pos.append(i) - + return pos -def findeq(a, lst): return find(eqfunc(a), lst) -def findneq(a, lst): return find(neqfunc(a), lst) -def findle(a, lst): return find(lefunc(a), lst) -def findlt(a, lst): return find(ltfunc(a), lst) -def findge(a, lst): return find(gefunc(a), lst) -def findgt(a, lst): return find(gtfunc(a), lst) + +def findeq(a, lst): + return find(eqfunc(a), lst) + + +def findneq(a, lst): + return find(neqfunc(a), lst) + + +def findle(a, lst): + return find(lefunc(a), lst) + + +def findlt(a, lst): + return find(ltfunc(a), lst) + + +def findge(a, lst): + return find(gefunc(a), lst) + + +def findgt(a, lst): + return find(gtfunc(a), lst) def islands(lst): - """Takes a iterable and returns islands of equal consecutive items - + """Takes a iterable and returns islands of equal consecutive items + Return value is a dict with the following format - + counts = {elm1: [(start,end), (start,end), ...], elm2: [(start,end), (start,end), ...] ...} - - where for each (start,end) in counts[elm1] we have lst[start:end] only + + where for each (start,end) in counts[elm1] we have lst[start:end] only containing elm1 - + """ - counts = {} - NULL = object() # unique NULL + NULL = object() # unique NULL last = NULL start = 0 - + for i, x in enumerate(lst): if x != last and last != NULL: counts.setdefault(last, []).append((start, i)) @@ -692,52 +692,51 @@ def islands(lst): last = x if last != NULL: counts.setdefault(last, []).append((start, i+1)) - - return counts + return counts def binsearch(lst, val, cmp=cmp, order=1, key=None): """Performs binary search for val in lst - + if val in lst: Returns (i, i) where lst[i] == val - if val not in lst + if val not in lst Returns index i,j where lst[i] < val < lst[j] - + runs in O(log n) - lst -- sorted lst to search - val -- value to find - cmp -- comparison function (default: cmp) - order -- sort order of lst (1=ascending (default), -1=descending) + lst: sorted lst to search + val: value to find + cmp: comparison function (default: cmp) + order: sort order of lst (1=ascending (default), -1=descending) """ #TODO: make a funtion based linear search - + assert order == 1 or order == -1 - + if key is not None: - cmp = lambda a,b: cmp(key(a), key(b)) + cmp = lambda a, b: cmp(key(a), key(b)) low = 0 top = len(lst) - 1 - + if len(lst) == 0: return None, None - + if cmp(lst[-1], val) * order == -1: return (top, None) - + if cmp(lst[0], val) * order == 1: return (None, low) - + while top - low > 1: ptr = (top + low) // 2 - + comp = cmp(lst[ptr], val) * order - + if comp == 0: # have we found val exactly? return ptr, ptr @@ -746,8 +745,7 @@ def binsearch(lst, val, cmp=cmp, order=1, key=None): low = ptr else: top = ptr - - + # check top and low for exact hits if cmp(lst[low], val) == 0: return low, low @@ -757,19 +755,17 @@ def binsearch(lst, val, cmp=cmp, order=1, key=None): return low, top - - - #============================================================================= -# max and min functions +# Max and min functions + def argmax(lst, key=lambda x: x): """ Find the index 'i' in 'lst' with maximum lst[i] - - lst -- list to search - key -- function to apply to each lst[i]. - argmax(lst, key=func) --> argmax(map(key, lst)) + + lst: list to search + key: function to apply to each lst[i]. + argmax(lst, key=func) --> argmax(map(key, lst)) """ it = iter(lst) @@ -786,12 +782,11 @@ def argmax(lst, key=lambda x: x): def argmin(lst, key=lambda x: x): """ Find the index 'i' in 'lst' with minimum lst[i] - - lst -- list to search - key -- function to apply to each lst[i]. - argmin(lst, key=func) --> argmin(map(key, lst)) + + lst: list to search + key: function to apply to each lst[i]. + argmin(lst, key=func) --> argmin(map(key, lst)) """ - it = iter(lst) low = 0 lowval = key(it.next()) @@ -803,67 +798,51 @@ def argmin(lst, key=lambda x: x): return low -''' -def argmin_old(lst, key=lambda x: x): - """ - Find the index 'i' in 'lst' with minimum lst[i] - - lst -- list to search - key -- function to apply to each lst[i]. - argmin(lst, key=func) --> argmin(map(key, lst)) - """ - - assert len(lst) > 0 - low = 0 - lowval = key(lst[0]) - for i in xrange(1, len(lst)): - val = key(lst[i]) - if val < lowval: - low = i - lowval = val - return low +#============================================================================= +# math functions -def argmax_old(lst, key=lambda x: x): - """ - Find the index 'i' in 'lst' with maximum lst[i] - - lst -- list to search - key -- function to apply to each lst[i]. - argmax(lst, key=func) --> argmax(map(key, lst)) - """ - - assert len(lst) > 0 - top = 0 - topval = key(lst[0]) - for i in xrange(1, len(lst)): - val = key(lst[i]) - if val > topval: - top = i - topval = val - return top -''' - +def prod(lst): + """Computes the product of a list of numbers.""" + p = 1.0 + for i in lst: + p *= i + return p -#============================================================================= -# math functions -# # comparison function factories # -# These functions will return convenient comparison functions. +# These functions will return convenient comparison functions. # # example: # filter(ltfunc(4), lst) ==> returns all values in lst less than 4 # count(ltfunc(4), lst) ==> returns the number of values in lst < 4 -# -def eqfunc(a): return lambda x: x == a -def neqfunc(a): return lambda x: x != a -def ltfunc(a): return lambda x: x < a -def gtfunc(a): return lambda x: x > a -def lefunc(a): return lambda x: x <= a -def gefunc(a): return lambda x: x >= a + +def eqfunc(a): + return lambda x: x == a + + +def neqfunc(a): + return lambda x: x != a + + +def ltfunc(a): + return lambda x: x < a + + +def gtfunc(a): + return lambda x: x > a + + +def lefunc(a): + return lambda x: x <= a + + +def gefunc(a): + return lambda x: x >= a + + def withinfunc(a, b, ainc=True, binc=True): if ainc: if binc: @@ -881,15 +860,31 @@ def sign(num): """Returns the sign of a number""" return cmp(num, 0) + def lg(num): """Retruns the log_2 of a number""" return math.log(num, 2) -def add(a, b): return a + b -def sub(a, b): return a - b -def mul(a, b): return a * b -def idiv(a, b): return a / b -def div(a, b): return a / float(b) + +def add(a, b): + return a + b + + +def sub(a, b): + return a - b + + +def mul(a, b): + return a * b + + +def idiv(a, b): + return a / b + + +def div(a, b): + return a / float(b) + def safediv(a, b, default=INF): try: @@ -897,32 +892,35 @@ def safediv(a, b, default=INF): except ZeroDivisionError: return default + def safelog(x, base=math.e, default=-INF): try: return math.log(x, base) except (OverflowError, ValueError): return default - -def invcmp(a, b): return cmp(b, a) + + +def invcmp(a, b): + return cmp(b, a) + def clamp(x, low=None, high=None): """Clamps a value 'x' between the values 'low' and 'high' If low == None, then there is no lower bound If high == None, then there is no upper bound """ - if high is not None and x > high: return high elif low is not None and x < low: return low else: - return x + return x + def clampfunc(low=None, high=None): return lambda x: clamp(x, low, high) - def compose2(f, g): """ Compose two functions into one @@ -930,15 +928,15 @@ def compose2(f, g): compose2(f, g)(x) <==> f(g(x)) """ return lambda *args, **kargs: f(g(*args, **kargs)) - + def compose(*funcs): - """Composes two or more functions into one function - - example: - compose(f,g,h,i)(x) <==> f(g(h(i(x)))) """ + Composes two or more functions into one function. + Example: + compose(f,g,h,i)(x) <==> f(g(h(i(x)))) + """ funcs = reversed(funcs) f = funcs.next() for g in funcs: @@ -949,8 +947,8 @@ def compose(*funcs): def overlap(a, b, x, y, inc=True): """ Returns True if range [a,b] overlaps [x,y] - - inc -- if True, treat [a,b] and [x,y] as inclusive + + inc: if True, treat [a,b] and [x,y] as inclusive """ if inc: return (y >= a) and (x <= b) @@ -958,26 +956,22 @@ def overlap(a, b, x, y, inc=True): return (y > a) and (x < b) - #============================================================================= # regex -# def match(pattern, text): """ A quick way to do pattern matching. - - remember: to name tokens use (?Ppattern) + + NOTE: name tokens using (?Ppattern) """ - m = re.match(pattern, text) - if m is None: return {} else: return m.groupdict() - + def evalstr(text): """Replace expressions in a string (aka string interpolation) @@ -985,28 +979,28 @@ def evalstr(text): >>> name = 'Matt' >>> evalstr("My name is ${name} and my age is ${12+12}") 'My name is Matt and my age is 24' - + "${!expr}" expands to "${expr}" - + """ - + # get environment of caller frame = sys._getframe(1) global_dict = frame.f_globals local_dict = frame.f_locals - + # find all expression to replace m = re.finditer("\$\{(?P[^\}]*)\}", text) - + # build new string try: strs = [] last = 0 for x in m: expr = x.groupdict()['expr'] - - strs.append(text[last:x.start()]) - + + strs.append(text[last:x.start()]) + if expr.startswith("!"): strs.append("${" + expr[1:] + "}") else: @@ -1015,89 +1009,61 @@ def evalstr(text): strs.append(text[last:len(text)]) except Exception, e: raise Exception("evalstr: " + str(e)) - + return "".join(strs) #============================================================================= # common Input/Output + def read_ints(filename): - """Read a list of integers from a file (one int per line) - - filename may also be a stream - """ - + """Read a list of integers from a file (one int per line).""" infile = open_stream(filename) - vec = [] - for line in infile: - vec.append(int(line)) - return vec - + return [int(line) for line in infile] def read_floats(filename): - """Read a list of floats from a file (one float per line) - - filename may also be a stream - """ + """Read a list of floats from a file (one float per line).""" infile = open_stream(filename) - vec = [] - for line in infile: - vec.append(float(line)) - return vec + return [float(line) for line in infile] def read_strings(filename): - """Read a list of strings from a file (one string per line) - - filename may also be a stream - """ + """Read a list of strings from a file (one string per line).""" infile = open_stream(filename) - vec = [line.rstrip("\n") for line in infile] - return vec + return [line.rstrip("\n") for line in infile] -def read_dict(filename, delim="\t", keytype=str, valtype=str): - """Read a dict from a file - - filename may also be a stream - """ - +def read_dict(filename, delim="\t", key=str, val=str): + """Read a dict from a file.""" infile = open_stream(filename) dct = {} - + for line in infile: - tokens = line.rstrip("\n").split(delim) - assert len(tokens) >= 2, line - dct[keytype(tokens[0])] = valtype(tokens[1]) - + tokens = line.rstrip("\n").split(delim, 1) + assert len(tokens) == 2, line + dct[key(tokens[0])] = val(tokens[1]) + return dct def write_list(filename, lst): - """Write a list of anything (ints, floats, strings, etc) to a file. - - filename may also be a stream - """ + """Write a list of anything (ints, floats, strings, etc) to a file.""" out = open_stream(filename, "w") for i in lst: print >>out, i def write_dict(filename, dct, delim="\t"): - """Write a dictionary to a file - - filename may also be a stream - """ - + """Write a dictionary to a file.""" out = open_stream(filename, "w") for k, v in dct.iteritems(): out.write("%s%s%s\n" % (str(k), delim, str(v))) - class IgnoreCloseFile (object): + """Wrap a stream such that close() is ignored.""" def __init__(self, stream): self.__stream = stream @@ -1112,22 +1078,18 @@ def close(self): pass - -def open_stream(filename, mode = "r", ignore_close=True): +def open_stream(filename, mode="r", ignore_close=True): """Returns a file stream depending on the type of 'filename' and 'mode' - - The following types for 'filename' are handled: - - stream - returns 'filename' unchanged - iterator - returns 'filename' unchanged - URL string - opens http pipe - '-' - opens stdin or stdout, depending on 'mode' - other string - opens file with name 'filename' - - mode is standard mode for file(): r,w,a,b - ignore_close -- if True and filename is a stream, then close() calls on - the returned stream will be ignored. + filename: the following types for 'filename' are handled: + stream - returns 'filename' unchanged + iterator - returns 'filename' unchanged + URL string - opens http pipe + '-' - opens stdin or stdout, depending on 'mode' + other string - opens file with name 'filename' + mode: standard mode for file(): r,w,a,b + ignore_close: if True and filename is a stream, then close() calls on + the returned stream will be ignored. """ is_stream = False @@ -1136,19 +1098,19 @@ def open_stream(filename, mode = "r", ignore_close=True): if hasattr(filename, "read") or hasattr(filename, "write"): stream = filename is_stream = True - + # if mode is reading and filename is an iterator elif "r" in mode and hasattr(filename, "next"): stream = filename is_stream = True - + # if filename is a string then open it elif isinstance(filename, basestring): # open URLs if filename.startswith("http://"): import urllib2 stream = urllib2.urlopen(filename) - + # open stdin and stdout elif filename == "-": if "w" in mode: @@ -1159,11 +1121,11 @@ def open_stream(filename, mode = "r", ignore_close=True): is_stream = True else: raise Exception("stream '-' can only be opened with modes r/w") - + # open regular file else: stream = open(filename, mode) - + # cannot handle other types for filename else: raise Exception("unknown filename type '%s'" % type(filename)) @@ -1174,32 +1136,29 @@ def open_stream(filename, mode = "r", ignore_close=True): return stream - #============================================================================= # Delimited files -# class DelimReader: """Reads delimited files""" def __init__(self, filename, delim="\t", types=None, parse=False): - """Constructor for DelimReader - - arguments: - filename -- filename or stream to read from - delim -- delimiting character - types -- types of columns - pars -- if True, fields are automatically parsed """ - + Constructor for DelimReader. + + filename: filename or stream to read from + delim: delimiting character + types: types of columns + pars: if True, fields are automatically parsed + """ self.infile = open_stream(filename) self.delim = delim self.types = types self.parse = parse - + def __iter__(self): return self - + def next(self): line = self.infile.next() row = line.rstrip("\n").split(self.delim) @@ -1220,9 +1179,10 @@ def iter_delim(filename, delim="\t", types=None, parse=False): """Iterate through a tab delimited file""" return DelimReader(filename, delim, types, parse) + def write_delim(filename, data, delim="\t"): """Write a 2D list into a file using a delimiter""" - + out = open_stream(filename, "w") for row in data: out.write(delim.join(str(x) for x in row)) @@ -1231,8 +1191,9 @@ def write_delim(filename, data, delim="\t"): def guess_type(text): - """Guesses the type of a value encoded in a string""" - + """ + Guess the type of a value encoded in a string. + """ # int try: int(text) @@ -1252,8 +1213,9 @@ def guess_type(text): def autoparse(text): - """Guesses the type of a value encoded in a string and parses""" - + """ + Guesse the type of a value encoded in a string and parses + """ # int try: return int(text) @@ -1268,24 +1230,20 @@ def autoparse(text): # string return text - #============================================================================= -# printing functions -# +# Printing functions def default_justify(val): - if isinstance(val, int) or \ - isinstance(val, float): + if isinstance(val, (int, float)): return "right" else: return "left" def default_format(val): - if isinstance(val, int) and \ - not isinstance(val, bool): + if isinstance(val, int) and not isinstance(val, bool): return int2pretty(val) elif isinstance(val, float): if abs(val) < 1e-4: @@ -1296,81 +1254,79 @@ def default_format(val): return str(val) -def printcols(data, width=None, spacing=1, format=default_format, +def printcols(data, width=None, spacing=1, format=default_format, justify=default_justify, out=sys.stdout, colwidth=INF, overflow="!"): - """Prints a list or matrix in aligned columns - - data - a list or matrix - width - maxium number of characters per line (default: 75 for lists) - spacing - number of spaces between columns (default: 1) - out - stream to print to (default: sys.stdout) - """ - + """ + Print a list or matrix in aligned columns. + + data: a list or matrix + width: maxium number of characters per line (default: 75 for lists) + spacing: number of spaces between columns (default: 1) + out: stream to print to (default: sys.stdout) + """ if len(data) == 0: return - + if isinstance(data[0], (list, tuple)): # matrix printing has default width of unlimited if width is None: width = 100000 - + mat = data else: # list printing has default width 75 if width is None: width = 75 - + ncols = int(width / (max(map(lambda x: len(format(x)), data))+spacing)) mat = list2matrix(data, ncols=ncols, bycols=True) - - + # turn all entries into strings matstr = map2(format, mat) - + # overflow for row in matstr: for j in xrange(len(row)): if len(row[j]) > colwidth: row[j] = row[j][:colwidth-len(overflow)] + overflow - + # ensure every row has same number of columns maxcols = max(map(len, matstr)) for row in matstr: if len(row) < maxcols: row.extend([""] * (maxcols - len(row))) - - + # find the maximum width char in each column maxwidths = map(max, map2(len, zip(* matstr))) - - + # print out matrix with whitespace padding for i in xrange(len(mat)): fields = [] for j in xrange(len(mat[i])): just = justify(mat[i][j]) - + if just == "right": - fields.append((" " * (maxwidths[j] - len(matstr[i][j]))) + \ - matstr[i][j] + \ + fields.append((" " * (maxwidths[j] - len(matstr[i][j]))) + + matstr[i][j] + (" " * spacing)) else: - # do left by default - fields.append(matstr[i][j] + - (" " * (maxwidths[j] - len(matstr[i][j]) + spacing))) + # do left by default + fields.append( + matstr[i][j] + + (" " * (maxwidths[j] - len(matstr[i][j]) + spacing))) out.write("".join(fields)[:width] + "\n") def list2matrix(lst, nrows=None, ncols=None, bycols=True): - """Turn a list into a matrix by wrapping its entries""" - + """Turn a list into a matrix by wrapping its entries.""" + mat = [] - - if nrows == None and ncols == None: + + if nrows is None and ncols is None: nrows = int(math.sqrt(len(lst))) ncols = int(math.ceil(len(lst) / float(nrows))) - elif nrows == None: + elif nrows is None: nrows = int(math.ceil(len(lst) / float(min(ncols, len(lst))))) else: ncols = int(math.ceil(len(lst) / float(min(nrows, len(lst))))) @@ -1384,17 +1340,17 @@ def list2matrix(lst, nrows=None, ncols=None, bycols=True): k = i*ncols + j if k < len(lst): mat[-1].append(lst[k]) - + return mat def printwrap(text, width=80, prefix="", out=sys.stdout): - """Prints text with wrapping""" - if width == None: + """Print text with wrapping.""" + if width is None: out.write(text) out.write("\n") return - + pos = 0 while pos < len(text): out.write(prefix) @@ -1403,32 +1359,33 @@ def printwrap(text, width=80, prefix="", out=sys.stdout): pos += width - def int2pretty(num): - """Returns a pretty-printed version of an int""" - - string = str(num) + """Return a pretty-printed version of an int.""" + string = str(abs(num)) parts = [] l = len(string) for i in xrange(0, l, 3): t = l - i s = t - 3 - if s < 0: s = 0 + if s < 0: + s = 0 parts.append(string[s:t]) parts.reverse() - return ",".join(parts) + if num < 0: + return "-" + ",".join(parts) + else: + return ",".join(parts) def pretty2int(string): - """Parses a pretty-printed version of an int into an int""" + """Parse a pretty-printed version of an int into an int.""" return int(string.replace(",", "")) def str2bool(val): - """Correctly converts the strings "True" and "False" to the - booleans True and False """ - + Convert the strings "True" and "False" to the booleans True and False. + """ if val == "True": return True elif val == "False": @@ -1437,17 +1394,16 @@ def str2bool(val): raise Exception("unknown string for bool '%s'" % val) - def print_dict(dic, key=lambda x: x, val=lambda x: x, - num=None, cmp=cmp, order=None, reverse=False, - spacing=4, out=sys.stdout, - format=default_format, - justify=default_justify): - """Prints a dictionary in two columns""" - - if num == None: + num=None, cmp=cmp, order=None, reverse=False, + spacing=4, out=sys.stdout, + format=default_format, + justify=default_justify): + """Print a dictionary in two columns.""" + + if num is None: num = len(dic) - + dic = mapdict(dic, key=key, val=val) items = dic.items() @@ -1455,8 +1411,8 @@ def print_dict(dic, key=lambda x: x, val=lambda x: x, items.sort(key=order, reverse=reverse) else: items.sort(cmp, reverse=reverse) - - printcols(items[:num], spacing=spacing, out=out, format=format, + + printcols(items[:num], spacing=spacing, out=out, format=format, justify=justify) @@ -1464,12 +1420,11 @@ def print_row(*args, **kargs): """ Prints a delimited row of values - out -- output stream (default: sys.stdout) - delim -- delimiter (default: '\t') - newline -- newline character (default: '\n') - format -- formatting function (default: str) + out: output stream (default: sys.stdout) + delim: delimiter (default: '\t') + newline: newline character (default: '\n') + format: formatting function (default: str) """ - out = kargs.get("out", sys.stdout) delim = kargs.get("delim", "\t") newline = kargs.get("newline", "\n") @@ -1479,11 +1434,11 @@ def print_row(*args, **kargs): #============================================================================= # Parsing -# -def read_word(infile, delims = [" ", "\t", "\n"]): + +def read_word(infile, delims=" \t\n"): word = "" - + while True: char = infile.read(1) if char == "": @@ -1491,13 +1446,14 @@ def read_word(infile, delims = [" ", "\t", "\n"]): if char not in delims: word += char break - + while True: char = infile.read(1) if char == "" or char in delims: return word word += char + def read_until(stream, chars): token = "" while True: @@ -1506,6 +1462,7 @@ def read_until(stream, chars): return token, char token += char + def read_while(stream, chars): token = "" while True: @@ -1514,6 +1471,7 @@ def read_while(stream, chars): return token, char token += char + def skip_comments(infile): for line in infile: if line.startswith("#") or line.startswith("\n"): @@ -1521,36 +1479,35 @@ def skip_comments(infile): yield line - class IndentStream: """ Makes any stream into an indent stream. - + Indent stream auto indents every line written to it """ - + def __init__(self, stream): self.stream = open_stream(stream, "w") self.linestart = True self.depth = 0 - + def indent(self, num=2): self.depth += num - + def dedent(self, num=2): self.depth -= num if self.depth < 0: self.depth = 0 - + def write(self, text): lines = text.split("\n") - + for line in lines[:-1]: if self.linestart: self.stream.write(" "*self.depth) self.linestart = True self.stream.write(line + "\n") - + if len(lines) > 0: if text.endswith("\n"): self.linestart = True @@ -1559,17 +1516,13 @@ def write(self, text): self.linestart = False - - - - #============================================================================= # file/directory functions def list_files(path, ext=""): """Returns a list of files in 'path' ending with 'ext'""" - + files = filter(lambda x: x.endswith(ext), os.listdir(path)) files.sort() return [os.path.join(path, x) for x in files] @@ -1582,57 +1535,58 @@ def tempfile(path, prefix, ext): fd, filename = temporaryfile.mkstemp(ext, prefix) os.close(fd) """ - + import warnings warnings.filterwarnings("ignore", ".*", RuntimeWarning) - filename = os.tempnam(path, "____") + filename = os.tempnam(path, "____") filename = filename.replace("____", prefix) + ext warnings.filterwarnings("default", ".*", RuntimeWarning) - + return filename def deldir(path): """Recursively remove a directory""" - - # This function is slightly more complicated because of a + + # This function is slightly more complicated because of a # strange behavior in AFS, that creates .__afsXXXXX files - + dirs = [] - + def cleandir(arg, path, names): for name in names: filename = os.path.join(path, name) if os.path.isfile(filename): os.remove(filename) dirs.append(path) - + # remove files os.path.walk(path, cleandir, "") - + # remove directories for i in xrange(len(dirs)): # AFS work around afsFiles = list_files(dirs[-i]) for f in afsFiles: os.remove(f) - + while True: try: if os.path.exists(dirs[-i]): os.rmdir(dirs[-i]) - except Exception, e: + except Exception: continue break def replace_ext(filename, oldext, newext): """Safely replaces a file extension new a new one""" - + if filename.endswith(oldext): return filename[:-len(oldext)] + newext else: - raise Exception("file '%s' does not have extension '%s'" % (filename, oldext)) + raise Exception("file '%s' does not have extension '%s'" % + (filename, oldext)) def makedirs(filename): @@ -1640,25 +1594,23 @@ def makedirs(filename): Makes a path of directories. Does not fail if filename already exists """ - if not os.path.isdir(filename): os.makedirs(filename) #============================================================================= # sorting -# def sortindex(lst, cmp=cmp, key=None, reverse=False): """Returns the sorted indices of items in lst""" ind = range(len(lst)) - + if key is None: compare = lambda a, b: cmp(lst[a], lst[b]) else: compare = lambda a, b: cmp(key(lst[a]), key(lst[b])) - + ind.sort(compare, reverse=reverse) return ind @@ -1666,25 +1618,25 @@ def sortindex(lst, cmp=cmp, key=None, reverse=False): def sortranks(lst, cmp=cmp, key=None, reverse=False): """Returns the ranks of items in lst""" return invperm(sortindex(lst, cmp, key, reverse)) - + def sort_many(lst, *others, **args): """Sort several lists based on the sorting of 'lst'""" args.setdefault("reverse", False) - if "key" in args: + if "key" in args: ind = sortindex(lst, key=args["key"], reverse=args["reverse"]) elif "cmp" in args: ind = sortindex(lst, cmp=args["cmp"], reverse=args["reverse"]) else: ind = sortindex(lst, reverse=args["reverse"]) - + lsts = [mget(lst, ind)] - + for other in others: lsts.append(mget(other, ind)) - + return lsts @@ -1695,11 +1647,9 @@ def invperm(perm): inv[perm[i]] = i return inv - #============================================================================= # histograms, distributions -# def one_norm(vals): """Normalize values so that they sum to 1""" @@ -1708,21 +1658,21 @@ def one_norm(vals): def bucket_size(array, ndivs=None, low=None, width=None): - """Determine the bucket size needed to divide the values in array into + """Determine the bucket size needed to divide the values in array into 'ndivs' evenly sized buckets""" - + if low is None: low = min(array) - + if ndivs is None: if width is None: ndivs = 20 else: ndivs = int(math.ceil(max((max(array) - low) / float(width), 1))) - + if width is None: width = (max(array) - low) / float(ndivs) - + return ndivs, low, width @@ -1730,7 +1680,6 @@ def bucket_bin(item, ndivs, low, width): """ Return the bin for an item """ - assert item >= low, Exception("negative bucket index") return min(int((item - low) / width), ndivs-1) @@ -1742,11 +1691,11 @@ def bucket(array, ndivs=None, low=None, width=None, key=lambda x: x): # set bucket sizes ndivs, low, width = bucket_size(keys, ndivs, low, width) - + # init histogram h = [[] for i in xrange(ndivs)] x = [] - + # bin items for i in array: if i >= low: @@ -1758,14 +1707,14 @@ def bucket(array, ndivs=None, low=None, width=None, key=lambda x: x): def hist(array, ndivs=None, low=None, width=None): """Create a histogram of 'array' with 'ndivs' buckets""" - + # set bucket sizes ndivs, low, width = bucket_size(array, ndivs, low, width) - + # init histogram h = [0] * ndivs x = [] - + # count items for i in array: if i >= low: @@ -1777,66 +1726,63 @@ def hist(array, ndivs=None, low=None, width=None): return (x, h) -def hist2(array1, array2, +def hist2(array1, array2, ndivs1=None, ndivs2=None, low1=None, low2=None, width1=None, width2=None): """Perform a 2D histogram""" - - + # set bucket sizes ndivs1, low1, width1 = bucket_size(array1, ndivs1, low1, width1) ndivs2, low2, width2 = bucket_size(array2, ndivs2, low2, width2) - + # init histogram h = [[0] * ndivs1 for i in xrange(ndivs2)] labels = [] - - for j,i in zip(array1, array2): + + for j, i in zip(array1, array2): if j > low1 and i > low2: - h[bucket_bin(i, ndivs2, low2, width2)] \ - [bucket_bin(j, ndivs1, low1, width1)] += 1 - + index1 = bucket_bin(i, ndivs2, low2, width2) + index2 = bucket_bin(j, ndivs1, low1, width1) + h[index1][index2] += 1 + for i in range(ndivs2): labels.append([]) - for j in range(ndivs1): + for j in range(ndivs1): labels[-1].append([j * width1 + low1, i * width2 + low2]) return labels, h - + def histbins(bins): """Adjust the bins from starts to centers, this is useful for plotting""" - + bins2 = [] - + if len(bins) == 1: bins2 = [bins[0]] else: for i in range(len(bins) - 1): bins2.append((bins[i] + bins[i+1]) / 2.0) bins2.append(bins[-1] + (bins[-1] - bins[-2]) / 2.0) - + return bins2 - + def distrib(array, ndivs=None, low=None, width=None): """Find the distribution of 'array' using 'ndivs' buckets""" - + # set bucket sizes ndivs, low, width = bucket_size(array, ndivs, low, width) - + h = hist(array, ndivs, low, width) - area = 0 - total = float(sum(h[1])) return (h[0], map(lambda x: (x/total)/width, h[1])) def hist_int(array): - """Returns a histogram of integers as a list of counts""" - - hist = [0] * (max(array) + 1) + """Return a histogram of integers as a list of counts.""" + hist = [0] * (max(array) + 1) negative = [] for i in array: if (i >= 0): @@ -1847,12 +1793,13 @@ def hist_int(array): def hist_dict(array): - """Returns a histogram of any items as a dict. - - The keys of the returned dict are elements of 'array' and the values - are the counts of each element in 'array'. """ - + Return a histogram of any items as a dict. + + The keys of the returned dict are elements of 'array' and the values + are the counts of each element in 'array'. + """ + hist = {} for i in array: if i in hist: @@ -1863,25 +1810,24 @@ def hist_dict(array): def print_hist(array, ndivs=20, low=None, width=None, - cols=75, spacing=2, out=sys.stdout): + cols=75, spacing=2, out=sys.stdout): data = list(hist(array, ndivs, low=low, width=width)) - + # find max bar maxwidths = map(max, map2(compose(len, str), data)) - maxbar = cols- sum(maxwidths) - 2 * spacing - + maxbar = cols - sum(maxwidths) - 2 * spacing + # make bars bars = [] maxcount = max(data[1]) for count in data[1]: bars.append("*" * int(count * maxbar / float(maxcount))) data.append(bars) - - printcols(zip(* data), spacing=spacing, out=out) + printcols(zip(* data), spacing=spacing, out=out) -# import common functions from other files, +# import common functions from other files, # so that only util needs to be included try: @@ -1907,4 +1853,3 @@ def print_hist(array, ndivs=20, low=None, width=None, from plotting import * except ImportError: pass -