Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure all watchers are executed in scheduling order #323

Merged
merged 5 commits into from
Mar 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .parameterized import Parameterized, Parameter, String, \
descendents, ParameterizedFunction, ParamOverrides
from .parameterized import (depends, output, logging_level, # noqa: api import
shared_parameters, instance_descriptor)
shared_parameters, instance_descriptor, batch_watch)

from collections import OrderedDict
from numbers import Real
Expand Down
82 changes: 57 additions & 25 deletions param/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,25 @@ def logging_level(level):
param_logger.setLevel(logging_level)


@contextmanager
def batch_watch(parameterized, run=True):
"""
Context manager to batch watcher events on a parameterized object.
The context manager will queue any events triggered by setting a
parameter on the supplied parameterized object and dispatch them
all at once when the context manager exits. If run=False the
queued events are not dispatched and should be processed manually.
jbednar marked this conversation as resolved.
Show resolved Hide resolved
"""
BATCH_WATCH = parameterized.param._BATCH_WATCH
parameterized.param._BATCH_WATCH = True
try:
yield
finally:
parameterized.param._BATCH_WATCH = BATCH_WATCH
if run and not BATCH_WATCH:
parameterized.param._batch_call_watchers()


def classlist(class_):
"""
Return a list of the class hierarchy above (and including) the given class.
Expand Down Expand Up @@ -644,6 +663,8 @@ def __setattr__(self,attribute,value):
event = Event(what=attribute,name=self.name,obj=None,cls=self.owner,old=old,new=value, type=None)
for watcher in self.watchers[attribute]:
self.owner.param._call_watcher(watcher, event)
if not self.owner.param._BATCH_WATCH:
self.owner.param._batch_call_watchers()


def __get__(self,obj,objtype): # pylint: disable-msg=W0613
Expand Down Expand Up @@ -735,8 +756,13 @@ def __set__(self,obj,val):

event = Event(what='value',name=self.name,obj=obj,cls=self.owner,old=_old,new=val, type=None)
obj = self.owner if obj is None else obj
for s in watchers:
obj.param._call_watcher(s, event)
if obj is None:
return

for watcher in watchers:
obj.param._call_watcher(watcher, event)
if not obj.param._BATCH_WATCH:
obj.param._batch_call_watchers()


def _validate(self, val):
Expand Down Expand Up @@ -1215,6 +1241,7 @@ def set_param(self_, *args,**kwargs):
positional arguments, but the keyword interface is preferred
because it is more compact and can set multiple values.
"""
BATCH_WATCH = self_.self_or_cls.param._BATCH_WATCH
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the scary capitals?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just following the style of the original variable name, not sure it has to be capitalized but if we decide to change I would change it everywhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you wish.

self_.self_or_cls.param._BATCH_WATCH = True
self_or_cls = self_.self_or_cls
if args:
Expand All @@ -1235,8 +1262,9 @@ def set_param(self_, *args,**kwargs):
self_.self_or_cls.param._BATCH_WATCH = False
raise

self_.self_or_cls.param._BATCH_WATCH = False
self_._batch_call_watchers()
self_.self_or_cls.param._BATCH_WATCH = BATCH_WATCH
if not BATCH_WATCH:
self_._batch_call_watchers()


def objects(self_, instance=True):
Expand Down Expand Up @@ -1324,29 +1352,38 @@ def _call_watcher(self_, watcher, event):
if watcher not in self_._watchers:
self_._watchers.append(watcher)
elif watcher.mode == 'args':
watcher.fn(self_._update_event_type(watcher, event, self_.self_or_cls.param._TRIGGER))
with batch_watch(self_.self_or_cls, run=False):
watcher.fn(self_._update_event_type(watcher, event, self_.self_or_cls.param._TRIGGER))
else:
event = self_._update_event_type(watcher, event, self_.self_or_cls.param._TRIGGER)
watcher.fn(**{event.name: event.new})
with batch_watch(self_.self_or_cls, run=False):
event = self_._update_event_type(watcher, event, self_.self_or_cls.param._TRIGGER)
watcher.fn(**{event.name: event.new})


def _batch_call_watchers(self_):
"""
Batch call a set of watchers based on the parameter value
settings in kwargs using the queued Event and watcher objects.
"""
event_dict = OrderedDict([(c.name,c) for c in self_.self_or_cls.param._events])
watchers = self_.self_or_cls.param._watchers[:]
self_.self_or_cls.param._events = []
self_.self_or_cls.param._watchers = []
while self_.self_or_cls.param._events:
event_dict = OrderedDict([(c.name,c) for c in self_.self_or_cls.param._events])
watchers = self_.self_or_cls.param._watchers[:]
self_.self_or_cls.param._events = []
self_.self_or_cls.param._watchers = []

for watcher in watchers:
events = [self_._update_event_type(watcher, event_dict[name], self_.self_or_cls.param._TRIGGER)
for name in watcher.parameter_names if name in event_dict]
if watcher.mode == 'args':
watcher.fn(*events)
else:
watcher.fn(**{c.name:c.new for c in events})
for watcher in watchers:
events = [self_._update_event_type(watcher, event_dict[name], self_.self_or_cls.param._TRIGGER)
for name in watcher.parameter_names if name in event_dict]
self_.self_or_cls.param._BATCH_WATCH = True
try:
if watcher.mode == 'args':
watcher.fn(*events)
else:
watcher.fn(**{c.name:c.new for c in events})
except:
raise
finally:
self_.self_or_cls.param._BATCH_WATCH = False


def set_dynamic_time_fn(self_,time_fn,sublistattr=None):
Expand Down Expand Up @@ -2188,19 +2225,14 @@ def __init__(self,**params):
self.initialized=False
# Override class level param namespace with instance namespace
self.param = Parameters(self.__class__, self=self)
self._instance__params = {}
self._param_watchers = {}

self.param._generate_name()

self._instance__params = {}
self.param._setup_params(**params)
object_count += 1

# TODO: should move to param namespace? (like _param_value
# etc should also move)
self._param_watchers = {}

# add watched dependencies
#
for n in self.__class__.param._depends['watch']:
# TODO: should improve this - will happen for every
# instantiation of Parameterized with watched deps. Will
Expand Down
57 changes: 57 additions & 0 deletions tests/API1/testwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ class SimpleWatchSubclass(SimpleWatchExample):
pass


class WatchMethodExample(SimpleWatchSubclass):

@param.depends('a', watch=True)
def _clip_a(self):
if self.a > 3:
self.a = 3

@param.depends('b', watch=True)
def _set_c(self):
self.c = self.b*2



class TestWatch(API1TestCase):

Expand Down Expand Up @@ -190,6 +202,29 @@ def test_simple_batched_watch_setattr(self):
self.assertEqual(args[0].new, 3)
self.assertEqual(args[0].type, 'changed')

def test_batched_watch_context_manager(self):

accumulator = Accumulator()

obj = SimpleWatchExample()
obj.param.watch(accumulator, ['a','b'])

with param.batch_watch(obj):
obj.a = 2
obj.b = 3

self.assertEqual(accumulator.call_count(), 1)
args = accumulator.args_for_call(0)

self.assertEqual(len(args), 2)
self.assertEqual(args[0].name, 'a')
self.assertEqual(args[0].old, 0)
self.assertEqual(args[0].new, 2)
self.assertEqual(args[0].type, 'changed')
self.assertEqual(args[1].name, 'b')
self.assertEqual(args[1].old, 0)
self.assertEqual(args[1].new, 3)
self.assertEqual(args[1].type, 'changed')

def test_nested_batched_watch_setattr(self):

Expand Down Expand Up @@ -383,6 +418,28 @@ def test_nested_batched_watch_not_onlychanged(self):



class TestWatchMethod(API1TestCase):

def test_dependent_params(self):
obj = WatchMethodExample()

obj.b = 3
self.assertEqual(obj.c, 6)

def test_multiple_watcher_dispatch(self):
obj = WatchMethodExample()
obj2 = SimpleWatchExample()

def link(event):
obj2.a = event.new

obj.param.watch(link, 'a')
obj.a = 4
self.assertEqual(obj.a, 3)
self.assertEqual(obj2.a, 3)



class TestWatchValues(API1TestCase):

def setUp(self):
Expand Down