Skip to content

Commit

Permalink
contractcourt: break launchResolvers into two steps
Browse files Browse the repository at this point in the history
In this commit, we break the old `launchResolvers` into two steps - step
one is to launch the resolvers synchronously, and step two is to
actually waiting for the resolvers to be resolved. This is critical as
in the following commit we will require the resolvers to be launched at
the same blockbeat when a force close event is sent by the chain watcher.
  • Loading branch information
yyforyongyu committed Nov 25, 2024
1 parent 0d1e81b commit aeafa63
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 30 deletions.
75 changes: 71 additions & 4 deletions contractcourt/channel_arbitrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet,
// TODO(roasbeef): this isn't re-launched?
}

c.launchResolvers(unresolvedContracts)
c.resolveContracts(unresolvedContracts)

return nil
}
Expand Down Expand Up @@ -1343,7 +1343,7 @@ func (c *ChannelArbitrator) stateStep(

// Finally, we'll launch all the required contract resolvers.
// Once they're all resolved, we're no longer needed.
c.launchResolvers(resolvers)
c.resolveContracts(resolvers)

nextState = StateWaitingFullResolution

Expand Down Expand Up @@ -1566,18 +1566,75 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32,
return fn.Some(int32(deadline)), valueLeft, nil
}

// launchResolvers updates the activeResolvers list and starts the resolvers.
func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver) {
// resolveContracts updates the activeResolvers list and starts to resolve each
// contract concurrently, and launches them.
func (c *ChannelArbitrator) resolveContracts(resolvers []ContractResolver) {
c.activeResolversLock.Lock()
c.activeResolvers = resolvers
c.activeResolversLock.Unlock()

// Launch all resolvers.
c.launchResolvers()

for _, contract := range resolvers {
c.wg.Add(1)
go c.resolveContract(contract)
}
}

// launchResolvers launches all the active resolvers concurrently.
func (c *ChannelArbitrator) launchResolvers() {
c.activeResolversLock.Lock()
resolvers := c.activeResolvers
c.activeResolversLock.Unlock()

// errChans is a map of channels that will be used to receive errors
// returned from launching the resolvers.
errChans := make(map[ContractResolver]chan error, len(resolvers))

// Launch each resolver in goroutines.
for _, r := range resolvers {
// If the contract is already resolved, there's no need to
// launch it again.
if r.IsResolved() {
log.Debugf("ChannelArbitrator(%v): skipping resolver "+
"%T as it's already resolved", c.cfg.ChanPoint,
r)

continue
}

// Create a signal chan.
errChan := make(chan error, 1)
errChans[r] = errChan

go func() {
err := r.Launch()
errChan <- err
}()
}

// Wait for all resolvers to finish launching.
for r, errChan := range errChans {
select {
case err := <-errChan:
if err == nil {
continue
}

log.Errorf("ChannelArbitrator(%v): unable to launch "+
"contract resolver(%T): %v", c.cfg.ChanPoint, r,
err)

case <-c.quit:
log.Debugf("ChannelArbitrator quit signal received, " +
"exit launchResolvers")

return
}
}
}

// advanceState is the main driver of our state machine. This method is an
// iterative function which repeatedly attempts to advance the internal state
// of the channel arbitrator. The state will be advanced until we reach a
Expand Down Expand Up @@ -2616,6 +2673,13 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver) {
// loop.
currentContract = nextContract

// Launch the new contract.
err = currentContract.Launch()
if err != nil {
log.Errorf("Failed to launch %T: %v",
currentContract, err)
}

// If this contract is actually fully resolved, then
// we'll mark it as such within the database.
case currentContract.IsResolved():
Expand Down Expand Up @@ -3117,6 +3181,9 @@ func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Blockbeat) error {
}
}

// Launch all active resolvers when a new blockbeat is received.
c.launchResolvers()

return nil
}

Expand Down
54 changes: 28 additions & 26 deletions contractcourt/channel_arbitrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
Expand Down Expand Up @@ -1091,12 +1092,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
}

// Send a notification that the expiry height has been reached.
//
// TODO(yy): remove the EpochChan and use the blockbeat below once
// resolvers are hooked with the blockbeat.
oldNotifier.EpochChan <- &chainntnfs.BlockEpoch{Height: 10}
// beat := chainio.NewBlockbeatFromHeight(10)
// chanArb.BlockbeatChan <- beat

// htlcOutgoingContestResolver is now transforming into a
// htlcTimeoutResolver and should send the contract off for incubation.
Expand Down Expand Up @@ -1139,8 +1135,12 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
default:
}

// Notify resolver that the second level transaction is spent.
oldNotifier.SpendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx}
// Notify resolver that the output of the timeout tx has been spent.
oldNotifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: closeTx,
SpentOutPoint: &wire.OutPoint{},
SpenderTxHash: &closeTxid,
}

// At this point channel should be marked as resolved.
chanArbCtxNew.AssertStateTransitions(StateFullyResolved)
Expand Down Expand Up @@ -2820,27 +2820,28 @@ func TestChannelArbitratorAnchors(t *testing.T) {
}
chanArb.UpdateContractSignals(signals)

// Set current block height.
beat = newBeatFromHeight(int32(heightHint))
chanArbCtx.chanArb.BlockbeatChan <- beat

htlcAmt := lnwire.MilliSatoshi(1_000_000)

// Create testing HTLCs.
deadlineDelta := uint32(10)
deadlinePreimageDelta := deadlineDelta + 2
spendingHeight := uint32(beat.Height())
deadlineDelta := uint32(100)

deadlinePreimageDelta := deadlineDelta
htlcWithPreimage := channeldb.HTLC{
HtlcIndex: 99,
RefundTimeout: heightHint + deadlinePreimageDelta,
HtlcIndex: 99,
// RefundTimeout is 101.
RefundTimeout: spendingHeight + deadlinePreimageDelta,
RHash: rHash,
Incoming: true,
Amt: htlcAmt,
}
expectedDeadline := deadlineDelta/2 + spendingHeight

deadlineHTLCdelta := deadlineDelta + 3
deadlineHTLCdelta := deadlineDelta + 40
htlc := channeldb.HTLC{
HtlcIndex: 100,
RefundTimeout: heightHint + deadlineHTLCdelta,
HtlcIndex: 100,
// RefundTimeout is 141.
RefundTimeout: spendingHeight + deadlineHTLCdelta,
Amt: htlcAmt,
}

Expand Down Expand Up @@ -2925,7 +2926,9 @@ func TestChannelArbitratorAnchors(t *testing.T) {

//nolint:lll
chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{
SpendDetail: &chainntnfs.SpendDetail{},
SpendDetail: &chainntnfs.SpendDetail{
SpendingHeight: int32(spendingHeight),
},
LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{
CloseTx: closeTx,
ContractResolutions: fn.Some(lnwallet.ContractResolutions{
Expand Down Expand Up @@ -2989,16 +2992,14 @@ func TestChannelArbitratorAnchors(t *testing.T) {
// to htlcWithPreimage's CLTV.
require.Equal(t, 2, len(chanArbCtx.sweeper.deadlines))
require.EqualValues(t,
heightHint+deadlinePreimageDelta/2,
expectedDeadline,
chanArbCtx.sweeper.deadlines[0], "want %d, got %d",
heightHint+deadlinePreimageDelta/2,
chanArbCtx.sweeper.deadlines[0],
expectedDeadline, chanArbCtx.sweeper.deadlines[0],
)
require.EqualValues(t,
heightHint+deadlinePreimageDelta/2,
expectedDeadline,
chanArbCtx.sweeper.deadlines[1], "want %d, got %d",
heightHint+deadlinePreimageDelta/2,
chanArbCtx.sweeper.deadlines[1],
expectedDeadline, chanArbCtx.sweeper.deadlines[1],
)
}

Expand Down Expand Up @@ -3146,7 +3147,8 @@ func assertResolverReport(t *testing.T, reports chan *channeldb.ResolverReport,
select {
case report := <-reports:
if !reflect.DeepEqual(report, expected) {
t.Fatalf("expected: %v, got: %v", expected, report)
t.Fatalf("expected: %v, got: %v", spew.Sdump(expected),
spew.Sdump(report))
}

case <-time.After(defaultTimeout):
Expand Down
11 changes: 11 additions & 0 deletions contractcourt/contract_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ type ContractResolver interface {
// resides within.
ResolverKey() []byte

// Launch starts the resolver by constructing an input and offering it
// to the sweeper. Once offered, it's expected to monitor the sweeping
// result in a goroutine invoked by calling Resolve.
//
// NOTE: We can call `Resolve` inside a goroutine at the end of this
// method to avoid calling it in the ChannelArbitrator. However, there
// are some DB-related operations such as SwapContract/ResolveContract
// which need to be done inside the resolvers instead, which needs a
// deeper refactoring.
Launch() error

// Resolve instructs the contract resolver to resolve the output
// on-chain. Once the output has been *fully* resolved, the function
// should return immediately with a nil ContractResolver value for the
Expand Down
3 changes: 3 additions & 0 deletions contractcourt/htlc_success_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ func (i *htlcResolverTestContext) resolve() {
i.resolverResultChan = make(chan resolveResult, 1)

go func() {
err := i.resolver.Launch()
require.NoError(i.t, err)

nextResolver, err := i.resolver.Resolve()
i.resolverResultChan <- resolveResult{
nextResolver: nextResolver,
Expand Down

0 comments on commit aeafa63

Please sign in to comment.