Skip to content

Commit

Permalink
Merge pull request biolab#2936 from cemsbr/oversampling
Browse files Browse the repository at this point in the history
Allow oversampling in Data Sampler widget
  • Loading branch information
astaric authored Mar 3, 2018
2 parents b09b282 + ec30b92 commit 9249ec4
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 12 deletions.
42 changes: 31 additions & 11 deletions Orange/widgets/data/owdatasampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import sklearn.model_selection as skl

from Orange.widgets import widget, gui
from Orange.widgets import gui
from Orange.widgets.settings import Setting
from Orange.data import Table
from Orange.data.sql.table import SqlTable
Expand All @@ -24,6 +24,8 @@ class OWDataSampler(OWWidget):
category = "Data"
keywords = ["data", "sample"]

_MAX_SAMPLE_SIZE = 2 ** 31 - 1

class Inputs:
data = Input("Data", Table)

Expand Down Expand Up @@ -52,6 +54,7 @@ class Outputs:

class Warning(OWWidget.Warning):
could_not_stratify = Msg("Stratification failed\n{}")
bigger_sample = Msg('Sample is bigger than input')

class Error(OWWidget.Error):
too_many_folds = Msg("Number of folds exceeds data size")
Expand All @@ -74,10 +77,10 @@ def __init__(self):
callback=self.sampling_type_changed)

def set_sampling_type(i):
def f():
def set_sampling_type_i():
self.sampling_type = i
self.sampling_type_changed()
return f
return set_sampling_type_i

gui.appendRadioButton(sampling, "Fixed proportion of data:")
self.sampleSizePercentageSlider = gui.hSlider(
Expand All @@ -91,8 +94,9 @@ def f():
ibox = gui.indentedBox(sampling)
self.sampleSizeSpin = gui.spin(
ibox, self, "sampleSizeNumber", label="Instances: ",
minv=1, maxv=2 ** 31 - 1,
callback=set_sampling_type(self.FixedSize))
minv=1, maxv=self._MAX_SAMPLE_SIZE,
callback=set_sampling_type(self.FixedSize),
controlWidth=90)
gui.checkBox(
ibox, self, "replacement", "Sample with replacement",
callback=set_sampling_type(self.FixedSize),
Expand Down Expand Up @@ -132,7 +136,6 @@ def f():
spin.setSuffix(" %")
self.sql_box.setVisible(False)


self.options_box = gui.vBox(self.controlArea, "Options")
self.cb_seed = gui.checkBox(
self.options_box, self, "use_seed",
Expand Down Expand Up @@ -162,6 +165,7 @@ def fold_changed(self):
self.sampling_type = self.CrossValidation

def settings_changed(self):
self._update_sample_max_size()
self.indices = None

@Inputs.data
Expand All @@ -179,7 +183,7 @@ def set_data(self, dataset):
('~', dataset.approx_len()) if sql else
('', len(dataset)))))
if not sql:
self.sampleSizeSpin.setMaximum(len(dataset))
self._update_sample_max_size()
self.updateindices()
else:
self.dataInfoLabel.setText('No data on input.')
Expand All @@ -188,6 +192,13 @@ def set_data(self, dataset):
self.clear_messages()
self.commit()

def _update_sample_max_size(self):
"""Limit number of instances to input size unless using replacement."""
if not self.data or self.replacement:
self.sampleSizeSpin.setMaximum(self._MAX_SAMPLE_SIZE)
else:
self.sampleSizeSpin.setMaximum(len(self.data))

def commit(self):
if self.data is None:
sample = other = None
Expand Down Expand Up @@ -231,6 +242,7 @@ def commit(self):

def updateindices(self):
self.Error.clear()
self.Warning.clear()
repl = True
data_length = len(self.data)
num_classes = len(self.data.domain.class_var.values) \
Expand Down Expand Up @@ -260,8 +272,14 @@ def updateindices(self):
self.indices = None
return

# By the above, we can safely assume there is data
if self.sampling_type == self.FixedSize and repl and size and \
size > len(self.data):
# This should only be possible when using replacement
self.Warning.bigger_sample()

stratified = (self.stratify and
type(self.data) == Table and
isinstance(self.data, Table) and
self.data.domain.has_discrete_class)
try:
self.indices = self.sample(data_length, size, stratified)
Expand All @@ -277,7 +295,8 @@ def sample(self, data_length, size, stratified):
random_state=rnd)
elif self.sampling_type == self.FixedProportion:
self.indice_gen = SampleRandomP(
self.sampleSizePercentage / 100, stratified=stratified, random_state=rnd)
self.sampleSizePercentage / 100, stratified=stratified,
random_state=rnd)
elif self.sampling_type == self.Bootstrap:
self.indice_gen = SampleBootstrap(data_length, random_state=rnd)
else:
Expand Down Expand Up @@ -323,7 +342,8 @@ def __init__(self, folds=10, stratified=False, random_state=None):
Args:
folds (int): Number of folds
stratified (bool): Return stratified indices (if applicable).
random_state (Random): An initial state for replicable random behavior
random_state (Random): An initial state for replicable random
behavior
Returns:
tuple-of-arrays: A tuple of array indices one for each fold.
Expand All @@ -349,7 +369,7 @@ def __call__(self, table):

class SampleRandomN(Reprable):
def __init__(self, n=0, stratified=False, replace=False,
random_state=None):
random_state=None):
self.n = n
self.stratified = stratified
self.replace = replace
Expand Down
32 changes: 32 additions & 0 deletions Orange/widgets/data/tests/test_owdatasampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,38 @@ def test_no_intersection_in_outputs(self):
self.assertEqual(len(self.zoo), len(sample) + len(other))
self.assertNoIntersection(sample, other)

def test_bigger_size_with_replacement(self):
"""Allow bigger output without replacement."""
self.send_signal('Data', self.iris[:2])
sample_size = self.set_fixed_sample_size(3, with_replacement=True)
self.assertEqual(3, sample_size, 'Should be able to set a bigger size '
'with replacement')

def test_bigger_size_without_replacement(self):
"""Lower output samples to match input's without replacement."""
self.send_signal('Data', self.iris[:2])
sample_size = self.set_fixed_sample_size(3)
self.assertEqual(2, sample_size)

def test_bigger_output_warning(self):
"""Should warn when sample size is bigger than input."""
self.send_signal('Data', self.iris[:2])
self.set_fixed_sample_size(3, with_replacement=True)
self.assertTrue(self.widget.Warning.bigger_sample.is_shown())

def set_fixed_sample_size(self, sample_size, with_replacement=False):
"""Set fixed sample size and return the number of gui spin.
Return the actual number in gui so we can check whether it is different
from sample_size. The number can be changed depending on the spin
maximum value.
"""
self.select_sampling_type(self.widget.FixedSize)
self.widget.controls.replacement.setChecked(with_replacement)
self.widget.sampleSizeSpin.setValue(sample_size)
self.widget.commit()
return self.widget.sampleSizeSpin.value()

def assertNoIntersection(self, sample, other):
for inst in sample:
self.assertNotIn(inst, other)
3 changes: 2 additions & 1 deletion doc/visual-programming/source/widgets/data/datasampler.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ provided and *Sample Data* is pressed.
- **Fixed sample size** returns a selected number of data instances
with a chance to set *Sample with replacement*, which always samples
from the entire dataset (does not subtract instances already in
the subset)
the subset). With replacement, you can generate more instances than
available in the input dataset.
- `Cross Validation <https://en.wikipedia.org/wiki/Cross-validation_(statistics)>`_
partitions data instances into complementary subsets, where you can
select the number of folds (subsets) and which fold you want to
Expand Down

0 comments on commit 9249ec4

Please sign in to comment.