diff --git a/dash_extensions/enrich.py b/dash_extensions/enrich.py index 28ee009..02be063 100644 --- a/dash_extensions/enrich.py +++ b/dash_extensions/enrich.py @@ -45,6 +45,7 @@ from dash.development.base_component import Component from flask import session from flask_caching.backends import FileSystemCache, RedisCache +from itertools import compress from more_itertools import flatten from collections import defaultdict from typing import Dict, Callable, List, Union, Any, Tuple, Optional, Generic, TypeVar @@ -952,8 +953,6 @@ def __init__(self, component_id, component_property): class TriggerTransform(DashTransform): - # NOTE: This transform cannot be implemented for clientside callbacks since the JS can't be modified from here. - def apply_serverside(self, callbacks): for callback in callbacks: is_trigger = [isinstance(item, Trigger) for item in callback.inputs] @@ -965,6 +964,24 @@ def apply_serverside(self, callbacks): callback.f = filter_args(is_trigger)(f) return callbacks + def apply_clientside(self, callbacks): + for callback in callbacks: + is_not_trigger = [not isinstance(item, Trigger) for item in callback.inputs] + # Check if any triggers are there. + if all(is_not_trigger): + continue + # If so, filter the callback args. + args = [f"arg{i}" for i in range(len(callback.inputs))] + filtered_args = compress(args, is_not_trigger) + if isinstance(callback.f, ClientsideFunction): + callback.f = f"window.dash_clientside['{callback.f.namespace}']['{callback.f.function_name}']" + callback.f = f""" +function({", ".join(args)}) {{ +const func = {callback.f}; +return func({", ".join(filtered_args)}); +}}""" + return callbacks + def filter_args(args_filter): def wrapper(f): diff --git a/tests/test_enrich.py b/tests/test_enrich.py index fa1b645..c599ff8 100644 --- a/tests/test_enrich.py +++ b/tests/test_enrich.py @@ -286,6 +286,42 @@ def update(n_clicks2, n_clicks4): assert log.text == "1-1" +def test_trigger_transform_clientside(dash_duo): + app = DashProxy(prevent_initial_callbacks=True, transforms=[TriggerTransform()]) + app.layout = html.Div([ + html.Button(id="btn1"), + html.Button(id="btn2"), + html.Button(id="btn3"), + html.Button(id="btn4"), + html.Div(id="log"), + ]) + + app.clientside_callback( + """(nClicks2, nClicks4) => `${nClicks2}-${nClicks4}`""", + Output("log", "children"), + Trigger("btn1", "n_clicks"), + Input("btn2", "n_clicks"), + Trigger("btn3", "n_clicks"), + State("btn4", "n_clicks")) + + # Check that the app works. + dash_duo.start_server(app) + log = dash_duo.find_element("#log") + assert log.text == "" + dash_duo.find_element("#btn1").click() + time.sleep(0.1) + assert log.text == "undefined-undefined" + dash_duo.find_element("#btn2").click() + time.sleep(0.1) + assert log.text == "1-undefined" + dash_duo.find_element("#btn4").click() + time.sleep(0.1) + assert log.text == "1-undefined" + dash_duo.find_element("#btn3").click() + time.sleep(0.1) + assert log.text == "1-1" + + @pytest.mark.parametrize( 'args, kwargs', [([Output("log", "children"), Input("right", "n_clicks")], dict()),