Skip to content
This repository has been archived by the owner on Dec 13, 2023. It is now read-only.

joinBindings implementation #204

Closed
wants to merge 13 commits into from
103 changes: 88 additions & 15 deletions src/Binding.lua
Original file line number Diff line number Diff line change
@@ -11,6 +11,19 @@ local function identity(value)
return value
end

--[[
Maps a table of bindings to their respective values. Used in Binding.join.
]]
local function mapBindingsToValues(bindings)
local values = {}

for key, binding in pairs(bindings) do
values[key] = binding:getValue()
end

return values
end

local Binding = {}

--[[
@@ -32,13 +45,13 @@ function bindingPrototype:getValue()

--[[
If our source is another binding but we're not subscribed, we'll
return the mapped value from our upstream binding.
return the mapped value from our upstream binding(s).

This allows us to avoid subscribing to our source until someone
has subscribed to us, and avoid creating dangling connections.
]]
if internalData.upstreamBinding ~= nil and internalData.upstreamDisconnect == nil then
return internalData.valueTransform(internalData.upstreamBinding:getValue())
if internalData.upstreamBindingCount > 0 then
return internalData.valueTransform(self:__getValueFromUpstreamBindings())
end

return internalData.value
@@ -53,13 +66,49 @@ function bindingPrototype:map(valueTransform)
end

local binding = Binding.create(valueTransform(self:getValue()))
local internalData = binding[InternalData]

binding[InternalData].valueTransform = valueTransform
binding[InternalData].upstreamBinding = self
internalData.valueTransform = valueTransform

internalData.upstreamBindings.source = self
internalData.upstreamBindingCount = internalData.upstreamBindingCount + 1

return binding
end

--[[
Determines the final (not yet transformed) value from upstream bindings
]]
function bindingPrototype:__getValueFromUpstreamBindings()
local internalData = self[InternalData]
local newValue = mapBindingsToValues(internalData.upstreamBindings)

if not internalData.isJoinedBinding then
--[[
If this is not a joined binding, there will always only be one upstream
binding.

To ensure that joined bindings with a single upstream binding always
result in a table, we use the internal variable isJoinedBinding
]]
local _, value = next(newValue)
newValue = value
end

return newValue
end

--[[
Disconnects all connections to upstream bindings
]]
function bindingPrototype:__upstreamDisconnect()
local internalData = self[InternalData]

for _, disconnect in pairs(internalData.upstreamConnections) do
disconnect()
end
end

--[[
Update a binding's value. This is only accessible by Roact.
]]
@@ -80,13 +129,17 @@ function Binding.subscribe(binding, handler)

--[[
If this binding is mapped to another and does not have any subscribers,
we need to create a subscription to our source binding so that updates
we need to create subscriptions to our source bindings so that updates
get passed along to us
]]
if internalData.upstreamBinding ~= nil and internalData.subscriberCount == 0 then
internalData.upstreamDisconnect = Binding.subscribe(internalData.upstreamBinding, function(value)
Binding.update(binding, value)
end)
if internalData.upstreamBindingCount > 0 and internalData.subscriberCount == 0 then
local function upstreamCallback()
Binding.update(binding, binding:__getValueFromUpstreamBindings())
end

for _, upstreamBinding in pairs(internalData.upstreamBindings) do
table.insert(internalData.upstreamConnections, Binding.subscribe(upstreamBinding, upstreamCallback))
end
end

local disconnect = internalData.changeSignal:subscribe(handler)
@@ -111,9 +164,8 @@ function Binding.subscribe(binding, handler)
If our subscribers count drops to 0, we can safely unsubscribe from
our source binding
]]
if internalData.subscriberCount == 0 and internalData.upstreamDisconnect ~= nil then
internalData.upstreamDisconnect()
internalData.upstreamDisconnect = nil
if internalData.subscriberCount == 0 then
binding:__upstreamDisconnect()
end
end
end
@@ -132,8 +184,10 @@ function Binding.create(initialValue)
subscriberCount = 0,

valueTransform = identity,
upstreamBinding = nil,
upstreamDisconnect = nil,
isJoinedBinding = false,
upstreamBindings = {},
upstreamConnections = {},
upstreamBindingCount = 0,
},
}

@@ -146,4 +200,23 @@ function Binding.create(initialValue)
return binding, setter
end

--[[
Creates a new binding which updates when any of the upstream bindings
updates, which can be further mapped into any value. This function will
be exposed to users of Roact.
]]
function Binding.join(bindings)
local joinedBinding = Binding.create(mapBindingsToValues(bindings))
local internalData = joinedBinding[InternalData]

internalData.isJoinedBinding = true

for key, binding in pairs(bindings) do
internalData.upstreamBindings[key] = binding
internalData.upstreamBindingCount = internalData.upstreamBindingCount + 1
end

return joinedBinding
end

return Binding
41 changes: 41 additions & 0 deletions src/Binding.spec.lua
Original file line number Diff line number Diff line change
@@ -21,6 +21,47 @@ return function()
end)
end)

describe("Binding.join", function()
it("should properly output values", function()
local binding1 = Binding.create(1)
local binding2 = Binding.create(2)

local joinedBinding = Binding.join({
binding1,
binding2,
})

local bindingValue = joinedBinding:getValue()
expect(bindingValue).to.be.a("table")
expect(bindingValue[1]).to.equal(1)
expect(bindingValue[2]).to.equal(2)
end)

it("should update when any one of the subscribed bindings updates", function()
local binding1, update1 = Binding.create(1)
local binding2, update2 = Binding.create(2)

local joinedBinding = Binding.join({
binding1,
binding2,
})

local spy = createSpy()
Binding.subscribe(joinedBinding, spy.value)

expect(spy.callCount).to.equal(0)
update1(3)
expect(spy.callCount).to.equal(1)
update2(4)
expect(spy.callCount).to.equal(2)

local bindingValue = joinedBinding:getValue()
expect(bindingValue).to.be.a("table")
expect(bindingValue[1]).to.equal(3)
expect(bindingValue[2]).to.equal(4)
end)
end)

describe("Binding object", function()
it("should provide a getter and setter", function()
local binding, update = Binding.create(1)
1 change: 1 addition & 0 deletions src/init.lua
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ local Roact = strict {
Portal = require(script.Portal),
createRef = require(script.createRef),
createBinding = Binding.create,
joinBindings = Binding.join,

Change = require(script.PropMarkers.Change),
Children = require(script.PropMarkers.Children),
1 change: 1 addition & 0 deletions src/init.spec.lua
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ return function()
createFragment = "function",
createRef = "function",
createBinding = "function",
joinBindings = "function",
mount = "function",
unmount = "function",
update = "function",