Skip to content

Commit

Permalink
Added ability to set CD-HIT memory limit on the command line - Issue s…
Browse files Browse the repository at this point in the history
  • Loading branch information
kpepper committed Mar 12, 2019
1 parent da79fad commit 70c82f5
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 9 deletions.
26 changes: 20 additions & 6 deletions ariba/cdhit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@ def __init__(
seq_identity_threshold=0.9,
threads=1,
length_diff_cutoff=0.0,
memory_limit=None,
verbose=False,
min_cluster_number=0
):

if not os.path.exists(infile):
raise Error('File not found: "' + infile + '". Cannot continue')

if (memory_limit is not None) and (memory_limit < 0):
raise Error('Input parameter cdhit_max_memory is set to an invalid value. Cannot continue')

self.infile = os.path.abspath(infile)
self.seq_identity_threshold = seq_identity_threshold
self.threads = threads
self.length_diff_cutoff = length_diff_cutoff
self.memory_limit = memory_limit
self.verbose = verbose
self.min_cluster_number = min_cluster_number
extern_progs = external_progs.ExternalProgs(fail_on_error=True, using_spades=False)
Expand Down Expand Up @@ -133,24 +138,33 @@ def _get_clusters_from_bak_file(filename, min_cluster_number=0):
return clusters


def run(self):
tmpdir = tempfile.mkdtemp(prefix='tmp.run_cd-hit.', dir=os.getcwd())
cdhit_fasta = os.path.join(tmpdir, 'cdhit')
cluster_info_outfile = cdhit_fasta + '.bak.clstr'

def get_run_cmd(self, output_file):
cmd = ' '.join([
self.cd_hit_est,
'-i', self.infile,
'-o', cdhit_fasta,
'-o', output_file,
'-c', str(self.seq_identity_threshold),
'-T', str(self.threads),
'-s', str(self.length_diff_cutoff),
'-d 0',
'-bak 1',
])

# Add in cdhit memory allocation if one has been specified
if self.memory_limit is not None:
cmd = ' '.join([cmd, '-M', str(self.memory_limit)])

return cmd


def run(self):
tmpdir = tempfile.mkdtemp(prefix='tmp.run_cd-hit.', dir=os.getcwd())
cdhit_fasta = os.path.join(tmpdir, 'cdhit')
cluster_info_outfile = cdhit_fasta + '.bak.clstr'
cmd = self.get_run_cmd(cdhit_fasta)
common.syscall(cmd, verbose=self.verbose)
clusters = self._get_clusters_from_bak_file(cluster_info_outfile, self.min_cluster_number)
common.rmtree(tmpdir)
return clusters


5 changes: 4 additions & 1 deletion ariba/ref_preparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self,
genetic_code=11,
cdhit_min_id=0.9,
cdhit_min_length=0.0,
cdhit_max_memory=None,
run_cdhit=True,
clusters_file=None,
threads=1,
Expand All @@ -40,6 +41,7 @@ def __init__(self,
self.genetic_code = genetic_code
self.cdhit_min_id = cdhit_min_id
self.cdhit_min_length = cdhit_min_length
self.cdhit_max_memory = cdhit_max_memory
self.run_cdhit = run_cdhit
self.clusters_file = clusters_file
self.threads = threads
Expand Down Expand Up @@ -193,6 +195,7 @@ def run(self, outdir):
seq_identity_threshold=self.cdhit_min_id,
threads=self.threads,
length_diff_cutoff=self.cdhit_min_length,
memory_limit=self.cdhit_max_memory,
nocluster=not self.run_cdhit,
verbose=self.verbose,
clusters_file=self.clusters_file,
Expand All @@ -214,4 +217,4 @@ def run(self, outdir):
print(' grep REMOVE', os.path.join(outdir, '01.filter.check_genes.log'), file=sys.stderr)

if number_of_bad_variants_logged > 0:
print('WARNING. Problem with at least one variant. Problem variants are rmoved. Please see the file', os.path.join(outdir, '01.filter.check_metadata.log'), 'for details.', file=sys.stderr)
print('WARNING. Problem with at least one variant. Problem variants are removed. Please see the file', os.path.join(outdir, '01.filter.check_metadata.log'), 'for details.', file=sys.stderr)
3 changes: 2 additions & 1 deletion ariba/reference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def write_cluster_allocation_file(clusters, outfile):
pyfastaq.utils.close(f_out)


def cluster_with_cdhit(self, outprefix, seq_identity_threshold=0.9, threads=1, length_diff_cutoff=0.0, nocluster=False, verbose=False, clusters_file=None):
def cluster_with_cdhit(self, outprefix, seq_identity_threshold=0.9, threads=1, length_diff_cutoff=0.0, memory_limit=None, nocluster=False, verbose=False, clusters_file=None):
clusters = {}
ReferenceData._write_sequences_to_files(self.sequences, self.metadata, outprefix)
ref_types = ('noncoding', 'noncoding.varonly', 'gene', 'gene.varonly')
Expand All @@ -454,6 +454,7 @@ def cluster_with_cdhit(self, outprefix, seq_identity_threshold=0.9, threads=1, l
seq_identity_threshold=seq_identity_threshold,
threads=threads,
length_diff_cutoff=length_diff_cutoff,
memory_limit=memory_limit,
verbose=verbose,
min_cluster_number = min_cluster_number,
)
Expand Down
1 change: 1 addition & 0 deletions ariba/tasks/prepareref.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def run(options):
genetic_code=options.genetic_code,
cdhit_min_id=options.cdhit_min_id,
cdhit_min_length=options.cdhit_min_length,
cdhit_max_memory=options.cdhit_max_memory,
run_cdhit=not options.no_cdhit,
clusters_file=options.cdhit_clusters,
threads=options.threads,
Expand Down
36 changes: 36 additions & 0 deletions ariba/tests/cdhit_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
import os
import re
from ariba import cdhit, external_progs


modules_dir = os.path.dirname(os.path.abspath(cdhit.__file__))
data_dir = os.path.join(modules_dir, 'tests', 'data')
extern_progs = external_progs.ExternalProgs()
Expand All @@ -13,6 +15,13 @@ def test_init_fail_infile_missing(self):
cdhit.Runner('oopsnotafile', 'out')


def test_init_fail_invalid_memory(self):
'''test_init_fail_invalid_memory'''
infile = os.path.join(data_dir, 'cdhit_test_run.in.fa')
with self.assertRaises(cdhit.Error):
cdhit.Runner(infile, memory_limit=-10)


def test_get_clusters_from_bak_file(self):
'''test _get_clusters_from_bak_file'''
infile = os.path.join(data_dir, 'cdhit_test_get_clusters_from_bak_file.in')
Expand Down Expand Up @@ -162,3 +171,30 @@ def test_run_get_clusters_from_file_with_renaming(self):
'1': {'seq3'},
}
self.assertEqual(clusters, expected_clusters)


def test_get_run_cmd_with_default_memory(self):
'''test_get_run_cmd_with_default_memory'''
fa_infile = os.path.join(data_dir, 'cdhit_test_run_get_clusters_from_dict_rename.in.fa')
r = cdhit.Runner(fa_infile)
run_cmd = r.get_run_cmd('foo/bar/file.out')
match = re.search('^.+cd-hit-est -i .+ -o foo/bar/file.out -c 0.9 -T 1 -s 0.0 -d 0 -bak 1$', run_cmd)
self.assertTrue(match)


def test_get_run_cmd_with_non_default_memory(self):
'''test_get_run_cmd_with_non_default_memory'''
fa_infile = os.path.join(data_dir, 'cdhit_test_run_get_clusters_from_dict_rename.in.fa')
r = cdhit.Runner(fa_infile, memory_limit=900)
run_cmd = r.get_run_cmd('foo/bar/file.out')
match = re.search('^.+cd-hit-est -i .+ -c 0.9 -T 1 -s 0.0 -d 0 -bak 1 -M 900$', run_cmd)
self.assertTrue(match)


def test_get_run_cmd_with_unlimited_memory(self):
'''test_get_run_cmd_with_unlimited_memory'''
fa_infile = os.path.join(data_dir, 'cdhit_test_run_get_clusters_from_dict_rename.in.fa')
r = cdhit.Runner(fa_infile, memory_limit=0)
run_cmd = r.get_run_cmd('foo/bar/file.out')
match = re.search('^.+cd-hit-est -i .+ -c 0.9 -T 1 -s 0.0 -d 0 -bak 1 -M 0$', run_cmd)
self.assertTrue(match)
3 changes: 2 additions & 1 deletion scripts/ariba
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ cdhit_group = subparser_prepareref.add_argument_group('cd-hit options')
cdhit_group.add_argument('--no_cdhit', action='store_true', help='Do not run cd-hit. Each input sequence is put into its own "cluster". Incompatible with --cdhit_clusters.')
cdhit_group.add_argument('--cdhit_clusters', help='File specifying how the sequences should be clustered. Will be used instead of running cdhit. Format is one cluster per line. Sequence names separated by whitespace. Incompatible with --no_cdhit', metavar='FILENAME')
cdhit_group.add_argument('--cdhit_min_id', type=float, help='Sequence identity threshold (cd-hit option -c) [%(default)s]', default=0.9, metavar='FLOAT')
cdhit_group.add_argument('--cdhit_min_length', type=float, help='length difference cutoff (cd-hit option -s) [%(default)s]', default=0.0, metavar='FLOAT')
cdhit_group.add_argument('--cdhit_min_length', type=float, help='Length difference cutoff (cd-hit option -s) [%(default)s]', default=0.0, metavar='FLOAT')
cdhit_group.add_argument('--cdhit_max_memory', type=int, help='Memory limit in MB (cd-hit option -M) [%(default)s]. Use 0 for unlimited.', metavar='INT')

other_prep_group = subparser_prepareref.add_argument_group('other options')
other_prep_group.add_argument('--min_gene_length', type=int, help='Minimum allowed length in nucleotides of reference genes [%(default)s]', metavar='INT', default=6)
Expand Down

0 comments on commit 70c82f5

Please sign in to comment.