Skip to content

Commit

Permalink
New style signals
Browse files Browse the repository at this point in the history
  • Loading branch information
astaric committed Jul 26, 2017
1 parent dc48d5c commit 4702973
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 63 deletions.
2 changes: 1 addition & 1 deletion Orange/widgets/model/owadaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class OWAdaBoost(OWBaseLearner):

LEARNER = SklAdaBoostLearner

class Inputs:
class Inputs(OWBaseLearner.Inputs):
learner = Input("Learner", Learner)

#: Algorithms for classification problems
Expand Down
2 changes: 1 addition & 1 deletion Orange/widgets/model/owrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def update_model(self):
self.model.name = self.learner_name
self.model.instances = self.data
self.valid_data = True
self.send(self.OUTPUT_MODEL_NAME, self.model)
self.Outputs.model.send(self.model)

def create_learner(self):
"""
Expand Down
8 changes: 5 additions & 3 deletions Orange/widgets/model/owsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from Orange.canvas.report import bool_str
from Orange.data import ContinuousVariable, StringVariable, Domain, Table
from Orange.modelling.linear import SGDLearner
from Orange.widgets import gui, widget
from Orange.widgets import gui
from Orange.widgets.model.owlogisticregression import create_coef_table
from Orange.widgets.settings import Setting
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
from Orange.widgets.utils.signals import Output

MAXINT = 2 ** 31 - 1

Expand All @@ -25,7 +26,8 @@ class OWSGD(OWBaseLearner):

LEARNER = SGDLearner

outputs = [("Coefficients", Table, widget.Explicit)]
class Outputs(OWBaseLearner.Outputs):
coefficients = Output("Coefficients", Table)

reg_losses = (
('Squared Loss', 'squared_loss'),
Expand Down Expand Up @@ -302,7 +304,7 @@ def update_model(self):
[attr.name for attr in self.model.domain.attributes]
coeffs = Table(domain, list(zip(cfs, names)))
coeffs.name = "coefficients"
self.send("Coefficients", coeffs)
self.Outputs.coefficients.send(coeffs)


if __name__ == '__main__':
Expand Down
6 changes: 4 additions & 2 deletions Orange/widgets/model/owsvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from Orange.widgets import gui, widget
from Orange.widgets.settings import Setting
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
from Orange.widgets.utils.signals import Output


class OWSVM(OWBaseLearner):
Expand All @@ -23,7 +24,8 @@ class OWSVM(OWBaseLearner):

LEARNER = SVMLearner

outputs = [("Support vectors", Table, widget.Explicit)]
class Outputs(OWBaseLearner.Outputs):
support_vectors = Output("Support vectors", Table)

#: Different types of SVMs
SVM, Nu_SVM = range(2)
Expand Down Expand Up @@ -186,7 +188,7 @@ def update_model(self):
sv = None
if self.model is not None:
sv = self.data[self.model.skl_model.support_]
self.send("Support vectors", sv)
self.Outputs.support_vectors.send(sv)

def _on_kernel_changed(self):
self._show_right_kernel()
Expand Down
12 changes: 6 additions & 6 deletions Orange/widgets/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def test_input_data_disconnect(self):
self.widget.apply_button.button.click()
self.send_signal("Data", None)
self.assertEqual(self.widget.data, None)
self.assertIsNone(self.get_output(self.model_name))
self.assertIsNone(self.get_output(self.widget.Outputs.model))

def test_input_data_learner_adequacy(self):
"""Check if error message is shown with inadequate data on input"""
Expand Down Expand Up @@ -509,12 +509,12 @@ def test_output_learner(self):

def test_output_model(self):
"""Check if model is on output after sending data and apply"""
self.assertIsNone(self.get_output(self.model_name))
self.assertIsNone(self.get_output(self.widget.Outputs.model))
self.widget.apply_button.button.click()
self.assertIsNone(self.get_output(self.model_name))
self.assertIsNone(self.get_output(self.widget.Outputs.model))
self.send_signal('Data', self.data)
self.widget.apply_button.button.click()
model = self.get_output(self.model_name)
model = self.get_output(self.widget.Outputs.model)
self.assertIsNotNone(model)
self.assertIsInstance(model, self.widget.LEARNER.__returns__)
self.assertIsInstance(model, self.model_class)
Expand All @@ -535,7 +535,7 @@ def test_output_model_name(self):
self.widget.name_line_edit.setText(new_name)
self.send_signal("Data", self.data)
self.widget.apply_button.button.click()
self.assertEqual(self.get_output(self.model_name).name, new_name)
self.assertEqual(self.get_output(self.widget.Outputs.model).name, new_name)

def _get_param_value(self, learner, param):
if isinstance(learner, Fitter):
Expand Down Expand Up @@ -592,7 +592,7 @@ def test_parameters(self):
"Mismatching setting for parameter '%s'" % parameter)

if issubclass(self.widget.LEARNER, SklModel):
model = self.get_output(self.model_name)
model = self.get_output(self.widget.Outputs.model)
if model is not None:
self.assertEqual(self._get_param_value(model, parameter), value)
self.assertFalse(self.widget.Error.active)
Expand Down
45 changes: 20 additions & 25 deletions Orange/widgets/tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import unittest

from Orange.classification import KNNLearner
from Orange.data import Table
from Orange.preprocess.preprocess import Preprocess
from Orange.modelling import TreeLearner
from Orange.regression import MeanLearner
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
from Orange.widgets.utils.signals import Output


class TestProviderMetaClass(unittest.TestCase):
Expand All @@ -13,32 +14,26 @@ def assertChannelsEqual(self, first, second, msg=None):
self.assertEqual(set(i.name for i in first),
set(i[0] for i in second), msg)

def test_inputs(self):
inputs = [("Data", Table, "set_data"),
("Preprocessor", Preprocess, "set_preprocessor")]

class OWTestProvider(OWBaseLearner):
def test_widgets_do_not_share_outputs(self):
class WidgetA(OWBaseLearner):
name = "A"
LEARNER = KNNLearner
name = "test widget"

self.assertChannelsEqual(OWTestProvider.inputs, inputs)

def test_outputs(self):
expected_outputs = [
("Learner", KNNLearner),
("Classifier", KNNLearner.__returns__),
("Test", Table)
]
class WidgetB(OWBaseLearner):
name = "B"
LEARNER = MeanLearner

class OWTestProvider(OWBaseLearner):
LEARNER = KNNLearner
name = "test widget"
self.assertEqual(WidgetA.Outputs.learner.type, KNNLearner)
self.assertEqual(WidgetB.Outputs.learner.type, MeanLearner)

outputs = [("Test", Table)]
class WidgetC(WidgetA):
name = "C"
LEARNER = TreeLearner

self.assertChannelsEqual(OWTestProvider.outputs, expected_outputs)
class Outputs(WidgetA.Outputs):
test = Output("test", str)

def test_class_without_learner(self):
with self.assertRaises(AttributeError):
class OWTestProvider(OWBaseLearner):
name = 'test'
self.assertEqual(WidgetC.Outputs.learner.type, TreeLearner)
self.assertEqual(WidgetC.Outputs.test.name, "test")
self.assertEqual(WidgetA.Outputs.learner.type, KNNLearner)
self.assertFalse(hasattr(WidgetA.Outputs, "test"))
16 changes: 16 additions & 0 deletions Orange/widgets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import sys

from AnyQt.QtCore import QObject
Expand Down Expand Up @@ -70,3 +71,18 @@ def dumpObjectTree(obj, _indent=0):
file=sys.stderr)
for child in obj.children():
dumpObjectTree(child, _indent + 1)


def getmembers(obj, predicate=None):
"""Return all the members of an object in a list of (name, value) pairs sorted by name.
Behaves like inspect.getmembers. If a type object is passed as a predicate,
only members of that type are returned.
"""

if isinstance(predicate, type):
def mypredicate(x):
return isinstance(x, predicate)
else:
mypredicate = predicate
return inspect.getmembers(obj, mypredicate)
47 changes: 32 additions & 15 deletions Orange/widgets/utils/owlearnerwidget.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from copy import deepcopy
import numpy as np

from AnyQt.QtCore import QTimer, Qt

Expand All @@ -8,27 +8,46 @@
from Orange.preprocess.preprocess import Preprocess
from Orange.widgets import gui
from Orange.widgets.settings import Setting
from Orange.widgets.utils import getmembers
from Orange.widgets.utils.signals import Output, Input
from Orange.widgets.utils.sql import check_sql_input
from Orange.widgets.widget import OWWidget, WidgetMetaClass, Msg


class OWBaseLearnerMeta(WidgetMetaClass):
""" Meta class for learner widgets
OWBaseLearner declares two outputs, learner and model with
generic type (Learner and Model).
This metaclass ensures that each of the subclasses gets
its own Outputs class with output that match the corresponding
learner.
"""
def __new__(mcls, name, bases, attributes):
learner = attributes.get("LEARNER", None)
# classes with empty names are considered abstract
if attributes.get(name):
if learner is None:
raise AttributeError(
"'{}' must declare attribute LEARNER".format(name))
def __new__(cls, name, bases, attributes):
def abstract_widget():
return not attributes.get("name")

def copy_outputs(template):
result = type("Outputs", (), {})
for name, signal in getmembers(template, Output):
setattr(result, name, deepcopy(signal))
return result

obj = super().__new__(cls, name, bases, attributes)
if abstract_widget():
return obj

learner = attributes.get("LEARNER")
if not learner:
raise AttributeError(
"'{}' must declare attribute LEARNER".format(name))

outputs = attributes["Outputs"] = deepcopy(attributes["Outputs"])
outputs.learner.type = learner
outputs.model.type = learner.__returns__
outputs = obj.Outputs = copy_outputs(obj.Outputs)
outputs.learner.type = learner
outputs.model.type = learner.__returns__

return super().__new__(mcls, name, bases, attributes)
return obj


class OWBaseLearner(OWWidget, metaclass=OWBaseLearnerMeta):
Expand Down Expand Up @@ -64,8 +83,6 @@ class Inputs:
preprocessor = Input("Preprocessor", Preprocess)

class Outputs:
# Exact output types for each learner widget are set in
# OWBaseLearnerMeta meta class
learner = Output("Learner", Learner, dynamic=False)
model = Output("Model", Model, dynamic=False,
replaces=["Classifier", "Predictor"])
Expand All @@ -75,7 +92,7 @@ def __init__(self):
self.data = None
self.valid_data = False
self.learner = None
self.learner_name = self.LEARNER.name
self.learner_name = self.name
self.model = None
self.preprocessors = None
self.outdated_settings = False
Expand Down
14 changes: 4 additions & 10 deletions Orange/widgets/utils/signals.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import copy
import inspect
import itertools

from Orange.canvas.registry.description import InputSignal, OutputSignal
from Orange.widgets.utils import getmembers

# increasing counter for ensuring the order of Input/Output definitions
# is preserved when going through the unordered class namespace of
Expand Down Expand Up @@ -146,12 +146,6 @@ def send(self, value, id=None):
signal_manager.send(self.widget, self.name, value, id)


def _get_members(obj, member_type):
def is_member_type(member):
return isinstance(member, member_type)
return inspect.getmembers(obj, is_member_type)


class WidgetSignalsMixin:
"""Mixin for managing widget's input and output signals"""
class Inputs:
Expand All @@ -165,7 +159,7 @@ def __init__(self):

def _bind_outputs(self):
bound_cls = self.Outputs()
for name, signal in _get_members(bound_cls, Output):
for name, signal in getmembers(bound_cls, Output):
setattr(bound_cls, name, signal.bound_signal(self))
setattr(self, "Outputs", bound_cls)

Expand Down Expand Up @@ -217,7 +211,7 @@ def signal_from_args(args, signal_type):
@classmethod
def _check_input_handlers(cls):
unbound = [signal.name
for _, signal in _get_members(cls.Inputs, Input)
for _, signal in getmembers(cls.Inputs, Input)
if not signal.handler]
if unbound:
raise ValueError("unbound signal(s) in {}: {}".
Expand Down Expand Up @@ -249,7 +243,7 @@ def get_signals(cls, direction):
return old_style

signal_class = getattr(cls, direction.title())
signals = [signal for _, signal in _get_members(signal_class, _Signal)]
signals = [signal for _, signal in getmembers(signal_class, _Signal)]
return list(sorted(signals, key=lambda s: s._seq_id))


Expand Down

0 comments on commit 4702973

Please sign in to comment.