diff --git a/src/Binding.lua b/src/Binding.lua index 06a9e267..04ce1d0b 100644 --- a/src/Binding.lua +++ b/src/Binding.lua @@ -11,6 +11,26 @@ local function identity(value) return value end +--[[ + Determines if a table (array or dictionary) is empty +]] +local function isTableEmpty(value) + return not next(value) +end + +--[[ + Maps a table of bindings to their respective values. Used in Binding.join. +]] +local function mapBindingsToValues(bindings) + local values = {} + + for key, binding in pairs(bindings) do + values[key] = binding:getValue() + end + + return values +end + local Binding = {} --[[ @@ -32,13 +52,13 @@ function bindingPrototype:getValue() --[[ If our source is another binding but we're not subscribed, we'll - return the mapped value from our upstream binding. + return the mapped value from our upstream binding(s). This allows us to avoid subscribing to our source until someone has subscribed to us, and avoid creating dangling connections. ]] - if internalData.upstreamBinding ~= nil and internalData.upstreamDisconnect == nil then - return internalData.valueTransform(internalData.upstreamBinding:getValue()) + if not isTableEmpty(internalData.upstreamBindings) then + return internalData.valueTransform(self:__getValueFromUpstreamBindings()) end return internalData.value @@ -55,11 +75,46 @@ function bindingPrototype:map(valueTransform) local binding = Binding.create(valueTransform(self:getValue())) binding[InternalData].valueTransform = valueTransform - binding[InternalData].upstreamBinding = self + binding[InternalData].upstreamBindings.source = self return binding end +--[[ + Determines the final (not yet transformed) value from upstream bindings +]] +function bindingPrototype:__getValueFromUpstreamBindings() + local internalData = self[InternalData] + local newValue = mapBindingsToValues(internalData.upstreamBindings) + + if not internalData.isJoinedBinding then + --[[ + If this is not a joined binding, there will always only be one upstream + binding. + + To ensure that joined bindings with a single upstream binding always + result in a table, we use the internal variable isJoinedBinding + ]] + local _, value = next(newValue) + newValue = value + end + + return newValue +end + +--[[ + Disconnects all connections to upstream bindings +]] +function bindingPrototype:__upstreamDisconnect() + local internalData = self[InternalData] + + for _, disconnect in ipairs(internalData.upstreamConnections) do + disconnect() + end + + internalData.upstreamConnections = {} +end + --[[ Update a binding's value. This is only accessible by Roact. ]] @@ -80,13 +135,17 @@ function Binding.subscribe(binding, handler) --[[ If this binding is mapped to another and does not have any subscribers, - we need to create a subscription to our source binding so that updates + we need to create subscriptions to our source bindings so that updates get passed along to us ]] - if internalData.upstreamBinding ~= nil and internalData.subscriberCount == 0 then - internalData.upstreamDisconnect = Binding.subscribe(internalData.upstreamBinding, function(value) - Binding.update(binding, value) - end) + if not isTableEmpty(internalData.upstreamBindings) and internalData.subscriberCount == 0 then + local function upstreamCallback() + Binding.update(binding, binding:__getValueFromUpstreamBindings()) + end + + for _, upstreamBinding in pairs(internalData.upstreamBindings) do + table.insert(internalData.upstreamConnections, Binding.subscribe(upstreamBinding, upstreamCallback)) + end end local disconnect = internalData.changeSignal:subscribe(handler) @@ -111,9 +170,8 @@ function Binding.subscribe(binding, handler) If our subscribers count drops to 0, we can safely unsubscribe from our source binding ]] - if internalData.subscriberCount == 0 and internalData.upstreamDisconnect ~= nil then - internalData.upstreamDisconnect() - internalData.upstreamDisconnect = nil + if internalData.subscriberCount == 0 then + binding:__upstreamDisconnect() end end end @@ -132,8 +190,9 @@ function Binding.create(initialValue) subscriberCount = 0, valueTransform = identity, - upstreamBinding = nil, - upstreamDisconnect = nil, + isJoinedBinding = false, + upstreamBindings = {}, + upstreamConnections = {}, }, } @@ -146,4 +205,30 @@ function Binding.create(initialValue) return binding, setter end +--[[ + Creates a new binding which updates when any of the upstream bindings + updates, which can be further mapped into any value. This function will + be exposed to users of Roact. +]] +function Binding.join(bindings) + if config.typeChecks then + assert(typeof(bindings) == "table", "Bad arg #1 to Binding.join: expected table") + + for key, binding in pairs(bindings) do + assert(Type.of(binding) == Type.Binding, ("Non-binding value passed into Binding.join at index %q"):format(key)) + end + end + + local joinedBinding = Binding.create(mapBindingsToValues(bindings)) + local internalData = joinedBinding[InternalData] + + internalData.isJoinedBinding = true + + for key, binding in pairs(bindings) do + internalData.upstreamBindings[key] = binding + end + + return joinedBinding +end + return Binding \ No newline at end of file diff --git a/src/Binding.spec.lua b/src/Binding.spec.lua index d02fff84..d0360b21 100644 --- a/src/Binding.spec.lua +++ b/src/Binding.spec.lua @@ -21,6 +21,82 @@ return function() end) end) + describe("Binding.join", function() + it("should properly output values", function() + local binding1 = Binding.create(1) + local binding2 = Binding.create(2) + + local joinedBinding = Binding.join({ + binding1, + binding2, + }) + + local bindingValue = joinedBinding:getValue() + expect(bindingValue).to.be.a("table") + expect(bindingValue[1]).to.equal(1) + expect(bindingValue[2]).to.equal(2) + end) + + it("should update when any one of the subscribed bindings updates", function() + local binding1, update1 = Binding.create(1) + local binding2, update2 = Binding.create(2) + + local joinedBinding = Binding.join({ + binding1, + binding2, + }) + + local spy = createSpy() + Binding.subscribe(joinedBinding, spy.value) + + expect(spy.callCount).to.equal(0) + update1(3) + expect(spy.callCount).to.equal(1) + update2(4) + expect(spy.callCount).to.equal(2) + + local bindingValue = spy.values[1] + expect(bindingValue).to.be.a("table") + expect(bindingValue[1]).to.equal(3) + expect(bindingValue[2]).to.equal(4) + end) + + it("should return correct values when there are no subscriptions", function() + local binding1, update1 = Binding.create(1) + local binding2, update2 = Binding.create(2) + + local joinedBinding = Binding.join({ + binding1, + binding2, + }) + + update1("foo") + update2("bar") + + local bindingValue = joinedBinding:getValue() + expect(bindingValue).to.be.a("table") + expect(bindingValue[1]).to.equal("foo") + expect(bindingValue[2]).to.equal("bar") + end) + + it("should throw when a non-table value is passed", function() + expect(function() + Binding.join("hi") + end).to.throw() + end) + + it("should throw when a non-binding value is passed via table", function() + expect(function() + local binding = Binding.create(123) + + Binding.join({ + binding, + "abcde", + }) + end).to.throw() + end) + end) + describe("Binding object", function() it("should provide a getter and setter", function() local binding, update = Binding.create(1) diff --git a/src/init.lua b/src/init.lua index cbaa6f0a..f002f975 100644 --- a/src/init.lua +++ b/src/init.lua @@ -22,6 +22,7 @@ local Roact = strict { Portal = require(script.Portal), createRef = require(script.createRef), createBinding = Binding.create, + joinBindings = Binding.join, Change = require(script.PropMarkers.Change), Children = require(script.PropMarkers.Children), diff --git a/src/init.spec.lua b/src/init.spec.lua index 7b23a173..7fcf79c8 100644 --- a/src/init.spec.lua +++ b/src/init.spec.lua @@ -7,6 +7,7 @@ return function() createFragment = "function", createRef = "function", createBinding = "function", + joinBindings = "function", mount = "function", unmount = "function", update = "function",