Skip to content

Commit

Permalink
bpo-45500: Rewrite test_dbm (GH-29002)
Browse files Browse the repository at this point in the history
* Generate test classes at import time. It allows to filter them when
  run with unittest. E.g: "./python -m unittest test.test_dbm.TestCase_gnu -v".
* Create a database class in a new directory which will be removed after
  test. It guarantees that all created files and directories be removed
  and will not conflict with other dbm tests.
* Restore dbm._defaultmod after tests. Previously it was set to the last
  dbm module (dbm.dumb) which affected other tests.
* Enable the whichdb test for dbm.dumb.
* Move test_keys to the correct test class. It does not test whichdb().
* Remove some outdated code and comments.
  • Loading branch information
serhiy-storchaka authored Oct 19, 2021
1 parent 236e301 commit 975b94b
Showing 1 changed file with 57 additions and 69 deletions.
126 changes: 57 additions & 69 deletions Lib/test/test_dbm.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
"""Test script for the dbm.open function based on testdumbdbm.py"""

import unittest
import glob
import dbm
import os
from test.support import import_helper
from test.support import os_helper

# Skip tests if dbm module doesn't exist.
dbm = import_helper.import_module('dbm')

try:
from dbm import ndbm
except ImportError:
ndbm = None

_fname = os_helper.TESTFN
dirname = os_helper.TESTFN
_fname = os.path.join(dirname, os_helper.TESTFN)

#
# Iterates over every database module supported by dbm currently available,
# setting dbm to use each in turn, and yielding that module
# Iterates over every database module supported by dbm currently available.
#
def dbm_iterator():
for name in dbm._names:
Expand All @@ -32,11 +29,12 @@ def dbm_iterator():
#
# Clean up all scratch databases we might have created during testing
#
def delete_files():
# we don't know the precise name the underlying database uses
# so we use glob to locate all names
for f in glob.glob(glob.escape(_fname) + "*"):
os_helper.unlink(f)
def cleaunup_test_dir():
os_helper.rmtree(dirname)

def setup_test_dir():
cleaunup_test_dir()
os.mkdir(dirname)


class AnyDBMTestCase:
Expand Down Expand Up @@ -144,86 +142,76 @@ def read_helper(self, f):
for key in self._dict:
self.assertEqual(self._dict[key], f[key.encode("ascii")])

def tearDown(self):
delete_files()
def test_keys(self):
with dbm.open(_fname, 'c') as d:
self.assertEqual(d.keys(), [])
a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')]
for k, v in a:
d[k] = v
self.assertEqual(sorted(d.keys()), sorted(k for (k, v) in a))
for k, v in a:
self.assertIn(k, d)
self.assertEqual(d[k], v)
self.assertNotIn(b'xxx', d)
self.assertRaises(KeyError, lambda: d[b'xxx'])

def setUp(self):
self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod)
dbm._defaultmod = self.module
delete_files()
self.addCleanup(cleaunup_test_dir)
setup_test_dir()


class WhichDBTestCase(unittest.TestCase):
def test_whichdb(self):
self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod)
_bytes_fname = os.fsencode(_fname)
for path in [_fname, os_helper.FakePath(_fname),
_bytes_fname, os_helper.FakePath(_bytes_fname)]:
for module in dbm_iterator():
# Check whether whichdb correctly guesses module name
# for databases opened with "module" module.
# Try with empty files first
name = module.__name__
if name == 'dbm.dumb':
continue # whichdb can't support dbm.dumb
delete_files()
f = module.open(path, 'c')
f.close()
fnames = [_fname, os_helper.FakePath(_fname),
_bytes_fname, os_helper.FakePath(_bytes_fname)]
for module in dbm_iterator():
# Check whether whichdb correctly guesses module name
# for databases opened with "module" module.
name = module.__name__
setup_test_dir()
dbm._defaultmod = module
# Try with empty files first
with module.open(_fname, 'c'): pass
for path in fnames:
self.assertEqual(name, self.dbm.whichdb(path))
# Now add a key
f = module.open(path, 'w')
# Now add a key
with module.open(_fname, 'w') as f:
f[b"1"] = b"1"
# and test that we can find it
self.assertIn(b"1", f)
# and read it
self.assertEqual(f[b"1"], b"1")
f.close()
for path in fnames:
self.assertEqual(name, self.dbm.whichdb(path))

@unittest.skipUnless(ndbm, reason='Test requires ndbm')
def test_whichdb_ndbm(self):
# Issue 17198: check that ndbm which is referenced in whichdb is defined
db_file = '{}_ndbm.db'.format(_fname)
with open(db_file, 'w'):
self.addCleanup(os_helper.unlink, db_file)
db_file_bytes = os.fsencode(db_file)
self.assertIsNone(self.dbm.whichdb(db_file[:-3]))
self.assertIsNone(self.dbm.whichdb(os_helper.FakePath(db_file[:-3])))
self.assertIsNone(self.dbm.whichdb(db_file_bytes[:-3]))
self.assertIsNone(self.dbm.whichdb(os_helper.FakePath(db_file_bytes[:-3])))

def tearDown(self):
delete_files()
with open(_fname + '.db', 'wb'): pass
_bytes_fname = os.fsencode(_fname)
fnames = [_fname, os_helper.FakePath(_fname),
_bytes_fname, os_helper.FakePath(_bytes_fname)]
for path in fnames:
self.assertIsNone(self.dbm.whichdb(path))

def setUp(self):
delete_files()
self.filename = os_helper.TESTFN
self.d = dbm.open(self.filename, 'c')
self.d.close()
self.addCleanup(cleaunup_test_dir)
setup_test_dir()
self.dbm = import_helper.import_fresh_module('dbm')

def test_keys(self):
self.d = dbm.open(self.filename, 'c')
self.assertEqual(self.d.keys(), [])
a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')]
for k, v in a:
self.d[k] = v
self.assertEqual(sorted(self.d.keys()), sorted(k for (k, v) in a))
for k, v in a:
self.assertIn(k, self.d)
self.assertEqual(self.d[k], v)
self.assertNotIn(b'xxx', self.d)
self.assertRaises(KeyError, lambda: self.d[b'xxx'])
self.d.close()


def load_tests(loader, tests, pattern):
classes = []
for mod in dbm_iterator():
classes.append(type("TestCase-" + mod.__name__,
(AnyDBMTestCase, unittest.TestCase),
{'module': mod}))
for c in classes:
tests.addTest(loader.loadTestsFromTestCase(c))
return tests

for mod in dbm_iterator():
assert mod.__name__.startswith('dbm.')
suffix = mod.__name__[4:]
testname = f'TestCase_{suffix}'
globals()[testname] = type(testname,
(AnyDBMTestCase, unittest.TestCase),
{'module': mod})


if __name__ == "__main__":
unittest.main()

0 comments on commit 975b94b

Please sign in to comment.