Skip to content

Commit

Permalink
Fix bug in low memory random sampling and add more tests (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
J535D165 authored Feb 23, 2020
1 parent c12ef4d commit c7ab5d7
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 62 deletions.
55 changes: 34 additions & 21 deletions recordlinkage/algorithms/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from recordlinkage.measures import max_pairs
from recordlinkage.measures import full_index_size


def _map_tril_1d_on_2d(indices, dims):
Expand All @@ -17,20 +17,13 @@ def _map_tril_1d_on_2d(indices, dims):
return np.array([r, c], dtype=np.int64)


def _unique_rows_numpy(a):
"""return unique rows"""
a = np.ascontiguousarray(a)
unique_a = np.unique(a.view([('', a.dtype)] * a.shape[1]))
return unique_a.view(a.dtype).reshape((unique_a.shape[0], a.shape[1]))


def random_pairs_with_replacement(n, shape, random_state=None):
"""make random record pairs"""

if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(random_state)

n_max = max_pairs(shape)
n_max = full_index_size(shape)

if n_max <= 0:
raise ValueError('n_max must be larger than 0')
Expand All @@ -41,13 +34,19 @@ def random_pairs_with_replacement(n, shape, random_state=None):
if len(shape) == 1:
return _map_tril_1d_on_2d(indices, shape[0])
else:
return np.unravel_index(indices, shape)
return np.array(np.unravel_index(indices, shape))


def random_pairs_without_replacement_small_frames(
def random_pairs_without_replacement(
n, shape, random_state=None):
"""Return record pairs for dense sample.
Sample random record pairs without replacement bounded by the
maximum number of record pairs (based on shape). This algorithm is
efficient and fast for relative small samples.
"""

n_max = max_pairs(shape)
n_max = full_index_size(shape)

if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(random_state)
Expand All @@ -63,16 +62,27 @@ def random_pairs_without_replacement_small_frames(
if len(shape) == 1:
return _map_tril_1d_on_2d(sample, shape[0])
else:
return np.unravel_index(sample, shape)
return np.array(np.unravel_index(sample, shape))


def random_pairs_without_replacement_large_frames(
def random_pairs_without_replacement_low_memory(
n, shape, random_state=None):
"""Make a sample of random pairs with replacement"""
"""Make a sample of random pairs with replacement.
Sample random record pairs without replacement bounded by the
maximum number of record pairs (based on shape). This algorithm
consumes low memory and is fast for relatively small samples.
"""

n_max = full_index_size(shape)

n_max = max_pairs(shape)
if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(random_state)

if not isinstance(n, int) or n <= 0 or n > n_max:
raise ValueError("n must be a integer satisfying 0<n<=%s" % n_max)

sample = np.array([])
sample = np.array([], dtype=np.int64)

# Run as long as the number of pairs is less than the requested number
# of pairs n.
Expand All @@ -81,14 +91,17 @@ def random_pairs_without_replacement_large_frames(
# The number of pairs to sample (sample twice as much record pairs
# because the duplicates are dropped).
n_sample_size = (n - len(sample)) * 2
sample = random_state.randint(n_max, size=n_sample_size)
sample_sub = random_state.randint(
n_max,
size=n_sample_size
)

# concatenate pairs and deduplicate
pairs_non_unique = np.append(sample, sample)
sample = _unique_rows_numpy(pairs_non_unique)
pairs_non_unique = np.append(sample, sample_sub)
sample = np.unique(pairs_non_unique)

# return 2d indices
if len(shape) == 1:
return _map_tril_1d_on_2d(sample[0:n], shape[0])
else:
return np.unravel_index(sample[0:n], shape)
return np.array(np.unravel_index(sample[0:n], shape))
17 changes: 10 additions & 7 deletions recordlinkage/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from recordlinkage import rl_logging as logging
from recordlinkage.algorithms.indexing import (
random_pairs_with_replacement,
random_pairs_without_replacement_large_frames,
random_pairs_without_replacement_small_frames)
random_pairs_without_replacement_low_memory,
random_pairs_without_replacement)
from recordlinkage.base import BaseIndexAlgorithm
from recordlinkage.measures import full_index_size
from recordlinkage.utils import DeprecationHelper, listify, construct_multiindex
Expand Down Expand Up @@ -412,13 +412,16 @@ def _link_index(self, df_a, df_b):
raise ValueError(
"n must be a integer satisfying 0<n<=%s" % n_max)

# the fraction of pairs in the sample
frac = self.n / n_max

# large dataframes
if n_max < 1e6:
pairs = random_pairs_without_replacement_small_frames(
if n_max < 1e6 or frac > 0.5:
pairs = random_pairs_without_replacement(
self.n, shape, self.random_state)
# small dataframes
else:
pairs = random_pairs_without_replacement_large_frames(
pairs = random_pairs_without_replacement_low_memory(
self.n, shape, self.random_state)

levels = [df_a.index.values, df_b.index.values]
Expand Down Expand Up @@ -447,11 +450,11 @@ def _dedup_index(self, df_a):

# large dataframes
if n_max < 1e6:
pairs = random_pairs_without_replacement_small_frames(
pairs = random_pairs_without_replacement(
self.n, shape, self.random_state)
# small dataframes
else:
pairs = random_pairs_without_replacement_large_frames(
pairs = random_pairs_without_replacement_low_memory(
self.n, shape, self.random_state)

levels = [df_a.index.values, df_a.index.values]
Expand Down
Loading

0 comments on commit c7ab5d7

Please sign in to comment.