diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 07eab9e1..696b4636 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -6,8 +6,6 @@ import ( "sync" "sync/atomic" - "github.com/lightningnetwork/lnd/sweep" - "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" @@ -15,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/sweep" ) // ErrChainArbExiting signals that the chain arbitrator is shutting down. @@ -135,6 +134,11 @@ type ChainArbitratorConfig struct { // Sweeper allows resolvers to sweep their final outputs. Sweeper *sweep.UtxoSweeper + + // SettleInvoice attempts to settle an existing invoice on-chain with + // the given payment hash. ErrInvoiceNotFound is returned if an invoice + // is not found. + SettleInvoice func(chainhash.Hash, lnwire.MilliSatoshi) error } // ChainArbitrator is a sub-system that oversees the on-chain resolution of all diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 5e02e722..154821cb 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -243,8 +243,7 @@ func (c *ChannelArbitrator) Start() error { } var ( - err error - unresolvedContracts []ContractResolver + err error ) log.Debugf("Starting ChannelArbitrator(%v), htlc_set=%v", @@ -332,23 +331,10 @@ func (c *ChannelArbitrator) Start() error { if startingState == StateWaitingFullResolution && nextState == StateWaitingFullResolution { - // We'll now query our log to see if there are any active - // unresolved contracts. If this is the case, then we'll - // relaunch all contract resolvers. - unresolvedContracts, err = c.log.FetchUnresolvedContracts() - if err != nil { + if err := c.relaunchResolvers(); err != nil { c.cfg.BlockEpochs.Cancel() return err } - - log.Infof("ChannelArbitrator(%v): relaunching %v contract "+ - "resolvers", c.cfg.ChanPoint, len(unresolvedContracts)) - - c.activeResolvers = unresolvedContracts - for _, contract := range unresolvedContracts { - c.wg.Add(1) - go c.resolveContract(contract) - } } // TODO(roasbeef): cancel if breached @@ -358,6 +344,123 @@ func (c *ChannelArbitrator) Start() error { return nil } +// relauchResolvers relaunches the set of resolvers for unresolved contracts in +// order to provide them with information that's not immediately available upon +// starting the ChannelArbitrator. This information should ideally be stored in +// the database, so this only serves as a intermediate work-around to prevent a +// migration. +func (c *ChannelArbitrator) relaunchResolvers() error { + // We'll now query our log to see if there are any active + // unresolved contracts. If this is the case, then we'll + // relaunch all contract resolvers. + unresolvedContracts, err := c.log.FetchUnresolvedContracts() + if err != nil { + return err + } + + // Retrieve the commitment tx hash from the log. + contractResolutions, err := c.log.FetchContractResolutions() + if err != nil { + log.Errorf("unable to fetch contract resolutions: %v", + err) + return err + } + commitHash := contractResolutions.CommitHash + + // Reconstruct the htlc outpoints and data from the chain action log. + // The purpose of the constructed htlc map is to supplement to resolvers + // restored from database with extra data. Ideally this data is stored + // as part of the resolver in the log. This is a workaround to prevent a + // db migration. + htlcMap := make(map[wire.OutPoint]*channeldb.HTLC) + chainActions, err := c.log.FetchChainActions() + if err != nil { + log.Errorf("unable to fetch chain actions: %v", err) + return err + } + for _, htlcs := range chainActions { + for _, htlc := range htlcs { + outpoint := wire.OutPoint{ + Hash: commitHash, + Index: uint32(htlc.OutputIndex), + } + htlcMap[outpoint] = &htlc + } + } + + log.Infof("ChannelArbitrator(%v): relaunching %v contract "+ + "resolvers", c.cfg.ChanPoint, len(unresolvedContracts)) + + for _, resolver := range unresolvedContracts { + supplementResolver(resolver, htlcMap) + } + + c.launchResolvers(unresolvedContracts) + + return nil +} + +// supplementResolver takes a resolver as it is restored from the log and fills +// in missing data from the htlcMap. +func supplementResolver(resolver ContractResolver, + htlcMap map[wire.OutPoint]*channeldb.HTLC) error { + + switch r := resolver.(type) { + + case *htlcSuccessResolver: + return supplementSuccessResolver(r, htlcMap) + + case *htlcIncomingContestResolver: + return supplementSuccessResolver( + &r.htlcSuccessResolver, htlcMap, + ) + + case *htlcTimeoutResolver: + return supplementTimeoutResolver(r, htlcMap) + + case *htlcOutgoingContestResolver: + return supplementTimeoutResolver( + &r.htlcTimeoutResolver, htlcMap, + ) + } + + return nil +} + +// supplementSuccessResolver takes a htlcSuccessResolver as it is restored from +// the log and fills in missing data from the htlcMap. +func supplementSuccessResolver(r *htlcSuccessResolver, + htlcMap map[wire.OutPoint]*channeldb.HTLC) error { + + res := r.htlcResolution + htlcPoint := res.HtlcPoint() + htlc, ok := htlcMap[htlcPoint] + if !ok { + return errors.New( + "htlc for success resolver unavailable", + ) + } + r.htlcAmt = htlc.Amt + return nil +} + +// supplementTimeoutResolver takes a htlcSuccessResolver as it is restored from +// the log and fills in missing data from the htlcMap. +func supplementTimeoutResolver(r *htlcTimeoutResolver, + htlcMap map[wire.OutPoint]*channeldb.HTLC) error { + + res := r.htlcResolution + htlcPoint := res.HtlcPoint() + htlc, ok := htlcMap[htlcPoint] + if !ok { + return errors.New( + "htlc for timeout resolver unavailable", + ) + } + r.htlcAmt = htlc.Amt + return nil +} + // Stop signals the ChannelArbitrator for a graceful shutdown. func (c *ChannelArbitrator) Stop() error { if !atomic.CompareAndSwapInt32(&c.stopped, 0, 1) { @@ -703,11 +806,7 @@ func (c *ChannelArbitrator) stateStep(triggerHeight uint32, // Finally, we'll launch all the required contract resolvers. // Once they're all resolved, we're no longer needed. - c.activeResolvers = htlcResolvers - for _, contract := range htlcResolvers { - c.wg.Add(1) - go c.resolveContract(contract) - } + c.launchResolvers(htlcResolvers) nextState = StateWaitingFullResolution @@ -741,6 +840,15 @@ func (c *ChannelArbitrator) stateStep(triggerHeight uint32, return nextState, closeTx, nil } +// launchResolvers updates the activeResolvers list and starts the resolvers. +func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver) { + c.activeResolvers = resolvers + for _, contract := range resolvers { + c.wg.Add(1) + go c.resolveContract(contract) + } +} + // advanceState is the main driver of our state machine. This method is an // iterative function which repeatedly attempts to advance the internal state // of the channel arbitrator. The state will be advanced until we reach a @@ -1071,33 +1179,11 @@ func (c *ChannelArbitrator) prepContractResolutions(htlcActions ChainActionMap, inResolutionMap := make(map[wire.OutPoint]lnwallet.IncomingHtlcResolution) for i := 0; i < len(incomingResolutions); i++ { inRes := incomingResolutions[i] - - // If we have a success transaction, then the htlc's outpoint - // is the transaction's only input. Otherwise, it's the claim - // point. - var htlcPoint wire.OutPoint - if inRes.SignedSuccessTx != nil { - htlcPoint = inRes.SignedSuccessTx.TxIn[0].PreviousOutPoint - } else { - htlcPoint = inRes.ClaimOutpoint - } - - inResolutionMap[htlcPoint] = inRes + inResolutionMap[inRes.HtlcPoint()] = inRes } for i := 0; i < len(outgoingResolutions); i++ { outRes := outgoingResolutions[i] - - // If we have a timeout transaction, then the htlc's outpoint - // is the transaction's only input. Otherwise, it's the claim - // point. - var htlcPoint wire.OutPoint - if outRes.SignedTimeoutTx != nil { - htlcPoint = outRes.SignedTimeoutTx.TxIn[0].PreviousOutPoint - } else { - htlcPoint = outRes.ClaimOutpoint - } - - outResolutionMap[htlcPoint] = outRes + outResolutionMap[outRes.HtlcPoint()] = outRes } // We'll create the resolver kit that we'll be cloning for each @@ -1155,6 +1241,7 @@ func (c *ChannelArbitrator) prepContractResolutions(htlcActions ChainActionMap, htlcResolution: resolution, broadcastHeight: height, payHash: htlc.RHash, + htlcAmt: htlc.Amt, ResolverKit: resKit, } htlcResolvers = append(htlcResolvers, resolver) @@ -1182,6 +1269,7 @@ func (c *ChannelArbitrator) prepContractResolutions(htlcActions ChainActionMap, htlcResolution: resolution, broadcastHeight: height, htlcIndex: htlc.HtlcIndex, + htlcAmt: htlc.Amt, ResolverKit: resKit, } htlcResolvers = append(htlcResolvers, resolver) @@ -1215,6 +1303,7 @@ func (c *ChannelArbitrator) prepContractResolutions(htlcActions ChainActionMap, htlcResolution: resolution, broadcastHeight: height, payHash: htlc.RHash, + htlcAmt: htlc.Amt, ResolverKit: resKit, }, } @@ -1241,10 +1330,11 @@ func (c *ChannelArbitrator) prepContractResolutions(htlcActions ChainActionMap, resKit.Quit = make(chan struct{}) resolver := &htlcOutgoingContestResolver{ - htlcTimeoutResolver{ + htlcTimeoutResolver: htlcTimeoutResolver{ htlcResolution: resolution, broadcastHeight: height, htlcIndex: htlc.HtlcIndex, + htlcAmt: htlc.Amt, ResolverKit: resKit, }, } diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index aae59f2b..e910de0a 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -175,6 +175,9 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator, *lnwallet.IncomingHtlcResolution, uint32) error { return nil }, + SettleInvoice: func(chainhash.Hash, lnwire.MilliSatoshi) error { + return nil + }, } // We'll use the resolvedChan to synchronize on call to diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 56337c9e..3a35c850 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -5,6 +5,9 @@ import ( "fmt" "io" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lnwallet" @@ -46,6 +49,10 @@ type htlcSuccessResolver struct { // TODO(roasbeef): send off to utxobundler sweepTx *wire.MsgTx + // htlcAmt is the original amount of the htlc, not taking into + // account any fees that may have to be paid if it goes on chain. + htlcAmt lnwire.MilliSatoshi + ResolverKit } @@ -169,6 +176,14 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { return nil, fmt.Errorf("quitting") } + // With the HTLC claimed, we can attempt to settle its + // corresponding invoice if we were the original destination. + err = h.SettleInvoice(h.payHash, h.htlcAmt) + if err != nil && err != channeldb.ErrInvoiceNotFound { + log.Errorf("Unable to settle invoice with payment "+ + "hash %x: %v", h.payHash, err) + } + // Once the transaction has received a sufficient number of // confirmations, we'll mark ourselves as fully resolved and exit. h.resolved = true @@ -234,6 +249,14 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { return nil, fmt.Errorf("quitting") } + // With the HTLC claimed, we can attempt to settle its corresponding + // invoice if we were the original destination. + err = h.SettleInvoice(h.payHash, h.htlcAmt) + if err != nil && err != channeldb.ErrInvoiceNotFound { + log.Errorf("Unable to settle invoice with payment "+ + "hash %x: %v", h.payHash, err) + } + h.resolved = true return nil, h.Checkpoint(h) } diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 3b4adf1f..55e9cc00 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -40,6 +40,10 @@ type htlcTimeoutResolver struct { // additional commitment state machine. htlcIndex uint64 + // htlcAmt is the original amount of the htlc, not taking into + // account any fees that may have to be paid if it goes on chain. + htlcAmt lnwire.MilliSatoshi + ResolverKit } diff --git a/lnd_test.go b/lnd_test.go index be94669d..3881e84c 100644 --- a/lnd_test.go +++ b/lnd_test.go @@ -4,23 +4,20 @@ package main import ( "bytes" + "crypto/rand" + "crypto/sha256" + "encoding/hex" "fmt" "io" "io/ioutil" "os" "path/filepath" + "reflect" "strings" + "sync/atomic" "testing" "time" - "sync/atomic" - - "encoding/hex" - "reflect" - - "crypto/rand" - "crypto/sha256" - "github.com/btcsuite/btcd/btcjson" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -9320,8 +9317,9 @@ func testMultiHopReceiverChainClaim(net *lntest.NetworkHarness, t *harnessTest) defer shutdownAndAssert(net, t, carol) // With the network active, we'll now add a new invoice at Carol's end. + const invoiceAmt = 100000 invoiceReq := &lnrpc.Invoice{ - Value: 100000, + Value: invoiceAmt, } ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) carolInvoice, err := carol.AddInvoice(ctxt, invoiceReq) @@ -9527,6 +9525,25 @@ func testMultiHopReceiverChainClaim(net *lntest.NetworkHarness, t *harnessTest) t.Fatalf(predErr.Error()) } + // The invoice should show as settled for Carol, indicating that it was + // swept on-chain. + invoicesReq := &lnrpc.ListInvoiceRequest{} + invoicesResp, err := carol.ListInvoices(ctxb, invoicesReq) + if err != nil { + t.Fatalf("unable to retrieve invoices: %v", err) + } + if len(invoicesResp.Invoices) != 1 { + t.Fatalf("expected 1 invoice, got %d", len(invoicesResp.Invoices)) + } + invoice := invoicesResp.Invoices[0] + if invoice.State != lnrpc.Invoice_SETTLED { + t.Fatalf("expected invoice to be settled on chain") + } + if invoice.AmtPaidSat != invoiceAmt { + t.Fatalf("expected invoice to be settled with %d sat, got "+ + "%d sat", invoiceAmt, invoice.AmtPaidSat) + } + // We'll close out the channel between Alice and Bob, then shutdown // carol to conclude the test. ctxt, _ = context.WithTimeout(ctxb, channelCloseTimeout) @@ -10368,8 +10385,9 @@ func testMultiHopHtlcRemoteChainClaim(net *lntest.NetworkHarness, t *harnessTest defer shutdownAndAssert(net, t, carol) // With the network active, we'll now add a new invoice at Carol's end. + const invoiceAmt = 100000 invoiceReq := &lnrpc.Invoice{ - Value: 100000, + Value: invoiceAmt, } ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) carolInvoice, err := carol.AddInvoice(ctxt, invoiceReq) @@ -10611,6 +10629,25 @@ func testMultiHopHtlcRemoteChainClaim(net *lntest.NetworkHarness, t *harnessTest if err != nil { t.Fatalf(predErr.Error()) } + + // The invoice should show as settled for Carol, indicating that it was + // swept on-chain. + invoicesReq := &lnrpc.ListInvoiceRequest{} + invoicesResp, err := carol.ListInvoices(ctxb, invoicesReq) + if err != nil { + t.Fatalf("unable to retrieve invoices: %v", err) + } + if len(invoicesResp.Invoices) != 1 { + t.Fatalf("expected 1 invoice, got %d", len(invoicesResp.Invoices)) + } + invoice := invoicesResp.Invoices[0] + if invoice.State != lnrpc.Invoice_SETTLED { + t.Fatalf("expected invoice to be settled on chain") + } + if invoice.AmtPaidSat != invoiceAmt { + t.Fatalf("expected invoice to be settled with %d sat, got "+ + "%d sat", invoiceAmt, invoice.AmtPaidSat) + } } // testSwitchCircuitPersistence creates a multihop network to ensure the sender diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 8572ed62..f10f20a7 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -5499,6 +5499,30 @@ func newIncomingHtlcResolution(signer Signer, localChanCfg *channeldb.ChannelCon }, nil } +// HtlcPoint returns the htlc's outpoint on the commitment tx. +func (r *IncomingHtlcResolution) HtlcPoint() wire.OutPoint { + // If we have a success transaction, then the htlc's outpoint + // is the transaction's only input. Otherwise, it's the claim + // point. + if r.SignedSuccessTx != nil { + return r.SignedSuccessTx.TxIn[0].PreviousOutPoint + } + + return r.ClaimOutpoint +} + +// HtlcPoint returns the htlc's outpoint on the commitment tx. +func (r *OutgoingHtlcResolution) HtlcPoint() wire.OutPoint { + // If we have a timeout transaction, then the htlc's outpoint + // is the transaction's only input. Otherwise, it's the claim + // point. + if r.SignedTimeoutTx != nil { + return r.SignedTimeoutTx.TxIn[0].PreviousOutPoint + } + + return r.ClaimOutpoint +} + // extractHtlcResolutions creates a series of outgoing HTLC resolutions, and // the local key used when generating the HTLC scrips. This function is to be // used in two cases: force close, or a unilateral close. diff --git a/server.go b/server.go index f12053c1..85178f81 100644 --- a/server.go +++ b/server.go @@ -728,7 +728,8 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, DisableChannel: func(op wire.OutPoint) error { return s.announceChanStatus(op, true) }, - Sweeper: s.sweeper, + Sweeper: s.sweeper, + SettleInvoice: s.invoices.SettleInvoice, }, chanDB) s.breachArbiter = newBreachArbiter(&BreachConfig{