diff --git a/Orange/widgets/model/owadaboost.py b/Orange/widgets/model/owadaboost.py index 8b64c924c94..d711e782918 100644 --- a/Orange/widgets/model/owadaboost.py +++ b/Orange/widgets/model/owadaboost.py @@ -22,7 +22,7 @@ class OWAdaBoost(OWBaseLearner): LEARNER = SklAdaBoostLearner - class Inputs: + class Inputs(OWBaseLearner.Inputs): learner = Input("Learner", Learner) #: Algorithms for classification problems diff --git a/Orange/widgets/model/owrules.py b/Orange/widgets/model/owrules.py index 3558ed08d5e..3964d8ce961 100644 --- a/Orange/widgets/model/owrules.py +++ b/Orange/widgets/model/owrules.py @@ -333,7 +333,7 @@ def update_model(self): self.model.name = self.learner_name self.model.instances = self.data self.valid_data = True - self.send(self.OUTPUT_MODEL_NAME, self.model) + self.Outputs.model.send(self.model) def create_learner(self): """ diff --git a/Orange/widgets/model/owsgd.py b/Orange/widgets/model/owsgd.py index 5ff744822dc..fad0b5c837b 100644 --- a/Orange/widgets/model/owsgd.py +++ b/Orange/widgets/model/owsgd.py @@ -5,10 +5,11 @@ from Orange.canvas.report import bool_str from Orange.data import ContinuousVariable, StringVariable, Domain, Table from Orange.modelling.linear import SGDLearner -from Orange.widgets import gui, widget +from Orange.widgets import gui from Orange.widgets.model.owlogisticregression import create_coef_table from Orange.widgets.settings import Setting from Orange.widgets.utils.owlearnerwidget import OWBaseLearner +from Orange.widgets.utils.signals import Output MAXINT = 2 ** 31 - 1 @@ -25,7 +26,8 @@ class OWSGD(OWBaseLearner): LEARNER = SGDLearner - outputs = [("Coefficients", Table, widget.Explicit)] + class Outputs(OWBaseLearner.Outputs): + coefficients = Output("Coefficients", Table) reg_losses = ( ('Squared Loss', 'squared_loss'), @@ -302,7 +304,7 @@ def update_model(self): [attr.name for attr in self.model.domain.attributes] coeffs = Table(domain, list(zip(cfs, names))) coeffs.name = "coefficients" - self.send("Coefficients", coeffs) + self.Outputs.coefficients.send(coeffs) if __name__ == '__main__': diff --git a/Orange/widgets/model/owsvm.py b/Orange/widgets/model/owsvm.py index 3162f7a25cb..b23d7817559 100644 --- a/Orange/widgets/model/owsvm.py +++ b/Orange/widgets/model/owsvm.py @@ -8,6 +8,7 @@ from Orange.widgets import gui, widget from Orange.widgets.settings import Setting from Orange.widgets.utils.owlearnerwidget import OWBaseLearner +from Orange.widgets.utils.signals import Output class OWSVM(OWBaseLearner): @@ -23,7 +24,8 @@ class OWSVM(OWBaseLearner): LEARNER = SVMLearner - outputs = [("Support vectors", Table, widget.Explicit)] + class Outputs(OWBaseLearner.Outputs): + support_vectors = Output("Support vectors", Table) #: Different types of SVMs SVM, Nu_SVM = range(2) @@ -186,7 +188,7 @@ def update_model(self): sv = None if self.model is not None: sv = self.data[self.model.skl_model.support_] - self.send("Support vectors", sv) + self.Outputs.support_vectors.send(sv) def _on_kernel_changed(self): self._show_right_kernel() diff --git a/Orange/widgets/tests/base.py b/Orange/widgets/tests/base.py index 7f47f5101c1..afe59043f0e 100644 --- a/Orange/widgets/tests/base.py +++ b/Orange/widgets/tests/base.py @@ -453,7 +453,7 @@ def test_input_data_disconnect(self): self.widget.apply_button.button.click() self.send_signal("Data", None) self.assertEqual(self.widget.data, None) - self.assertIsNone(self.get_output(self.model_name)) + self.assertIsNone(self.get_output(self.widget.Outputs.model)) def test_input_data_learner_adequacy(self): """Check if error message is shown with inadequate data on input""" @@ -509,12 +509,12 @@ def test_output_learner(self): def test_output_model(self): """Check if model is on output after sending data and apply""" - self.assertIsNone(self.get_output(self.model_name)) + self.assertIsNone(self.get_output(self.widget.Outputs.model)) self.widget.apply_button.button.click() - self.assertIsNone(self.get_output(self.model_name)) + self.assertIsNone(self.get_output(self.widget.Outputs.model)) self.send_signal('Data', self.data) self.widget.apply_button.button.click() - model = self.get_output(self.model_name) + model = self.get_output(self.widget.Outputs.model) self.assertIsNotNone(model) self.assertIsInstance(model, self.widget.LEARNER.__returns__) self.assertIsInstance(model, self.model_class) @@ -535,7 +535,7 @@ def test_output_model_name(self): self.widget.name_line_edit.setText(new_name) self.send_signal("Data", self.data) self.widget.apply_button.button.click() - self.assertEqual(self.get_output(self.model_name).name, new_name) + self.assertEqual(self.get_output(self.widget.Outputs.model).name, new_name) def _get_param_value(self, learner, param): if isinstance(learner, Fitter): @@ -592,7 +592,7 @@ def test_parameters(self): "Mismatching setting for parameter '%s'" % parameter) if issubclass(self.widget.LEARNER, SklModel): - model = self.get_output(self.model_name) + model = self.get_output(self.widget.Outputs.model) if model is not None: self.assertEqual(self._get_param_value(model, parameter), value) self.assertFalse(self.widget.Error.active) diff --git a/Orange/widgets/tests/test_providers.py b/Orange/widgets/tests/test_providers.py index 47477638673..07fd9cf4c86 100644 --- a/Orange/widgets/tests/test_providers.py +++ b/Orange/widgets/tests/test_providers.py @@ -1,9 +1,10 @@ import unittest from Orange.classification import KNNLearner -from Orange.data import Table -from Orange.preprocess.preprocess import Preprocess +from Orange.modelling import TreeLearner +from Orange.regression import MeanLearner from Orange.widgets.utils.owlearnerwidget import OWBaseLearner +from Orange.widgets.utils.signals import Output class TestProviderMetaClass(unittest.TestCase): @@ -13,32 +14,26 @@ def assertChannelsEqual(self, first, second, msg=None): self.assertEqual(set(i.name for i in first), set(i[0] for i in second), msg) - def test_inputs(self): - inputs = [("Data", Table, "set_data"), - ("Preprocessor", Preprocess, "set_preprocessor")] - - class OWTestProvider(OWBaseLearner): + def test_widgets_do_not_share_outputs(self): + class WidgetA(OWBaseLearner): + name = "A" LEARNER = KNNLearner - name = "test widget" - - self.assertChannelsEqual(OWTestProvider.inputs, inputs) - def test_outputs(self): - expected_outputs = [ - ("Learner", KNNLearner), - ("Classifier", KNNLearner.__returns__), - ("Test", Table) - ] + class WidgetB(OWBaseLearner): + name = "B" + LEARNER = MeanLearner - class OWTestProvider(OWBaseLearner): - LEARNER = KNNLearner - name = "test widget" + self.assertEqual(WidgetA.Outputs.learner.type, KNNLearner) + self.assertEqual(WidgetB.Outputs.learner.type, MeanLearner) - outputs = [("Test", Table)] + class WidgetC(WidgetA): + name = "C" + LEARNER = TreeLearner - self.assertChannelsEqual(OWTestProvider.outputs, expected_outputs) + class Outputs(WidgetA.Outputs): + test = Output("test", str) - def test_class_without_learner(self): - with self.assertRaises(AttributeError): - class OWTestProvider(OWBaseLearner): - name = 'test' + self.assertEqual(WidgetC.Outputs.learner.type, TreeLearner) + self.assertEqual(WidgetC.Outputs.test.name, "test") + self.assertEqual(WidgetA.Outputs.learner.type, KNNLearner) + self.assertFalse(hasattr(WidgetA.Outputs, "test")) diff --git a/Orange/widgets/utils/__init__.py b/Orange/widgets/utils/__init__.py index a10e1033a9a..71068075980 100644 --- a/Orange/widgets/utils/__init__.py +++ b/Orange/widgets/utils/__init__.py @@ -1,3 +1,4 @@ +import inspect import sys from AnyQt.QtCore import QObject @@ -70,3 +71,18 @@ def dumpObjectTree(obj, _indent=0): file=sys.stderr) for child in obj.children(): dumpObjectTree(child, _indent + 1) + + +def getmembers(obj, predicate=None): + """Return all the members of an object in a list of (name, value) pairs sorted by name. + + Behaves like inspect.getmembers. If a type object is passed as a predicate, + only members of that type are returned. + """ + + if isinstance(predicate, type): + def mypredicate(x): + return isinstance(x, predicate) + else: + mypredicate = predicate + return inspect.getmembers(obj, mypredicate) diff --git a/Orange/widgets/utils/owlearnerwidget.py b/Orange/widgets/utils/owlearnerwidget.py index 5f6dcbc0447..c748b868ee0 100644 --- a/Orange/widgets/utils/owlearnerwidget.py +++ b/Orange/widgets/utils/owlearnerwidget.py @@ -1,5 +1,5 @@ -import numpy as np from copy import deepcopy +import numpy as np from AnyQt.QtCore import QTimer, Qt @@ -8,6 +8,7 @@ from Orange.preprocess.preprocess import Preprocess from Orange.widgets import gui from Orange.widgets.settings import Setting +from Orange.widgets.utils import getmembers from Orange.widgets.utils.signals import Output, Input from Orange.widgets.utils.sql import check_sql_input from Orange.widgets.widget import OWWidget, WidgetMetaClass, Msg @@ -15,20 +16,38 @@ class OWBaseLearnerMeta(WidgetMetaClass): """ Meta class for learner widgets + + OWBaseLearner declares two outputs, learner and model with + generic type (Learner and Model). + + This metaclass ensures that each of the subclasses gets + its own Outputs class with output that match the corresponding + learner. """ - def __new__(mcls, name, bases, attributes): - learner = attributes.get("LEARNER", None) - # classes with empty names are considered abstract - if attributes.get(name): - if learner is None: - raise AttributeError( - "'{}' must declare attribute LEARNER".format(name)) + def __new__(cls, name, bases, attributes): + def abstract_widget(): + return not attributes.get("name") + + def copy_outputs(template): + result = type("Outputs", (), {}) + for name, signal in getmembers(template, Output): + setattr(result, name, deepcopy(signal)) + return result + + obj = super().__new__(cls, name, bases, attributes) + if abstract_widget(): + return obj + + learner = attributes.get("LEARNER") + if not learner: + raise AttributeError( + "'{}' must declare attribute LEARNER".format(name)) - outputs = attributes["Outputs"] = deepcopy(attributes["Outputs"]) - outputs.learner.type = learner - outputs.model.type = learner.__returns__ + outputs = obj.Outputs = copy_outputs(obj.Outputs) + outputs.learner.type = learner + outputs.model.type = learner.__returns__ - return super().__new__(mcls, name, bases, attributes) + return obj class OWBaseLearner(OWWidget, metaclass=OWBaseLearnerMeta): @@ -64,8 +83,6 @@ class Inputs: preprocessor = Input("Preprocessor", Preprocess) class Outputs: - # Exact output types for each learner widget are set in - # OWBaseLearnerMeta meta class learner = Output("Learner", Learner, dynamic=False) model = Output("Model", Model, dynamic=False, replaces=["Classifier", "Predictor"]) @@ -75,7 +92,7 @@ def __init__(self): self.data = None self.valid_data = False self.learner = None - self.learner_name = self.LEARNER.name + self.learner_name = self.name self.model = None self.preprocessors = None self.outdated_settings = False diff --git a/Orange/widgets/utils/signals.py b/Orange/widgets/utils/signals.py index 1f8a650ea78..2f772bf2a5b 100644 --- a/Orange/widgets/utils/signals.py +++ b/Orange/widgets/utils/signals.py @@ -1,8 +1,8 @@ import copy -import inspect import itertools from Orange.canvas.registry.description import InputSignal, OutputSignal +from Orange.widgets.utils import getmembers # increasing counter for ensuring the order of Input/Output definitions # is preserved when going through the unordered class namespace of @@ -146,12 +146,6 @@ def send(self, value, id=None): signal_manager.send(self.widget, self.name, value, id) -def _get_members(obj, member_type): - def is_member_type(member): - return isinstance(member, member_type) - return inspect.getmembers(obj, is_member_type) - - class WidgetSignalsMixin: """Mixin for managing widget's input and output signals""" class Inputs: @@ -165,7 +159,7 @@ def __init__(self): def _bind_outputs(self): bound_cls = self.Outputs() - for name, signal in _get_members(bound_cls, Output): + for name, signal in getmembers(bound_cls, Output): setattr(bound_cls, name, signal.bound_signal(self)) setattr(self, "Outputs", bound_cls) @@ -217,7 +211,7 @@ def signal_from_args(args, signal_type): @classmethod def _check_input_handlers(cls): unbound = [signal.name - for _, signal in _get_members(cls.Inputs, Input) + for _, signal in getmembers(cls.Inputs, Input) if not signal.handler] if unbound: raise ValueError("unbound signal(s) in {}: {}". @@ -249,7 +243,7 @@ def get_signals(cls, direction): return old_style signal_class = getattr(cls, direction.title()) - signals = [signal for _, signal in _get_members(signal_class, _Signal)] + signals = [signal for _, signal in getmembers(signal_class, _Signal)] return list(sorted(signals, key=lambda s: s._seq_id))