Skip to content

Commit

Permalink
Fix problem where tree search is truncated incorrectly. (#244)
Browse files Browse the repository at this point in the history
* add (failing) test for --best-only and similarity calculation
* some simple Python API improvements + tests
* support MinHash(..., scaled=val)
* pass tests with new scaled convenience functions
* reorder scaled args
* remove unnecessary imports, refactor search a bit
* fix SBT internal nodes to have metadata; added max_n_below to metadata
* move search mechanics into sourmash_lib.search
* move gather code into sourmash_lib.search
* ugly fix to get --output-unassigned working!
* add in --traverse-directory to search and gather, along with tests
* add -U to pip install
* add --randomize to sourmash compute
* add a nicer sbt API
* Fixed a bug when loading minhashes with track_abundance, and a bunch of tests for pickling
* Fix division problem
  • Loading branch information
ctb authored and luizirber committed Oct 25, 2017
1 parent 2b14671 commit 2c2552a
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 183 deletions.
195 changes: 33 additions & 162 deletions sourmash_lib/commands.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import print_function
from __future__ import print_function, division

import argparse
import csv
Expand Down Expand Up @@ -580,7 +580,8 @@ def sbt_combine(args):


def index(args):
import sourmash_lib.sbt
import sourmash_lib.sbtmh

parser = argparse.ArgumentParser()
parser.add_argument('sbt_name', help='name to save SBT into')
parser.add_argument('signatures', nargs='+',
Expand Down Expand Up @@ -649,12 +650,14 @@ def index(args):


def search(args):
from sourmash_lib.sbtmh import search_minhashes, SearchMinHashesFindBest
from .search import search_databases

parser = argparse.ArgumentParser()
parser.add_argument('query', help='query signature')
parser.add_argument('databases', help='signatures/SBTs to search',
nargs='+')
parser.add_argument('--traverse-directory', action='store_true',
help='search all signatures underneath directories.')
parser.add_argument('-q', '--quiet', action='store_true',
help='suppress non-error output')
parser.add_argument('--threshold', default=0.08, type=float,
Expand Down Expand Up @@ -699,66 +702,24 @@ def search(args):
query.minhash.scaled, int(args.scaled))
query.minhash = query.minhash.downsample_scaled(args.scaled)

# set up the search function(s)
search_fn = search_minhashes

# similarity vs containment
query_similarity = lambda x: query.similarity(x, downsample=True)
if args.containment:
query_similarity = lambda x: query.contained_by(x, downsample=True)

# set up the search databases
databases = sourmash_args.load_sbts_and_sigs(args.databases,
query_ksize, query_moltype)
query_ksize, query_moltype,
args.traverse_directory)

if not len(databases):
error('Nothing found to search!')
sys.exit(-1)

# collect results across all the trees
SearchResult = namedtuple('SearchResult',
'similarity, match_sig, md5, filename, name')
results = []
found_md5 = set()
for (sbt_or_siglist, filename, is_sbt) in databases:
if args.best_only:
search_fn = sourmash_lib.sbtmh.SearchMinHashesFindBest().search

if is_sbt:
tree = sbt_or_siglist
notify('Searching SBT {}', filename)
for leaf in tree.find(search_fn, query, args.threshold):
similarity = query_similarity(leaf.data)
if similarity >= args.threshold and \
leaf.data.md5sum() not in found_md5:
sr = SearchResult(similarity=similarity,
match_sig=leaf.data,
md5=leaf.data.md5sum(),
filename=filename,
name=leaf.data.name())
found_md5.add(sr.md5)
results.append(sr)

else: # list of signatures
for ss in sbt_or_siglist:
similarity = query_similarity(ss)
if similarity >= args.threshold and \
ss.md5sum() not in found_md5:
sr = SearchResult(similarity=similarity,
match_sig=ss,
md5=ss.md5sum(),
filename=filename,
name=ss.name())
found_md5.add(sr.md5)
results.append(sr)

# sort results on similarity (reverse)
results.sort(key=lambda x: -x.similarity)
# do the actual search
results = search_databases(query, databases,
args.threshold, args.containment,
args.best_only)

n_matches = len(results)
if args.best_only:
notify("(truncated search because of --best-only; only trust top result")
args.num_results = 1

n_matches = len(results)
if n_matches <= args.num_results:
print_results('{} matches:'.format(len(results)))
else:
Expand All @@ -774,6 +735,9 @@ def search(args):
name = sr.match_sig._display_name(60)
print_results('{:>6} {}', pct, name)

if args.best_only:
notify("** reporting only one match because --best-only was set")

if args.output:
fieldnames = ['similarity', 'name', 'filename', 'md5']
w = csv.DictWriter(args.output, fieldnames=fieldnames)
Expand Down Expand Up @@ -867,10 +831,14 @@ def categorize(args):


def gather(args):
from .search import gather_databases

parser = argparse.ArgumentParser()
parser.add_argument('query', help='query signature')
parser.add_argument('databases', help='signatures/SBTs to search',
nargs='+')
parser.add_argument('--traverse-directory', action='store_true',
help='search all signatures underneath directories.')
parser.add_argument('-o', '--output', type=argparse.FileType('wt'),
help='output CSV containing matches to this file')
parser.add_argument('--save-matches', type=argparse.FileType('wt'),
Expand Down Expand Up @@ -919,54 +887,14 @@ def gather(args):

# set up the search databases
databases = sourmash_args.load_sbts_and_sigs(args.databases,
query_ksize, query_moltype)
query_ksize, query_moltype,
args.traverse_directory)

if not len(databases):
error('Nothing found to search!')
sys.exit(-1)

orig_query = query
orig_mins = orig_query.minhash.get_hashes()

# calculate the band size/resolution R for the genome
R_metagenome = orig_query.minhash.scaled

# define a function to do a 'best' search and get only top match.
def find_best(dblist, query):
results = []
for (sbt_or_siglist, filename, is_sbt) in dblist:
search_fn = sourmash_lib.sbtmh.SearchMinHashesFindBestIgnoreMaxHash().search

if is_sbt:
tree = sbt_or_siglist

for leaf in tree.find(search_fn, query, 0.0):
leaf_e = leaf.data.minhash
similarity = query.minhash.similarity_ignore_maxhash(leaf_e)
if similarity > 0.0:
results.append((similarity, leaf.data))
else:
for ss in sbt_or_siglist:
similarity = query.minhash.similarity_ignore_maxhash(ss.minhash)
if similarity > 0.0:
results.append((similarity, ss))

if not results:
return None, None, None

# take the best result
results.sort(key=lambda x: -x[0]) # reverse sort on similarity
best_similarity, best_leaf = results[0]
return best_similarity, best_leaf, filename


# define a function to build new signature object from set of mins
def build_new_signature(mins, template_sig):
e = template_sig.minhash.copy_and_clear()
e.add_many(mins)
return sig.SourmashSignature(e)

# xxx
# pretty-printing code.
def format_bp(bp):
bp = float(bp)
if bp < 500:
Expand All @@ -979,76 +907,21 @@ def format_bp(bp):
return '{:.1f} Gbp'.format(round(bp / 1e9, 1))
return '???'

# construct a new query that doesn't have the max_hash attribute set.
new_mins = query.minhash.get_hashes()
query = build_new_signature(new_mins, orig_query)

sum_found = 0.
found = []
GatherResult = namedtuple('GatherResult',
'intersect_bp, f_orig_query, f_match, f_unique_to_query, filename, name, md5, leaf')
while 1:
best_similarity, best_leaf, filename = find_best(databases, query)
if not best_leaf: # no matches at all!
break

# subtract found hashes from search hashes, construct new search
query_mins = set(query.minhash.get_hashes())
found_mins = best_leaf.minhash.get_hashes()

# figure out what the resolution of the banding on the genome is,
# based either on an explicit --scaled parameter, or on genome
# cardinality (deprecated)
if not best_leaf.minhash.max_hash:
error('Best hash match in sbt_gather has no max_hash')
error('Please prepare database of sequences with --scaled')
sys.exit(-1)

R_genome = best_leaf.minhash.scaled

# pick the highest R / lowest resolution
R_comparison = max(R_metagenome, R_genome)

# CTB: these could probably be replaced by minhash.downsample_scaled.
new_max_hash = sourmash_lib.MAX_HASH / float(R_comparison)
query_mins = set([ i for i in query_mins if i < new_max_hash ])
found_mins = set([ i for i in found_mins if i < new_max_hash ])
orig_mins = set([ i for i in orig_mins if i < new_max_hash ])

# calculate intersection:
intersect_mins = query_mins.intersection(found_mins)
intersect_orig_mins = orig_mins.intersection(found_mins)
intersect_bp = R_comparison * len(intersect_orig_mins)
sum_found += len(intersect_mins)

if intersect_bp < args.threshold_bp: # hard cutoff for now
notify('found less than {} in common. => exiting',
format_bp(intersect_bp))
break

# calculate fractions wrt first denominator - genome size
genome_n_mins = len(found_mins)
f_match = len(intersect_mins) / float(genome_n_mins)
f_orig_query = len(intersect_orig_mins) / float(len(orig_mins))
sum_found = 0
for result, n_intersect_mins, new_max_hash, next_query in gather_databases(query, databases,
args.threshold_bp):
# print interim result & save in a list for later use
pct_query = '{:.1f}%'.format(result.f_orig_query*100)
pct_genome = '{:.1f}%'.format(result.f_match*100)

# calculate fractions wrt second denominator - metagenome size
query_n_mins = len(orig_query.minhash.get_hashes())
f_unique_to_query = len(intersect_mins) / float(query_n_mins)
name = result.leaf._display_name(40)

if not len(found): # first result? print header.
print_results("")
print_results("overlap p_query p_match ")
print_results("--------- ------- --------")

result = GatherResult(intersect_bp=intersect_bp,
f_orig_query=f_orig_query,
f_match=f_match,
f_unique_to_query=f_unique_to_query,
filename=filename,
md5=best_leaf.md5sum(),
name=best_leaf.name(),
leaf=best_leaf)

# print interim result & save in a list for later use
pct_query = '{:.1f}%'.format(result.f_orig_query*100)
pct_genome = '{:.1f}%'.format(result.f_match*100)
Expand All @@ -1058,16 +931,14 @@ def format_bp(bp):
print_results('{:9} {:>6} {:>6} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
name)
sum_found += n_intersect_mins
found.append(result)

# construct a new query, minus the previous one.
query_mins -= set(found_mins)
query = build_new_signature(query_mins, orig_query)

# basic reporting
print_results('\nfound {} matches total;', len(found))

sum_found /= len(orig_query.minhash.get_hashes())
sum_found /= len(query.minhash.get_hashes())
print_results('the recovered matches hit {:.1f}% of the query',
sum_found * 100)
print_results('')
Expand Down
12 changes: 6 additions & 6 deletions sourmash_lib/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
"""
A trial implementation of sequence bloom trees, Solomon & Kingsford, 2015.
This is a simple in-memory version where all of the graphs are in
memory at once; to move it onto disk, the graphs would need to be
dynamically loaded for each query.
To try it out, do::
factory = GraphFactory(ksize, tablesizes, n_tables)
Expand Down Expand Up @@ -188,8 +184,7 @@ def save(self, tag):
'.'.join([basetag, basename, 'sbt'])),
'name': node.name
}
if isinstance(node, Leaf):
data['metadata'] = node.metadata
data['metadata'] = node.metadata

node.save(os.path.join(dirprefix, data['filename']))
structure[i] = data
Expand Down Expand Up @@ -381,6 +376,7 @@ def __init__(self, factory, name=None, fullpath=None):
self._factory = factory
self._data = None
self._filename = fullpath
self.metadata = dict()

def __str__(self):
return '*Node:{name} [occupied: {nb}, fpr: {fpr:.2}]'.format(
Expand All @@ -407,10 +403,14 @@ def data(self, new_data):
def load(info, dirname):
filename = os.path.join(dirname, info['filename'])
new_node = Node(info['factory'], name=info['name'], fullpath=filename)
new_node.metadata = info.get('metadata', {})
return new_node

def update(self, parent):
parent.data.update(self.data)
max_n_below = max(parent.metadata.get('max_n_below', 0),
self.metadata.get('max_n_below'))
parent.metadata['max_n_below'] = max_n_below


class Leaf(object):
Expand Down
Loading

0 comments on commit 2c2552a

Please sign in to comment.